In [165]:
# Load dependencies

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.models as models
import torch.optim as optim
import os
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchattacks
from torchmetrics.functional.image import peak_signal_noise_ratio, structural_similarity_index_measure

In [166]:
# Setup CUDA Device

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device using: {device}')

Device using: cuda


In [167]:
model_name = "Mobilenetv3Small"
version = "v4"
training_name = "ARD_Alpha=0.7_Temperature=5"

height = 224
num_classes = 7
epochs = 100
lr = 0.00017
lr_factor = 0.1
lr_threshold = 6
weight_decay = 0.0002
batch_size = 32

# Attack hyperparameters 
epsilon = 8.0 / 255
alpha = 2.0 / 255
steps = 10

# Knowledge Distillation hyperparameters
temp = 5.0
alpha = 0.8

In [168]:
# Graph writer initialize for data visualization

writer = SummaryWriter("runs/trashbox/" + f'{training_name}--{model_name}.{version}')

In [169]:
preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
test_preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.ToTensor()
])

In [170]:
# Training and Validation dataset

print('===>> Preparing data...')
trash_train_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/train', transform=preprocessing)
trash_train_loader = torch.utils.data.DataLoader(dataset=trash_train_dataset, shuffle=True, batch_size=batch_size)
trash_val_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/val', transform=test_preprocessing)
trash_val_loader = torch.utils.data.DataLoader(dataset=trash_val_dataset, shuffle=True, batch_size=batch_size)

===>> Preparing data...


In [171]:
print('====>> Setting up teacher model...')

# Initialize architecture with modified output features
teacher_model = models.googlenet(weights=models.GoogLeNet_Weights.DEFAULT)
infeatures = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(infeatures, num_classes, True)

# Load saved weights 
checkpoint = torch.load('./best_trained_models/best_AT--Googlenet.v1_epoch98.pth')

if 'module' in list(checkpoint['net'].keys())[0]:
    new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['net'].items()}
    teacher_model.load_state_dict(new_state_dict)
else:
    teacher_model.load_state_dict(checkpoint['net'])

teacher_model = teacher_model.to(device)

for param in teacher_model.parameters():
    param.requires_grad = False

teacher_model.eval()

====>> Setting up teacher model...


GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track

In [172]:
# Setup student model 

print('====>> Setting up student model...')
student_model = models.mobilenet_v3_small(weights=None)

# for param in student_model.parameters():
#     param.requires_grad = False

student_model.classifier = nn.Sequential(
    nn.Linear(in_features=576, out_features=1024, bias=True),
    nn.Hardswish(),
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1024, out_features=num_classes, bias=True)
)

# # Load saved weights 
# checkpoint = torch.load('./best_trained_models/best_NORMAL--Mobilenetv3Small.v1_epoch40.pth')

# if 'module' in list(checkpoint['net'].keys())[0]:
#     new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['net'].items()}
#     student_model.load_state_dict(new_state_dict)
# else:
#     student_model.load_state_dict(checkpoint['net'])

student_model = student_model.to(device)

====>> Setting up student model...


In [173]:
# Initialize adversarial attack for generating adversarial samples

attack = torchattacks.TPGD(student_model, eps=epsilon, alpha=alpha, steps=steps)

In [174]:
# Setup Loss functions

XENT_loss = nn.CrossEntropyLoss()
KL_loss = nn.KLDivLoss()

In [175]:
# Train Loop

