In [2]:
from tensorflow import keras 
import sys
import h5py
import numpy as np
clean_data_filename = './drive/MyDrive/Lab3/cl/valid.h5'
clean_data_filename_test = './drive/MyDrive/Lab3/cl/test.h5'
poisoned_data_filename = './drive/MyDrive/Lab3/bd/bd_valid.h5'
model_filename = 'bd_net.h5'

def data_loader(filepath):
    data = h5py.File(filepath, 'r')
    x_data = np.array(data['data'])
    y_data = np.array(data['label'])
    x_data = x_data.transpose((0,2,3,1))
    return x_data, y_data

cl_x_valid, cl_y_valid = data_loader(clean_data_filename)
cl_x_test, cl_y_test = data_loader(clean_data_filename_test)
bd_x_valid, bd_y_valid = data_loader(poisoned_data_filename)

In [3]:
# import tensorflow
# tensorflow.config.experimental.list_physical_devices('GPU')

In [4]:
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tqdm import tqdm

def get_intermediate(model, layer_ind):
    intermediate_layer_model = keras.models.Model(inputs=model.input,
                                 outputs=model.layers[layer_ind].output)
    return intermediate_layer_model

class G(keras.Model):
    def __init__(self, B, B_prime):
        super(G, self).__init__()
        self.B = B
        self.B_prime = B_prime

    def call(self,data):
        y = np.argmax(self.B(data), axis=1)
        y_prime = np.argmax(self.B_prime(data), axis=1)
        tmpRes = np.array([y[i] if y[i] == y_prime[i] else 1283 for i in range(y.shape[0])])
        res = np.zeros((y.shape[0],1284))
        res[np.arange(tmpRes.size),tmpRes] = 1
        return res
      
def prune_defense(X):
    layer_ind = 6
    B = keras.models.load_model(model_filename)
    B_pruned = keras.models.load_model(model_filename)

    cl_label_p = np.argmax(B.predict(cl_x_valid), axis=1)
    original_accuracy = np.mean(np.equal(cl_label_p, cl_y_valid))*100
    print("orig accuracy clean: ",original_accuracy)

    bd_label_p = np.argmax(B.predict(bd_x_valid), axis=1)
    asr = np.mean(np.equal(bd_label_p, bd_y_valid))*100
    print("asr badnet: ",asr)

    intermediate_repr = get_intermediate(B, layer_ind)
    activations = intermediate_repr.predict(cl_x_valid)

    avg_activations = activations.mean(axis=(0,1,2))
    sorted_activations = np.argsort(avg_activations)
    #remove channels in decreasing order of avg_activations
    for channel in tqdm(sorted_activations):
        # Prune the last layer
        layer= B_pruned.layers[layer_ind-1]
        K.set_value(layer.kernel[:, :, :, channel], np.zeros_like(layer.kernel[:, :, :, channel]))
        K.set_value(layer.bias[channel], np.zeros_like(layer.bias[channel]))

        # B_pruned = delete_channels(B_pruned, B_pruned.layers[layer_ind-1], [i])
        cl_label_p = np.argmax(B_pruned.predict(cl_x_valid), axis=1)
        pruned_accuracy = np.mean(np.equal(cl_label_p, cl_y_valid))*100
        print('Validation accuracy after pruning channel', channel, ':', pruned_accuracy)
        if abs(original_accuracy - pruned_accuracy) > X:
            break

    goodnet = G(B, B_pruned)
    
    preds = goodnet(cl_x_test)
    cl_label_p = np.argmax(preds, axis=1)
    asr = np.mean(np.equal(cl_label_p, cl_y_test))*100
    print("clean test acc goodnet: ",asr)
    
    preds = goodnet(bd_x_valid)
    bd_label_p = np.argmax(preds, axis=1)
    asr = np.mean(np.equal(bd_label_p, bd_y_valid))*100
    print("asr goodnet: ",asr)
    return B_pruned

In [10]:
repaired_2 = prune_defense(2)

orig accuracy clean:  98.64899974019225
asr badnet:  100.0


  0%|          | 0/60 [00:00<?, ?it/s]



  2%|▏         | 1/60 [00:01<01:28,  1.50s/it]

