In [1]:
import numpy as np
import keras
import h5py
from tqdm import tqdm
from google.colab import drive

In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [11]:
clean_data_filename = '/content/drive/MyDrive/MLCS_4/cl/valid.h5'
poisoned_data_filename = '/content/drive/MyDrive/MLCS_4/bd/bd_valid.h5'
model_filename = '/content/drive/MyDrive/MLCS_4/bd_net.h5'


def data_loader(filepath):
    data = h5py.File(filepath, 'r')
    x_data = np.array(data['data'])
    y_data = np.array(data['label'])
    x_data = x_data.transpose((0, 2, 3, 1))
    return x_data, y_data

In [12]:
def load_and_evaluate_model(model_filename, cl_x_test, cl_y_test, bd_x_test, bd_y_test):
    model = keras.models.load_model(model_filename)

    cl_label_p = np.argmax(model.predict(cl_x_test), axis=1)
    clean_accuracy = np.mean(np.equal(cl_label_p, cl_y_test)) * 100
    print('Clean Classification accuracy:', clean_accuracy)

    bd_label_p = np.argmax(model.predict(bd_x_test), axis=1)
    asr = np.mean(np.equal(bd_label_p, bd_y_test)) * 100
    print('Attack Success Rate:', asr)

    return model, clean_accuracy

In [13]:
model = keras.models.load_model(model_filename)
print(model.summary())

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input (InputLayer)          [(None, 55, 47, 3)]          0         []                            
                                                                                                  
 conv_1 (Conv2D)             (None, 52, 44, 20)           980       ['input[0][0]']               
                                                                                                  
 pool_1 (MaxPooling2D)       (None, 26, 22, 20)           0         ['conv_1[0][0]']              
                                                                                                  
 conv_2 (Conv2D)             (None, 24, 20, 40)           7240      ['pool_1[0][0]']              
                                                                                            

In [14]:

def prune_and_save_models(model, cl_x_test, cl_y_test, bd_x_test, bd_y_test, clean_data_acc):
    model_copy = keras.models.clone_model(model)
    model_copy.set_weights(model.get_weights())
    prune_index = []
    clean_acc = []
    asrate = []
    saved_model = np.zeros(3, dtype=bool)

    layer_output = model_copy.get_layer('pool_3').output
    intermediate_model = keras.models.Model(inputs=model_copy.input, outputs=layer_output)
    intermediate_prediction = intermediate_model.predict(cl_x_test)
    temp = np.mean(intermediate_prediction, axis=(0, 1, 2))
    seq = np.argsort(temp)
    weight_0 = model_copy.layers[5].get_weights()[0]
    bias_0 = model_copy.layers[5].get_weights()[1]

    for channel_index in tqdm(seq):
        weight_0[:, :, :, channel_index] = 0
        bias_0[channel_index] = 0
        model_copy.layers[5].set_weights([weight_0, bias_0])
        cl_label_p = np.argmax(model_copy.predict(cl_x_test), axis=1)
        clean_accuracy = np.mean(np.equal(cl_label_p, cl_y_test)) * 100
        if clean_data_acc - clean_accuracy >= 2 and not saved_model[0]:
            print("The accuracy drops at least 2%, saved the model")
            model_copy.save(f'model_X=2_channel_{channel_index}.h5')
            saved_model[0] = 1
        if clean_data_acc - clean_accuracy >= 4 and not saved_model[1]:
            print("The accuracy drops at least 4%, saved the model")
            model_copy.save(f'model_X=4_channel_{channel_index}.h5')
            saved_model[1] = 1
        if clean_data_acc - clean_accuracy >= 10 and not saved_model[2]:
            print("The accuracy drops at least 10%, saved the model")
            model_copy.save(f'model_X=10_channel_{channel_index}.h5')
            saved_model[2] = 1
        clean_acc.append(clean_accuracy)
        bd_label_p = np.argmax(model_copy.predict(bd_x_test), axis=1)
        asr = np.mean(np.equal(bd_label_p, bd_y_test)) * 100
        asrate.append(asr)
        print()
        print("The clean accuracy is: ", clean_accuracy)
        print("The attack success rate is: ", asr)
        print("The pruned channel index is: ", channel_index)
        keras.backend.clear_session()