def train(epoch, optimizer):
    train_loss = 0
    correct = 0
    total = 0
    adv_correct = 0
    student_model.train()
    total_ssim = 0
    total_psnr = 0

    iterator = tqdm(trash_train_loader, ncols=0, leave=False)
    for i, (inputs, targets)in enumerate(iterator):
        inputs, targets = inputs.to(device),targets.to(device)

        optimizer.zero_grad()
        adv_image = attack(inputs, targets)

        # Soft labels
        teacher_output = teacher_model(inputs)
        student_output = student_model(inputs)
        adv_student_output = student_model(adv_image)
        
        # ARD loss function formula
        loss =  alpha* temp* temp*KL_loss(F.log_softmax(adv_student_output/ temp, dim=1),F.softmax(teacher_output/ temp, dim=1))+(1.0- alpha)*XENT_loss(student_output, targets)
        loss.backward()
        optimizer.step()

        # Measure loss
        train_loss += loss.item()
        iterator.set_description(str(loss.item()))      

        # Get total 
        total += targets.size(0)

        # SSIM and PSNR 
        total_psnr += peak_signal_noise_ratio(adv_image, inputs)
        total_ssim += structural_similarity_index_measure(adv_image, inputs)
        
        # Measure clean and adversarial accuracy 
        _, predicted = student_output.max(1)
        correct += predicted.eq(targets).sum().item()
        _, adv_predicted = adv_student_output.max(1)
        adv_correct += adv_predicted.eq(targets).sum().item()
    
    # SSIM and PSNR Average
    avg_ssim = total_ssim / total
    avg_psnr = total_psnr / total

    writer.add_scalar("Average SSIM: " + model_name, avg_ssim, epoch)
    writer.add_scalar("Average PSNR: " + model_name, avg_psnr, epoch)

    training_loss = train_loss  / total
    train_adv_accuracy = 100.0 * correct / total
    adv_train_adv_accuracy = 100.0 * adv_correct / total
    
    print('\nTotal adversarial train accuracy:', 100. * correct / total)
    print('Total adversarial train loss:', train_loss)
    
    # Write graph over epoch
    writer.add_scalar('Train loss: ' + model_name, training_loss, epoch)
    writer.add_scalar('Train accuracy: ' + model_name, train_adv_accuracy, epoch)
    writer.add_scalar('Adversarial Train accuracy: ' + model_name, adv_train_adv_accuracy, epoch)

In [176]:
# Test function
best_loss = float(0)

def test(epoch, optimizer):
    global best_loss
    print('\n[ Test epoch: %d ]' % epoch)
    student_model.eval()
    benign_correct = 0
    adv_correct = 0
    total = 0
    with torch.no_grad():
        iterator = tqdm(trash_val_loader, ncols=0, leave=False)
        for i, (inputs, targets) in enumerate(iterator):
            inputs, targets = inputs.to(device), targets.to(device)
            
            total += targets.size(0)

            with torch.enable_grad():
                adv_images = attack(inputs, targets)
            
            # Output
            natural_outputs = student_model(inputs)
            adv_outputs = student_model(adv_images)

            # Prediction
            _, adv_predicted = adv_outputs.max(1)
            _, natural_predicted = natural_outputs.max(1)
            
            # Correct top 1
            benign_correct += natural_predicted.eq(targets).sum().item()
            adv_correct += adv_predicted.eq(targets).sum().item()

            iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0)))
    
    # Adversarial and Clean Accuracy 
    benign_val_accuracy = 100.0 * benign_correct / total
    adv_val_accuracy = 100.0 * adv_correct / total 
    
    # Logs
    print('\nTotal benign test accuarcy:', benign_val_accuracy)
    print('Total adversarial test Accuarcy:', adv_val_accuracy)
    
    # Graph
    writer.add_scalar("Natural test accuracy: " + model_name, benign_val_accuracy, epoch)
    writer.add_scalar("Adversarial test accuracy: " + model_name, adv_val_accuracy, epoch)
    
    # Save checkpoint
    state = {
        'epoch' : epoch,
        'net': student_model.state_dict(),
        'optim' : optimizer.state_dict()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/' + f'{training_name}--{model_name}.{version}.pth')
    if benign_val_accuracy > best_loss:
        print(f'Model saved: f{benign_val_accuracy}')
        torch.save(state, './trained_model/' + f'best_{training_name}_{model_name}_{version}_epoch{epoch}.pth')
    print('Model Saved!')
    return benign_val_accuracy, adv_val_accuracy

In [177]:
def main():
    learning_rate = lr
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate, weight_decay=0.0002)
    model_path = f'./checkpoint/{training_name}--{model_name}.{version}.pth'
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=lr_threshold, factor=lr_factor)
    if os.path.exists(model_path):
        # Load the saved model and optimizer state
        checkpoint = torch.load(model_path)
        student_model.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optim'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"=> Loaded checkpoint '{model_path}' (epoch {start_epoch})")
    else:
        start_epoch = 0
        print(f"=> No checkpoint found at '{model_path}'. Starting training from scratch.")


    for epoch in range(start_epoch, epochs):
        train_loss = train(epoch, optimizer)
        benign_val_accuracy , _ = test(epoch, optimizer)
        scheduler.step(metrics=benign_val_accuracy, epoch=epoch)
        scheduler.print_lr(True, student_model.parameters(), learning_rate, epoch)