Validation accuracy after pruning channel 0 : 98.64899974019225


  3%|▎         | 2/60 [00:03<01:34,  1.63s/it]

Validation accuracy after pruning channel 26 : 98.64899974019225


  5%|▌         | 3/60 [00:04<01:35,  1.68s/it]

Validation accuracy after pruning channel 27 : 98.64899974019225


  7%|▋         | 4/60 [00:06<01:29,  1.61s/it]

Validation accuracy after pruning channel 30 : 98.64899974019225


  8%|▊         | 5/60 [00:07<01:25,  1.56s/it]

Validation accuracy after pruning channel 31 : 98.64899974019225


 10%|█         | 6/60 [00:09<01:27,  1.62s/it]

Validation accuracy after pruning channel 33 : 98.64899974019225


 12%|█▏        | 7/60 [00:11<01:22,  1.56s/it]

Validation accuracy after pruning channel 34 : 98.64899974019225


 13%|█▎        | 8/60 [00:12<01:18,  1.51s/it]

Validation accuracy after pruning channel 36 : 98.64899974019225


 15%|█▌        | 9/60 [00:14<01:20,  1.59s/it]

Validation accuracy after pruning channel 37 : 98.64899974019225


 17%|█▋        | 10/60 [00:16<01:22,  1.64s/it]

Validation accuracy after pruning channel 38 : 98.64899974019225


 18%|█▊        | 11/60 [00:17<01:18,  1.60s/it]

Validation accuracy after pruning channel 25 : 98.64899974019225


 20%|██        | 12/60 [00:19<01:18,  1.64s/it]

Validation accuracy after pruning channel 39 : 98.64899974019225


 22%|██▏       | 13/60 [00:20<01:14,  1.58s/it]

Validation accuracy after pruning channel 41 : 98.64899974019225


 23%|██▎       | 14/60 [00:22<01:10,  1.53s/it]

Validation accuracy after pruning channel 44 : 98.64899974019225


 25%|██▌       | 15/60 [00:23<01:07,  1.50s/it]

Validation accuracy after pruning channel 45 : 98.64899974019225


 27%|██▋       | 16/60 [00:25<01:09,  1.57s/it]

Validation accuracy after pruning channel 47 : 98.64899974019225


 28%|██▊       | 17/60 [00:26<01:05,  1.53s/it]

Validation accuracy after pruning channel 48 : 98.64899974019225


 30%|███       | 18/60 [00:28<01:03,  1.50s/it]

Validation accuracy after pruning channel 49 : 98.64899974019225


 32%|███▏      | 19/60 [00:29<01:04,  1.57s/it]

Validation accuracy after pruning channel 50 : 98.64899974019225
Validation accuracy after pruning channel 53

 33%|███▎      | 20/60 [00:31<01:02,  1.57s/it]

 : 98.64899974019225


 35%|███▌      | 21/60 [00:32<00:59,  1.52s/it]

Validation accuracy after pruning channel 55 : 98.64899974019225


 37%|███▋      | 22/60 [00:34<00:57,  1.50s/it]

Validation accuracy after pruning channel 40 : 98.64899974019225


 38%|███▊      | 23/60 [00:36<00:58,  1.57s/it]

Validation accuracy after pruning channel 24 : 98.64899974019225


 40%|████      | 24/60 [00:37<00:55,  1.54s/it]

Validation accuracy after pruning channel 59 : 98.64899974019225


 42%|████▏     | 25/60 [00:39<00:55,  1.60s/it]

Validation accuracy after pruning channel 9 : 98.64899974019225


 43%|████▎     | 26/60 [00:40<00:55,  1.64s/it]

Validation accuracy after pruning channel 2 : 98.64899974019225


 45%|████▌     | 27/60 [00:42<00:52,  1.58s/it]

Validation accuracy after pruning channel 12 : 98.64899974019225


 47%|████▋     | 28/60 [00:44<00:51,  1.62s/it]

Validation accuracy after pruning channel 13 : 98.64899974019225


 48%|████▊     | 29/60 [00:46<00:52,  1.70s/it]