if __name__ == '__main__':
    # Load data
    cl_x_test, cl_y_test = data_loader(clean_data_filename)
    bd_x_test, bd_y_test = data_loader(poisoned_data_filename)

    # Evaluate original model
    original_model, clean_data_acc = load_and_evaluate_model(model_filename, cl_x_test, cl_y_test, bd_x_test, bd_y_test)

    # Prune and save models based on the original model
    prune_and_save_models(original_model, cl_x_test, cl_y_test, bd_x_test, bd_y_test, clean_data_acc)


Clean Classification accuracy: 98.64899974019225
Attack Success Rate: 100.0


  0%|          | 0/60 [00:00<?, ?it/s]



  2%|▏         | 1/60 [00:03<03:48,  3.87s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  0


  3%|▎         | 2/60 [00:07<03:43,  3.85s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  26


  5%|▌         | 3/60 [00:11<03:43,  3.93s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  27


  7%|▋         | 4/60 [00:15<03:42,  3.97s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  30

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  

  8%|▊         | 5/60 [00:19<03:36,  3.93s/it]

31


 10%|█         | 6/60 [00:23<03:31,  3.91s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  33


 12%|█▏        | 7/60 [00:27<03:28,  3.93s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  34


 13%|█▎        | 8/60 [00:31<03:23,  3.91s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  36


 15%|█▌        | 9/60 [00:35<03:18,  3.90s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  37


 17%|█▋        | 10/60 [00:39<03:17,  3.95s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  38

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  25


 18%|█▊        | 11/60 [00:43<03:15,  3.99s/it]



 20%|██        | 12/60 [00:47<03:09,  3.95s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  39


 22%|██▏       | 13/60 [00:51<03:06,  3.97s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  41


 23%|██▎       | 14/60 [00:55<03:03,  4.00s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  44


 25%|██▌       | 15/60 [00:59<02:58,  3.96s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  45


 27%|██▋       | 16/60 [01:03<02:56,  4.02s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  47

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  48


 28%|██▊       | 17/60 [01:07<02:51,  3.99s/it]



 30%|███       | 18/60 [01:11<02:46,  3.96s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  49

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  50


 32%|███▏      | 19/60 [01:15<02:41,  3.95s/it]



 33%|███▎      | 20/60 [01:18<02:37,  3.93s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  53

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  55


 35%|███▌      | 21/60 [01:22<02:34,  3.96s/it]



 37%|███▋      | 22/60 [01:26<02:30,  3.96s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  40


 38%|███▊      | 23/60 [01:30<02:25,  3.93s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  24


 40%|████      | 24/60 [01:34<02:21,  3.92s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  59


 42%|████▏     | 25/60 [01:38<02:17,  3.94s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  9

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  2


 43%|████▎     | 26/60 [01:42<02:13,  3.93s/it]



 45%|████▌     | 27/60 [01:46<02:09,  3.91s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  12

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  13


 47%|████▋     | 28/60 [01:50<02:05,  3.93s/it]



 48%|████▊     | 29/60 [01:54<02:01,  3.91s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  17


 50%|█████     | 30/60 [01:58<01:57,  3.90s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  14

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  15


 52%|█████▏    | 31/60 [02:02<01:53,  3.93s/it]



 53%|█████▎    | 32/60 [02:06<01:49,  3.92s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  23

The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  6


 55%|█████▌    | 33/60 [02:09<01:45,  3.91s/it]



 57%|█████▋    | 34/60 [02:13<01:42,  3.93s/it]


The clean accuracy is:  98.64033948211657
The attack success rate is:  100.0
The pruned channel index is:  51


 58%|█████▊    | 35/60 [02:17<01:37,  3.92s/it]


The clean accuracy is:  98.64033948211657
The attack success rate is:  100.0
The pruned channel index is:  32


 60%|██████    | 36/60 [02:21<01:33,  3.90s/it]


The clean accuracy is:  98.63167922404088
The attack success rate is:  100.0
The pruned channel index is:  22

The clean accuracy is:  98.65765999826795
The attack success rate is:  100.0
The pruned channel index is:  21


 62%|██████▏   | 37/60 [02:25<01:29,  3.89s/it]



 63%|██████▎   | 38/60 [02:29<01:26,  3.91s/it]


The clean accuracy is:  98.64899974019225
The attack success rate is:  100.0
The pruned channel index is:  20


 65%|██████▌   | 39/60 [02:33<01:21,  3.90s/it]


The clean accuracy is:  98.6056984498138
The attack success rate is:  100.0
The pruned channel index is:  19


 67%|██████▋   | 40/60 [02:37<01:18,  3.93s/it]


The clean accuracy is:  98.57105741751104
The attack success rate is:  100.0
The pruned channel index is:  43


 68%|██████▊   | 41/60 [02:41<01:15,  3.95s/it]


The clean accuracy is:  98.53641638520828
The attack success rate is:  100.0
The pruned channel index is:  58


 70%|███████   | 42/60 [02:45<01:11,  3.99s/it]


The clean accuracy is:  98.19000606218066
The attack success rate is:  100.0
The pruned channel index is:  3


 72%|███████▏  | 43/60 [02:49<01:08,  4.02s/it]


The clean accuracy is:  97.65307006148784
The attack success rate is:  100.0
The pruned channel index is:  42


 73%|███████▎  | 44/60 [02:53<01:03,  3.98s/it]


The clean accuracy is:  97.50584567420108
The attack success rate is:  100.0
The pruned channel index is:  1


  saving_api.save_model(


The accuracy drops at least 2%, saved the model


 75%|███████▌  | 45/60 [02:58<01:02,  4.20s/it]


The clean accuracy is:  95.75647354291158
The attack success rate is:  100.0
The pruned channel index is:  29


 77%|███████▋  | 46/60 [03:02<00:58,  4.17s/it]


The clean accuracy is:  95.20221702606739
The attack success rate is:  99.9913397419243
The pruned channel index is:  16


 78%|███████▊  | 47/60 [03:06<00:53,  4.10s/it]


The clean accuracy is:  94.7172425738287
The attack success rate is:  99.9913397419243
The pruned channel index is:  56




The accuracy drops at least 4%, saved the model


 80%|████████  | 48/60 [03:10<00:48,  4.08s/it]


The clean accuracy is:  92.09318437689443
The attack success rate is:  99.9913397419243
The pruned channel index is:  46

The clean accuracy is:  91.49562656967177
The attack success rate is:  99.9913397419243
The pruned channel index is:  5


 82%|████████▏ | 49/60 [03:14<00:44,  4.06s/it]



 83%|████████▎ | 50/60 [03:18<00:40,  4.03s/it]


The clean accuracy is:  91.01931237550879
The attack success rate is:  99.98267948384861
The pruned channel index is:  8

The clean accuracy is:  89.17467740538669
The attack success rate is:  80.73958603966398
The pruned channel index is:  11


 85%|████████▌ | 51/60 [03:22<00:36,  4.01s/it]





The accuracy drops at least 10%, saved the model

The clean accuracy is:  84.43751623798389
The attack success rate is:  77.015675067117
The pruned channel index is:  54


 87%|████████▋ | 52/60 [03:26<00:32,  4.03s/it]


The clean accuracy is:  76.48739932449988
The attack success rate is:  35.71490430414826
The pruned channel index is:  10


 88%|████████▊ | 53/60 [03:30<00:28,  4.00s/it]


The clean accuracy is:  54.8627349095003
The attack success rate is:  6.954187234779596
The pruned channel index is:  28


 90%|█████████ | 54/60 [03:34<00:23,  3.96s/it]



 92%|█████████▏| 55/60 [03:37<00:19,  3.96s/it]


The clean accuracy is:  27.08928726076037
The attack success rate is:  0.4243526457088421
The pruned channel index is:  35

The clean accuracy is:  13.87373343725643
The attack success rate is:  0.0
The pruned channel index is:  18


 93%|█████████▎| 56/60 [03:41<00:15,  3.95s/it]


The clean accuracy is:  7.101411622066338
The attack success rate is:  0.0
The pruned channel index is:  4


 95%|█████████▌| 57/60 [03:45<00:11,  3.91s/it]



 97%|█████████▋| 58/60 [03:49<00:07,  3.94s/it]


The clean accuracy is:  1.5501861955486274
The attack success rate is:  0.0
The pruned channel index is:  7


 98%|█████████▊| 59/60 [03:53<00:03,  3.93s/it]


The clean accuracy is:  0.7188014202823244
The attack success rate is:  0.0
The pruned channel index is:  52


100%|██████████| 60/60 [03:57<00:00,  3.96s/it]


The clean accuracy is:  0.0779423226812159
The attack success rate is:  0.0
The pruned channel index is:  57