In [178]:
if __name__ == '__main__':
    main()

=> No checkpoint found at './checkpoint/ARD_Alpha=0.7_Temperature=5--Mobilenetv3Small.v4.pth'. Starting training from scratch.


                                                          


Total adversarial train accuracy: 31.556831710904124
Total adversarial train loss: 256.04753854870796

[ Test epoch: 0 ]





Total benign test accuarcy: 34.306569343065696
Total adversarial test Accuarcy: 21.224031443009544
Model saved: f34.306569343065696
Model Saved!
Epoch 00000: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFD80> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 39.071363540864205
Total adversarial train loss: 232.23380780220032

[ Test epoch: 1 ]


                                                        


Total benign test accuarcy: 40.53902302077485
Total adversarial test Accuarcy: 26.445816956765864
Model saved: f40.53902302077485
Model Saved!
Epoch 00001: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFD80> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 41.80965053575181
Total adversarial train loss: 223.2721776664257

[ Test epoch: 2 ]


                                                        


Total benign test accuarcy: 41.661987647389104
Total adversarial test Accuarcy: 29.084783829309377
Model saved: f41.661987647389104
Model Saved!
Epoch 00002: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FF4C0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 43.448420757756146
Total adversarial train loss: 215.19185507297516

[ Test epoch: 3 ]


                                                        


Total benign test accuarcy: 45.87310499719259
Total adversarial test Accuarcy: 32.6221224031443
Model saved: f45.87310499719259
Model Saved!
Epoch 00003: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFAE0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 44.87008894180265
Total adversarial train loss: 210.01050329208374

[ Test epoch: 4 ]


                                                        


Total benign test accuarcy: 45.98540145985402
Total adversarial test Accuarcy: 31.723750701852893
Model saved: f45.98540145985402
Model Saved!
Epoch 00004: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFAE0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 46.36179004131942
Total adversarial train loss: 203.4052678346634

[ Test epoch: 5 ]


                                                        


Total benign test accuarcy: 49.3542953396968
Total adversarial test Accuarcy: 35.03649635036496
Model saved: f49.3542953396968
Model Saved!
Epoch 00005: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFD80> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 47.216191610056725
Total adversarial train loss: 201.176798671484

[ Test epoch: 6 ]


                                                       


Total benign test accuarcy: 47.95058955642897
Total adversarial test Accuarcy: 37.675463222908476
Model saved: f47.95058955642897
Model Saved!
Epoch 00006: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FF920> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 48.20365571818755
Total adversarial train loss: 196.11781778931618

[ Test epoch: 7 ]


                                                        


Total benign test accuarcy: 49.466591802358224
Total adversarial test Accuarcy: 36.94553621560921
Model saved: f49.466591802358224
Model Saved!
Epoch 00007: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FF920> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 48.16163596890539
Total adversarial train loss: 191.96959222853184

[ Test epoch: 8 ]


                                                        


Total benign test accuarcy: 50.25266704098821
Total adversarial test Accuarcy: 37.113980909601345
Model saved: f50.25266704098821
Model Saved!
Epoch 00008: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFA00> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 49.31017578261783
Total adversarial train loss: 188.92543596029282

[ Test epoch: 9 ]


                                                       