Validation accuracy after pruning channel 17 : 98.64899974019225


 50%|█████     | 30/60 [00:47<00:51,  1.72s/it]

Validation accuracy after pruning channel 14 : 98.64899974019225


 52%|█████▏    | 31/60 [00:49<00:47,  1.64s/it]

Validation accuracy after pruning channel 15 : 98.64899974019225


 53%|█████▎    | 32/60 [00:50<00:43,  1.57s/it]

Validation accuracy after pruning channel 23 : 98.64899974019225


 55%|█████▌    | 33/60 [00:52<00:43,  1.62s/it]

Validation accuracy after pruning channel 6 : 98.64899974019225


 57%|█████▋    | 34/60 [00:53<00:40,  1.56s/it]

Validation accuracy after pruning channel 51 : 98.64033948211657


 58%|█████▊    | 35/60 [00:55<00:38,  1.54s/it]

Validation accuracy after pruning channel 32 : 98.64033948211657


 60%|██████    | 36/60 [00:57<00:38,  1.60s/it]

Validation accuracy after pruning channel 22 : 98.63167922404088


 62%|██████▏   | 37/60 [00:58<00:37,  1.64s/it]

Validation accuracy after pruning channel 21 : 98.65765999826795


 63%|██████▎   | 38/60 [01:00<00:36,  1.66s/it]

Validation accuracy after pruning channel 20 : 98.64899974019225


 65%|██████▌   | 39/60 [01:02<00:35,  1.68s/it]

Validation accuracy after pruning channel 19 : 98.6056984498138


 67%|██████▋   | 40/60 [01:04<00:34,  1.71s/it]

Validation accuracy after pruning channel 43 : 98.57105741751104


 68%|██████▊   | 41/60 [01:05<00:32,  1.71s/it]

Validation accuracy after pruning channel 58 : 98.53641638520828


 70%|███████   | 42/60 [01:07<00:30,  1.71s/it]

Validation accuracy after pruning channel 3 : 98.19000606218066


 72%|███████▏  | 43/60 [01:09<00:29,  1.72s/it]

Validation accuracy after pruning channel 42 : 97.65307006148784


 73%|███████▎  | 44/60 [01:10<00:27,  1.72s/it]

Validation accuracy after pruning channel 1 : 97.50584567420108


 73%|███████▎  | 44/60 [01:12<00:26,  1.65s/it]

Validation accuracy after pruning channel 29 : 95.75647354291158





clean test acc goodnet:  95.74434918160561
asr goodnet:  100.0


In [13]:
repaired_2.save("repaired_2", save_format='h5')

In [5]:
repaired_4 = prune_defense(4)

orig accuracy clean:  98.64899974019225
asr badnet:  100.0


  0%|          | 0/60 [00:00<?, ?it/s]



  2%|▏         | 1/60 [00:01<01:48,  1.84s/it]

Validation accuracy after pruning channel 0 : 98.64899974019225


  3%|▎         | 2/60 [00:03<01:43,  1.78s/it]

Validation accuracy after pruning channel 26 : 98.64899974019225


  5%|▌         | 3/60 [00:05<01:32,  1.62s/it]

Validation accuracy after pruning channel 27 : 98.64899974019225


  7%|▋         | 4/60 [00:06<01:29,  1.60s/it]

Validation accuracy after pruning channel 30 : 98.64899974019225


  8%|▊         | 5/60 [00:08<01:27,  1.59s/it]

Validation accuracy after pruning channel 31 : 98.64899974019225


 10%|█         | 6/60 [00:10<01:31,  1.69s/it]

Validation accuracy after pruning channel 33 : 98.64899974019225


 12%|█▏        | 7/60 [00:11<01:26,  1.63s/it]

Validation accuracy after pruning channel 34 : 98.64899974019225


 13%|█▎        | 8/60 [00:13<01:26,  1.66s/it]

Validation accuracy after pruning channel 36 : 98.64899974019225


 15%|█▌        | 9/60 [00:14<01:25,  1.69s/it]

Validation accuracy after pruning channel 37 : 98.64899974019225


 17%|█▋        | 10/60 [00:16<01:20,  1.61s/it]

