# Pruning Models to prevent Backdoor attacks

## Importing Libraries, cloning github repo and downloading datasets

In [None]:
#import libraries
import numpy as np
import matplotlib 
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import tensorflow as tf
from tensorflow import keras
from keras import models
import h5py
import matplotlib.image as mpimg
import imageio as im

Clone the repository

In [None]:
! git clone https://github.com/csaw-hackml/CSAW-HackML-2020.git

Cloning into 'CSAW-HackML-2020'...
remote: Enumerating objects: 220, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (52/52), done.[K
remote: Total 220 (delta 27), reused 2 (delta 0), pack-reused 164[K
Receiving objects: 100% (220/220), 85.94 MiB | 28.83 MiB/s, done.
Resolving deltas: 100% (82/82), done.


Download the datasets from Google Drive

In [None]:
%cd /content/CSAW-HackML-2020/data
%mkdir data

#download datasets
!gdown --id 19OKCkY2CjV3ASkOe6nMSYTsOVcxAoCnA
!gdown --id 1XtYnM-IopU-QYVc99U51EiDvI5zxK0nV
!gdown --id 1P8PTL62x3cfpV9mrC0unqZjRFhlTTOSG
!gdown --id 1XFKaTse6gflUFK7lDPxXBUaq4oQA8-qy
!gdown --id 1TiBviHoi-nh-aDRCP-1ZQlP0Nis6wOCw
!gdown --id 1SrObV38DPLgsMfpPYTdeX7nzjrEUAEwW

/content/CSAW-HackML-2020/data
Downloading...
From: https://drive.google.com/uc?id=19OKCkY2CjV3ASkOe6nMSYTsOVcxAoCnA
To: /content/CSAW-HackML-2020/data/clean_validation_data.h5
100% 716M/716M [00:02<00:00, 302MB/s]
Downloading...
From: https://drive.google.com/uc?id=1XtYnM-IopU-QYVc99U51EiDvI5zxK0nV
To: /content/CSAW-HackML-2020/data/clean_test_data.h5
100% 398M/398M [00:01<00:00, 262MB/s]
Downloading...
From: https://drive.google.com/uc?id=1P8PTL62x3cfpV9mrC0unqZjRFhlTTOSG
To: /content/CSAW-HackML-2020/data/sunglasses_poisoned_data.h5
100% 398M/398M [00:02<00:00, 141MB/s]
Downloading...
From: https://drive.google.com/uc?id=1XFKaTse6gflUFK7lDPxXBUaq4oQA8-qy
To: /content/CSAW-HackML-2020/data/anonymous_1_poisoned_data.h5
100% 637M/637M [00:06<00:00, 96.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=1TiBviHoi-nh-aDRCP-1ZQlP0Nis6wOCw
To: /content/CSAW-HackML-2020/data/lipstick_poisoned_data.h5
100% 637M/637M [00:03<00:00, 198MB/s]
Downloading...
From: https://drive.google.com/

## The BadNets

In this section, we do the following:


1.   Load the badnets
2.   Load the data
3.   Get activations for the last CNN layer and sort it



### Loading the model

In [None]:
%cd 
%cd /content

badNetPaths = ['/content/CSAW-HackML-2020/models/anonymous_1_bd_net.h5',
               '/content/CSAW-HackML-2020/models/anonymous_2_bd_net.h5',
               '/content/CSAW-HackML-2020/models/multi_trigger_multi_target_bd_net.h5',
               '/content/CSAW-HackML-2020/models/sunglasses_bd_net.h5']

weightsPaths = ['/content/CSAW-HackML-2020/models/anonymous_1_bd_weights.h5',
                '/content/CSAW-HackML-2020/models/anonymous_2_bd_weights.h5',
                '/content/CSAW-HackML-2020/models/multi_trigger_multi_target_bd_weights.h5',
                '/content/CSAW-HackML-2020/models/sunglasses_bd_weights.h5']

def CreateModel(badnet,weights):
  Model = keras.models.load_model(badnet)
  loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  Model.load_weights(weights)
  Model.compile(optimizer='adam', loss=loss_func, metrics=['accuracy'])
  

  return Model

BadNetA1 = CreateModel(badNetPaths[0],weightsPaths[0])
BadNetA2 = CreateModel(badNetPaths[1],weightsPaths[1])
BadNetA3 = CreateModel(badNetPaths[2],weightsPaths[2])
BadNetA4 = CreateModel(badNetPaths[3],weightsPaths[3])

/root
/content


### Load the data

In [None]:
# Load validation dataset
def loadData(filePath):
  data = h5py.File(filePath,'r')
  x = np.array(data['data'])
  x = x.transpose((0,2,3,1))
  x = x/255
  y = np.array(data['label'])

  return x,y

#set paths for all datasets
valCleanPath = '/content/CSAW-HackML-2020/data/clean_validation_data.h5'
testCleanPath = '/content/CSAW-HackML-2020/data/clean_test_data.h5'
valBadPaths = ['/content/CSAW-HackML-2020/data/sunglasses_poisoned_data.h5',
               '/content/CSAW-HackML-2020/data/lipstick_poisoned_data.h5',
               '/content/CSAW-HackML-2020/data/eyebrows_poisoned_data.h5',
               '/content/CSAW-HackML-2020/data/anonymous_1_poisoned_data.h5']


#load data
valCleanX, valCleanY = loadData(valCleanPath)



testCleanX, testCleanY = loadData(testCleanPath)
valBadSunGlassesX, valBadSunGlassesY = loadData(valBadPaths[0])
valBadLipstickX, valBadLipstickY = loadData(valBadPaths[1])
valBadEyebrowsX, valBadEyeBrowsY = loadData(valBadPaths[2])
valBadA1X, valBadA2Y = loadData(valBadPaths[3])

### Computing Activations

In [None]:
def ComputeActivation(badNet,valCleanX):
  layer = badNet.layers[5].output
  activationModel = models.Model(inputs=badNet.input, outputs=layer)
  layerActivations = activationModel.predict(valCleanX)

  imageNum = layerActivations.shape[0]
  chanelActivations = np.zeros([10, 8, 60])
  #set chanel activations
  for image in range(imageNum): #go through all images 
    chanelActivations[:,:,:] += layerActivations[image,:,:,:]

  chanelActivations = chanelActivations/imageNum

  #compute average activation for each chanel
  averageChanelActivation = []
  for chanel in range(60): #there are 60 chanels
    activation = np.sum(chanelActivations[:,:,chanel]/80)
    averageChanelActivation.append(activation)

  #sort activations in ascending order while maintaining an index of chanels
  activationRef = dict() #create an empty dict

  for idx,value in enumerate(averageChanelActivation):
    activationRef[idx] = value

  sortedAct = sorted(activationRef.items(), key=lambda x: x[1])

  return sortedAct

sortedActA1 = ComputeActivation(BadNetA1,valCleanX)
sortedActA2 = ComputeActivation(BadNetA2,valCleanX)
sortedActA3 = ComputeActivation(BadNetA3,valCleanX)
sortedActA4 = ComputeActivation(BadNetA4,valCleanX)

## Pruning the Badnet: Creating Repaired Networks

In [None]:
#let us now finally get to pruning
#we will prune on validation data
def pruneModel(targetAcc,RepairedModel,valCleanX,valCleanY,testBadX,testBadY,sortedActivations):
  #targetAcc is given as either 0.02, 0.04 or 0.10 
  #RepairedModel = keras.models.load_model(path)
  #RepairedModel.compile(optimizer = 'adam',loss = loss_func,metrics=['accuracy'])

  weights,biases = RepairedModel.layers[5].get_weights()

  baseLoss, baseAcc = RepairedModel.evaluate(valCleanX,valCleanY,verbose=2)
  threshold = 0
  run = 0
  for i in range(60): #number of chanels to prune

    index = sortedActivations[i][0]
    weights[:,:,:,index] = np.zeros((3,3,40)) #setting weights to zero
    biases[index] = 0 #setting biase to zero 
    RepairedModel.layers[5].set_weights([weights,biases])
    newLoss, newAcc = RepairedModel.evaluate(valCleanX,valCleanY,verbose=2)
    threshold = baseAcc - newAcc
    run+=1
    
    if threshold >= targetAcc:
      break
  
  chanelsPrunedFraction = run/60
  repLoss, repAcc = RepairedModel.evaluate(valCleanX,valCleanY,verbose=2)
  a, attackSuccess = RepairedModel.evaluate(testBadX, testBadY, verbose=0)

  return RepairedModel, chanelsPrunedFraction, repAcc, attackSuccess

In [None]:
RepairedNetA1, chanelsPrunedFractionA1, repAccA1, attackSuccessA1 = pruneModel(0.15,
                                                                               BadNetA1,
                                                                               valCleanX,valCleanY,
                                                                               valBadA1X,valBadA2Y,
                                                                               sortedActA1)

RepairedNetA2, chanelsPrunedFractionA2, repAccA2, attackSuccessA2 = pruneModel(0.15,
                                                                               BadNetA2,
                                                                               valCleanX,valCleanY,
                                                                               valBadA1X,valBadA2Y,
                                                                               sortedActA2)

RepairedNetA3, chanelsPrunedFractionA3, repAccA3, attackSuccessA3 = pruneModel(0.15,
                                                                               BadNetA3,
                                                                               valCleanX,valCleanY,
                                                                               valBadEyebrowsX,valBadEyeBrowsY,
                                                                               sortedActA3)

RepairedNetA4, chanelsPrunedFractionA4, repAccA4, attackSuccessA4 = pruneModel(0.15,
                                                                               BadNetA4,
                                                                               valCleanX,valCleanY,
                                                                               valBadSunGlassesX,valBadSunGlassesY,
                                                                               sortedActA4)

Now, we save the repaired models

In [None]:
RepairedNetA1.save('RepairedNetA1')
RepairedNetA2.save('RepairedNetA2')
RepairedNetA3.save('RepairedNetA3')
RepairedNetA4.save('RepairedNetA4')

In addition, we also save the badnets for reference.

In [None]:
BadNetA1.save('BadNetA1')
BadNetA2.save('BadNetA2')
BadNetA3.save('BadNetA3')
BadNetA4.save('BadNetA4')

In [None]:
print('For our A1 repaired model, we have the following metrics:')
print('Fraction of chanels pruned: ',chanelsPrunedFractionA1)
print('Accuracy on the clean validation dataset: ', repAccA1)
print('Attack success rate: ',attackSuccessA1)
print('------------------------------------------------------')

print('For our A2 repaired model, we have the following metrics:')
print('Fraction of chanels pruned: ',chanelsPrunedFractionA2)
print('Accuracy on the clean validation dataset: ', repAccA2)
print('Attack success rate: ',attackSuccessA2)
print('------------------------------------------------------')

print('For our A3 repaired model, we have the following metrics:')
print('Fraction of chanels pruned: ',chanelsPrunedFractionA3)
print('Accuracy on the clean validation dataset: ', repAccA3)
print('Attack success rate: ',attackSuccessA3)
print('------------------------------------------------------')

print('For our A4 repaired model, we have the following metrics:')
print('Fraction of chanels pruned: ',chanelsPrunedFractionA4)
print('Accuracy on the clean validation dataset: ', repAccA4)
print('Attack success rate: ',attackSuccessA4)
print('------------------------------------------------------')

For our A1 repaired model, we have the following metrics:
Fraction of chanels pruned:  0.5666666666666667
Accuracy on the clean validation dataset:  0.8118125796318054
Attack success rate:  0.6192517280578613
------------------------------------------------------
For our A2 repaired model, we have the following metrics:
Fraction of chanels pruned:  0.5166666666666667
Accuracy on the clean validation dataset:  0.7880834937095642
Attack success rate:  0.0
------------------------------------------------------
For our A3 repaired model, we have the following metrics:
Fraction of chanels pruned:  0.55
Accuracy on the clean validation dataset:  0.803412139415741
Attack success rate:  0.8066056370735168
------------------------------------------------------
For our A4 repaired model, we have the following metrics:
Fraction of chanels pruned:  0.5833333333333334
Accuracy on the clean validation dataset:  0.82792067527771
Attack success rate:  0.9855027198791504
-------------------------------

## GoodNet

For the goodnet, we will simply feed the network with test image and compare the outputs. If outputs do not match, then we will assign N+1 as the prediction

In [None]:
from PIL import Image

#input the paths
repairedNetPath = ''
badNetPath = ''
testImagePath = ''

def imagePreProcess(imagePath):
  image = Image.open(imagePath)
  imageArray = np.asarray(image)
  imageArray = np.float32(imageArray)
  imageTensor = tf.convert_to_tensor(imageArray)
  imageTensor = tf.expand_dims(imageTensor, 0)
  imageTensor = imageTensor/255

  return imageTensor

def Goodnet(testImagePath,badNetPath,repairedNetPath):
  testImage = imagePreProcess(testImagePath)
  repairedNet = keras.models.load_model(repairedNetPath)
  badNet = keras.models.load_model(badNetPath)

  predictionRNet = np.argmax(repairedNet.predict(testImage))
  predictionBNet = np.argmax(badNet.predict(testImage))

  if predictionRNet == predictionBNet:
    prediction = predictionRNet
    
  else:
    prediction = 1283
    
  
  return prediction

In [None]:
!mkdir MyModels
!zip -r /content/MyModels.zip /content/MyModels