Total benign test accuarcy: 48.568220101066814
Total adversarial test Accuarcy: 38.6299831555306
Model saved: f48.568220101066814
Model Saved!
Epoch 00009: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFAE0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 50.01050493732054
Total adversarial train loss: 187.40328082442284

[ Test epoch: 10 ]


                                                       


Total benign test accuarcy: 52.44244806288602
Total adversarial test Accuarcy: 37.7877596855699
Model saved: f52.44244806288602
Model Saved!
Epoch 00010: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFAE0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 50.73884725821136
Total adversarial train loss: 183.38335034251213

[ Test epoch: 11 ]


                                                       


Total benign test accuarcy: 50.870297585626055
Total adversarial test Accuarcy: 38.12464907355418
Model saved: f50.870297585626055
Model Saved!
Epoch 00011: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FF4C0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 50.92093283843406
Total adversarial train loss: 182.70768719911575

[ Test epoch: 12 ]


                                                       


Total benign test accuarcy: 50.870297585626055
Total adversarial test Accuarcy: 40.651319483436275
Model saved: f50.870297585626055
Model Saved!
Epoch 00012: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FF4C0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 51.117025001750825
Total adversarial train loss: 179.14206394553185

[ Test epoch: 13 ]


                                                       


Total benign test accuarcy: 53.621560920830994
Total adversarial test Accuarcy: 41.04435710275126
Model saved: f53.621560920830994
Model Saved!
Epoch 00013: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFAE0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 52.25856152391624
Total adversarial train loss: 176.93881058692932

[ Test epoch: 14 ]


                                                        


Total benign test accuarcy: 53.45311622683885
Total adversarial test Accuarcy: 40.7636159460977
Model saved: f53.45311622683885
Model Saved!
Epoch 00014: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFD80> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 52.16051544225786
Total adversarial train loss: 175.99961787462234

[ Test epoch: 15 ]


                                                        


Total benign test accuarcy: 53.340819764177425
Total adversarial test Accuarcy: 38.012352610892755
Model saved: f53.340819764177425
Model Saved!
Epoch 00015: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFD80> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 52.87485118005463
Total adversarial train loss: 174.10306030511856

[ Test epoch: 16 ]


                                                       


Total benign test accuarcy: 54.01459854014598
Total adversarial test Accuarcy: 41.15665356541269
Model saved: f54.01459854014598
Model Saved!
Epoch 00016: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFA00> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 53.25302892359409
Total adversarial train loss: 172.7299683690071

[ Test epoch: 17 ]


                                                        


Total benign test accuarcy: 53.84615384615385
Total adversarial test Accuarcy: 42.27961819202695
Model saved: f53.84615384615385
Model Saved!
Epoch 00017: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFD80> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 53.239022340500036
Total adversarial train loss: 171.4578753709793

[ Test epoch: 18 ]


                                                        


Total benign test accuarcy: 55.75519371139809
Total adversarial test Accuarcy: 39.58450308815272
Model saved: f55.75519371139809
Model Saved!
Epoch 00018: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FF4C0> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 54.2895160725541
Total adversarial train loss: 170.31835132837296

[ Test epoch: 19 ]


                                                       


Total benign test accuarcy: 53.79000561482313
Total adversarial test Accuarcy: 40.707467714766985
Model saved: f53.79000561482313
Model Saved!
Epoch 00019: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFA00> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 54.12844036697248
Total adversarial train loss: 170.14632646739483

[ Test epoch: 20 ]


                                                        


Total benign test accuarcy: 53.06007860752386
Total adversarial test Accuarcy: 40.14598540145985
Model saved: f53.06007860752386
Model Saved!
Epoch 00020: adjusting learning rate of group <generator object Module.parameters at 0x000002A1888FFA00> to 1.7000e-04.


                                                          


Total adversarial train accuracy: 54.57665102598221
Total adversarial train loss: 166.44353500008583

[ Test epoch: 21 ]


0.34375:  52% 29/56 [00:30<00:29,  1.08s/it]