Validation accuracy after pruning channel 38 : 98.64899974019225


 18%|█▊        | 11/60 [00:17<01:17,  1.58s/it]

Validation accuracy after pruning channel 25 : 98.64899974019225


 20%|██        | 12/60 [00:19<01:19,  1.66s/it]

Validation accuracy after pruning channel 39 : 98.64899974019225


 22%|██▏       | 13/60 [00:21<01:21,  1.73s/it]

Validation accuracy after pruning channel 41 : 98.64899974019225


 23%|██▎       | 14/60 [00:23<01:15,  1.64s/it]

Validation accuracy after pruning channel 44 : 98.64899974019225


 25%|██▌       | 15/60 [00:24<01:11,  1.58s/it]

Validation accuracy after pruning channel 45 : 98.64899974019225


 27%|██▋       | 16/60 [00:26<01:11,  1.63s/it]

Validation accuracy after pruning channel 47 : 98.64899974019225


 28%|██▊       | 17/60 [00:28<01:11,  1.66s/it]

Validation accuracy after pruning channel 48 : 98.64899974019225


 30%|███       | 18/60 [00:29<01:10,  1.69s/it]

Validation accuracy after pruning channel 49 : 98.64899974019225


 32%|███▏      | 19/60 [00:31<01:07,  1.66s/it]

Validation accuracy after pruning channel 50 : 98.64899974019225


 33%|███▎      | 20/60 [00:32<01:03,  1.58s/it]

Validation accuracy after pruning channel 53 : 98.64899974019225


 35%|███▌      | 21/60 [00:34<01:03,  1.63s/it]

Validation accuracy after pruning channel 55 : 98.64899974019225


 37%|███▋      | 22/60 [00:35<00:59,  1.56s/it]

Validation accuracy after pruning channel 40 : 98.64899974019225


 38%|███▊      | 23/60 [00:37<00:56,  1.52s/it]

Validation accuracy after pruning channel 24 : 98.64899974019225


 40%|████      | 24/60 [00:38<00:53,  1.49s/it]

Validation accuracy after pruning channel 59 : 98.64899974019225


 42%|████▏     | 25/60 [00:40<00:56,  1.61s/it]

Validation accuracy after pruning channel 9 : 98.64899974019225


 43%|████▎     | 26/60 [00:42<00:55,  1.64s/it]

Validation accuracy after pruning channel 2 : 98.64899974019225


 45%|████▌     | 27/60 [00:44<00:55,  1.67s/it]

Validation accuracy after pruning channel 12 : 98.64899974019225


 47%|████▋     | 28/60 [00:45<00:54,  1.69s/it]

Validation accuracy after pruning channel 13 : 98.64899974019225


 48%|████▊     | 29/60 [00:47<00:50,  1.62s/it]

Validation accuracy after pruning channel 17 : 98.64899974019225


 50%|█████     | 30/60 [00:49<00:49,  1.65s/it]

Validation accuracy after pruning channel 14 : 98.64899974019225


 52%|█████▏    | 31/60 [00:50<00:48,  1.68s/it]

Validation accuracy after pruning channel 15 : 98.64899974019225


 53%|█████▎    | 32/60 [00:52<00:47,  1.70s/it]

Validation accuracy after pruning channel 23 : 98.64899974019225


 55%|█████▌    | 33/60 [00:54<00:46,  1.71s/it]

Validation accuracy after pruning channel 6 : 98.64899974019225


 57%|█████▋    | 34/60 [00:55<00:42,  1.64s/it]

Validation accuracy after pruning channel 51 : 98.64033948211657


 58%|█████▊    | 35/60 [00:57<00:40,  1.60s/it]

Validation accuracy after pruning channel 32 : 98.64033948211657


 60%|██████    | 36/60 [00:59<00:39,  1.64s/it]

Validation accuracy after pruning channel 22 : 98.63167922404088


 62%|██████▏   | 37/60 [01:00<00:38,  1.67s/it]

Validation accuracy after pruning channel 21 : 98.65765999826795


 63%|██████▎   | 38/60 [01:02<00:35,  1.60s/it]

Validation accuracy after pruning channel 20 : 98.64899974019225


 65%|██████▌   | 39/60 [01:03<00:32,  1.56s/it]

Validation accuracy after pruning channel 19 : 98.6056984498138


 67%|██████▋   | 40/60 [01:05<00:32,  1.61s/it]

Validation accuracy after pruning channel 43 : 98.57105741751104


 68%|██████▊   | 41/60 [01:06<00:29,  1.56s/it]

Validation accuracy after pruning channel 58 : 98.53641638520828


 70%|███████   | 42/60 [01:08<00:27,  1.53s/it]

Validation accuracy after pruning channel 3 : 98.19000606218066


 72%|███████▏  | 43/60 [01:10<00:27,  1.59s/it]

Validation accuracy after pruning channel 42 : 97.65307006148784


 73%|███████▎  | 44/60 [01:11<00:25,  1.57s/it]

Validation accuracy after pruning channel 1 : 97.50584567420108


 75%|███████▌  | 45/60 [01:12<00:23,  1.53s/it]

Validation accuracy after pruning channel 29 : 95.75647354291158


 77%|███████▋  | 46/60 [01:14<00:21,  1.52s/it]

Validation accuracy after pruning channel 16 : 95.20221702606739


 78%|███████▊  | 47/60 [01:15<00:19,  1.51s/it]

Validation accuracy after pruning channel 56 : 94.7172425738287


 78%|███████▊  | 47/60 [01:17<00:21,  1.65s/it]

Validation accuracy after pruning channel 46 : 92.09318437689443





clean test acc goodnet:  92.1278254091972
asr goodnet:  99.9913397419243


In [7]:
repaired_4.save("repaired_4.h5")

In [6]:
repaired_10 = prune_defense(10)

orig accuracy clean:  98.64899974019225
asr badnet:  100.0


  0%|          | 0/60 [00:00<?, ?it/s]



  2%|▏         | 1/60 [00:01<01:27,  1.48s/it]

Validation accuracy after pruning channel 0 : 98.64899974019225


  3%|▎         | 2/60 [00:02<01:23,  1.45s/it]

Validation accuracy after pruning channel 26 : 98.64899974019225


  5%|▌         | 3/60 [00:04<01:29,  1.57s/it]

Validation accuracy after pruning channel 27 : 98.64899974019225


  7%|▋         | 4/60 [00:06<01:31,  1.64s/it]

Validation accuracy after pruning channel 30 : 98.64899974019225


  8%|▊         | 5/60 [00:07<01:26,  1.56s/it]

Validation accuracy after pruning channel 31 : 98.64899974019225


 10%|█         | 6/60 [00:09<01:21,  1.52s/it]

Validation accuracy after pruning channel 33 : 98.64899974019225


 12%|█▏        | 7/60 [00:10<01:24,  1.59s/it]

Validation accuracy after pruning channel 34 : 98.64899974019225


 13%|█▎        | 8/60 [00:12<01:24,  1.63s/it]

Validation accuracy after pruning channel 36 : 98.64899974019225


 15%|█▌        | 9/60 [00:14<01:24,  1.66s/it]

Validation accuracy after pruning channel 37 : 98.64899974019225


 17%|█▋        | 10/60 [00:15<01:21,  1.64s/it]

Validation accuracy after pruning channel 38 : 98.64899974019225


 18%|█▊        | 11/60 [00:17<01:23,  1.71s/it]

Validation accuracy after pruning channel 25 : 98.64899974019225


 20%|██        | 12/60 [00:19<01:22,  1.72s/it]

Validation accuracy after pruning channel 39 : 98.64899974019225


 22%|██▏       | 13/60 [00:21<01:16,  1.63s/it]

Validation accuracy after pruning channel 41 : 98.64899974019225


 23%|██▎       | 14/60 [00:22<01:12,  1.57s/it]

Validation accuracy after pruning channel 44 : 98.64899974019225


 25%|██▌       | 15/60 [00:23<01:09,  1.55s/it]

Validation accuracy after pruning channel 45 : 98.64899974019225


 27%|██▋       | 16/60 [00:25<01:06,  1.52s/it]

Validation accuracy after pruning channel 47 : 98.64899974019225


 28%|██▊       | 17/60 [00:27<01:09,  1.62s/it]

Validation accuracy after pruning channel 48 : 98.64899974019225


 30%|███       | 18/60 [00:28<01:05,  1.57s/it]

Validation accuracy after pruning channel 49 : 98.64899974019225


 32%|███▏      | 19/60 [00:30<01:02,  1.53s/it]

Validation accuracy after pruning channel 50 : 98.64899974019225


 33%|███▎      | 20/60 [00:31<01:00,  1.51s/it]

Validation accuracy after pruning channel 53 : 98.64899974019225


 35%|███▌      | 21/60 [00:33<01:01,  1.57s/it]

Validation accuracy after pruning channel 55 : 98.64899974019225


 37%|███▋      | 22/60 [00:34<00:58,  1.53s/it]

Validation accuracy after pruning channel 40 : 98.64899974019225


 38%|███▊      | 23/60 [00:36<01:00,  1.64s/it]

Validation accuracy after pruning channel 24 : 98.64899974019225


 40%|████      | 24/60 [00:38<00:56,  1.58s/it]

Validation accuracy after pruning channel 59 : 98.64899974019225


 42%|████▏     | 25/60 [00:39<00:54,  1.55s/it]

Validation accuracy after pruning channel 9 : 98.64899974019225


 43%|████▎     | 26/60 [00:41<00:51,  1.52s/it]

Validation accuracy after pruning channel 2 : 98.64899974019225


 45%|████▌     | 27/60 [00:42<00:52,  1.58s/it]

Validation accuracy after pruning channel 12 : 98.64899974019225


 47%|████▋     | 28/60 [00:44<00:49,  1.54s/it]

Validation accuracy after pruning channel 13 : 98.64899974019225


 48%|████▊     | 29/60 [00:46<00:50,  1.64s/it]

Validation accuracy after pruning channel 17 : 98.64899974019225


 50%|█████     | 30/60 [00:47<00:49,  1.67s/it]

Validation accuracy after pruning channel 14 : 98.64899974019225


 52%|█████▏    | 31/60 [00:49<00:48,  1.68s/it]

Validation accuracy after pruning channel 15 : 98.64899974019225


 53%|█████▎    | 32/60 [00:51<00:47,  1.70s/it]

Validation accuracy after pruning channel 23 : 98.64899974019225


 55%|█████▌    | 33/60 [00:52<00:46,  1.71s/it]

Validation accuracy after pruning channel 6 : 98.64899974019225


 57%|█████▋    | 34/60 [00:54<00:44,  1.71s/it]

Validation accuracy after pruning channel 51 : 98.64033948211657


 58%|█████▊    | 35/60 [00:56<00:44,  1.78s/it]

Validation accuracy after pruning channel 32 : 98.64033948211657


 60%|██████    | 36/60 [00:58<00:42,  1.76s/it]

Validation accuracy after pruning channel 22 : 98.63167922404088


 62%|██████▏   | 37/60 [01:00<00:40,  1.75s/it]

Validation accuracy after pruning channel 21 : 98.65765999826795


 63%|██████▎   | 38/60 [01:01<00:38,  1.76s/it]

Validation accuracy after pruning channel 20 : 98.64899974019225


 65%|██████▌   | 39/60 [01:03<00:34,  1.66s/it]

Validation accuracy after pruning channel 19 : 98.6056984498138


 67%|██████▋   | 40/60 [01:04<00:31,  1.58s/it]

Validation accuracy after pruning channel 43 : 98.57105741751104


 68%|██████▊   | 41/60 [01:06<00:31,  1.67s/it]

Validation accuracy after pruning channel 58 : 98.53641638520828


 70%|███████   | 42/60 [01:08<00:30,  1.69s/it]

Validation accuracy after pruning channel 3 : 98.19000606218066


 72%|███████▏  | 43/60 [01:10<00:29,  1.72s/it]

Validation accuracy after pruning channel 42 : 97.65307006148784


 73%|███████▎  | 44/60 [01:12<00:29,  1.87s/it]

Validation accuracy after pruning channel 1 : 97.50584567420108


 75%|███████▌  | 45/60 [01:14<00:31,  2.10s/it]

Validation accuracy after pruning channel 29 : 95.75647354291158


 77%|███████▋  | 46/60 [01:16<00:29,  2.08s/it]

Validation accuracy after pruning channel 16 : 95.20221702606739


 78%|███████▊  | 47/60 [01:19<00:26,  2.06s/it]

Validation accuracy after pruning channel 56 : 94.7172425738287


 80%|████████  | 48/60 [01:20<00:22,  1.89s/it]

Validation accuracy after pruning channel 46 : 92.09318437689443


 82%|████████▏ | 49/60 [01:21<00:19,  1.76s/it]

Validation accuracy after pruning channel 5 : 91.49562656967177


 83%|████████▎ | 50/60 [01:23<00:17,  1.75s/it]

Validation accuracy after pruning channel 8 : 91.01931237550879


 85%|████████▌ | 51/60 [01:25<00:15,  1.75s/it]

Validation accuracy after pruning channel 11 : 89.17467740538669


 85%|████████▌ | 51/60 [01:26<00:15,  1.70s/it]

Validation accuracy after pruning channel 54 : 84.43751623798389





clean test acc goodnet:  84.3335931410756
asr goodnet:  77.015675067117


In [8]:
repaired_10.save("repaired_10.h5")

In [None]:
## Creating stats for report

In [17]:
def prune_defense_all(X):
    layer_ind = 6
    B = keras.models.load_model(model_filename)
    B_pruned = keras.models.load_model(model_filename)

    cl_label_p = np.argmax(B.predict(cl_x_valid), axis=1)
    original_accuracy = np.mean(np.equal(cl_label_p, cl_y_valid))*100
    print("orig accuracy clean: ",original_accuracy)

    bd_label_p = np.argmax(B.predict(bd_x_valid), axis=1)
    asr = np.mean(np.equal(bd_label_p, bd_y_valid))*100
    print("asr badnet: ",asr)

    intermediate_repr = get_intermediate(B, layer_ind)
    activations = intermediate_repr.predict(cl_x_valid)

    avg_activations = activations.mean(axis=(0,1,2))
    sorted_activations = np.argsort(avg_activations)
    l = len(sorted_activations)
    #remove channels in decreasing order of avg_activations
    i = 0
    for channel in (sorted_activations):
        # Prune the last layer
        layer= B_pruned.layers[layer_ind-1]
        K.set_value(layer.kernel[:, :, :, channel], np.zeros_like(layer.kernel[:, :, :, channel]))
        K.set_value(layer.bias[channel], np.zeros_like(layer.bias[channel]))

        # B_pruned = delete_channels(B_pruned, B_pruned.layers[layer_ind-1], [i])
        # cl_label_p = np.argmax(B_pruned.predict(cl_x_valid), axis=1)
        # pruned_accuracy = np.mean(np.equal(cl_label_p, cl_y_valid))*100
        # print('Validation accuracy after pruning channel', channel, ':', pruned_accuracy)
        
        goodnet = G(B, B_pruned)
        print(i/l)
        
        preds = goodnet(cl_x_test)
        cl_label_p = np.argmax(preds, axis=1)
        asr = np.mean(np.equal(cl_label_p, cl_y_test))*100
        print("clean test acc goodnet: ",asr)
        
        preds = goodnet(bd_x_valid)
        bd_label_p = np.argmax(preds, axis=1)
        asr = np.mean(np.equal(bd_label_p, bd_y_valid))*100
        print("asr goodnet: ",asr)
        i+=1

In [18]:
prune_defense_all(0)

orig accuracy clean:  98.64899974019225
asr badnet:  100.0
0.0
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.016666666666666666
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.03333333333333333
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.05
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.06666666666666667
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.08333333333333333
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.1
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.11666666666666667
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.13333333333333333
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.15
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.16666666666666666
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.18333333333333332
clean test acc goodnet:  98.62042088854248
asr goodnet:  100.0
0.2
cl