In [15]:
# Load dependencies

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.models as models
import torch.optim as optim 
import os
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchattacks
from torchmetrics.functional.image import peak_signal_noise_ratio, structural_similarity_index_measure

In [16]:
# Setup CUDA Device

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device using: {device}')

Device using: cuda


In [17]:
model_name = "Mobilenetv3Small"
version = "v1"
training_name = "ARD_0.7_Temp=5"

height = 224
num_classes = 7
epochs = 100
lr = 0.0001
lr_factor = 0.1
lr_threshold = 6
weight_decay = 0.0002
batch_size = 32

# Attack hyperparameters 
epsilon = 8.0 / 255
alpha = 2.0 / 255
steps = 4

# Knowledge Distillation hyperparameters
temp = 5.0
alpha = 0.7

In [18]:
# Graph writer initialize for data visualization

writer = SummaryWriter("runs/trashbox/" + f'{training_name}--{model_name}.{version}')

In [19]:
preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
test_preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.ToTensor()
])

In [20]:
# Training and Validation dataset

print('===>> Preparing data...')
trash_train_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/train', transform=preprocessing)
trash_train_loader = torch.utils.data.DataLoader(dataset=trash_train_dataset, shuffle=True, batch_size=batch_size)
trash_val_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/val', transform=test_preprocessing)
trash_val_loader = torch.utils.data.DataLoader(dataset=trash_val_dataset, shuffle=True, batch_size=batch_size)

===>> Preparing data...


In [21]:
print('====>> Setting up teacher model...')

# Initialize architecture with modified output features
teacher_model = models.googlenet(weights=models.GoogLeNet_Weights.DEFAULT)
infeatures = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(infeatures, num_classes, True)

# Load saved weights 
checkpoint = torch.load('./best_trained_models/best_AT--Googlenet.v1_epoch98.pth')

if 'module' in list(checkpoint['net'].keys())[0]:
    new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['net'].items()}
    teacher_model.load_state_dict(new_state_dict)
else:
    teacher_model.load_state_dict(checkpoint['net'])

teacher_model = teacher_model.to(device)

for param in teacher_model.parameters():
    param.requires_grad = False

teacher_model.eval()

====>> Setting up teacher model...


GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track

In [22]:
# Setup student model 

print('====>> Setting up student model...')
student_model = models.mobilenet_v3_small(weights=None)

# for param in student_model.parameters():
#     param.requires_grad = False

student_model.classifier = nn.Sequential(
    nn.Linear(in_features=576, out_features=1024, bias=True),
    nn.Hardswish(),
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1024, out_features=num_classes, bias=True)
)

# Load saved weights 
# checkpoint = torch.load('./best_trained_models/best_NORMAL--Mobilenetv3Small.v1_epoch40.pth')

# if 'module' in list(checkpoint['net'].keys())[0]:
#     new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['net'].items()}
#     student_model.load_state_dict(new_state_dict)
# else:
#     student_model.load_state_dict(checkpoint['net'])

student_model = student_model.to(device)

====>> Setting up student model...


In [23]:
# Initialize adversarial attack for generating adversarial samples

attack = torchattacks.PGD(student_model, eps=epsilon, alpha=alpha, steps=steps)

In [24]:
# Setup Loss functions

XENT_loss = nn.CrossEntropyLoss()
KL_loss = nn.KLDivLoss()

In [25]:
# Train Loop

def train(epoch, optimizer):
    train_loss = 0
    correct = 0
    total = 0
    adv_correct = 0
    student_model.train()
    total_ssim = 0
    total_psnr = 0

    iterator = tqdm(trash_train_loader, ncols=0, leave=False)
    for i, (inputs, targets)in enumerate(iterator):
        inputs, targets = inputs.to(device),targets.to(device)

        optimizer.zero_grad()
        adv_image = attack(inputs, targets)

        # Soft labels
        teacher_output = teacher_model(inputs)
        student_output = student_model(inputs)
        adv_student_output = student_model(adv_image)
        
        # ARD loss function formula
        loss =  alpha* temp* temp*KL_loss(F.log_softmax(adv_student_output/ temp, dim=1),F.softmax(teacher_output/ temp, dim=1))+(1.0- alpha)*XENT_loss(student_output, targets)
        loss.backward()
        optimizer.step()

        # Measure loss
        train_loss += loss.item()
        iterator.set_description(str(loss.item()))      

        # Get total 
        total += targets.size(0)

        # SSIM and PSNR 
        total_psnr += peak_signal_noise_ratio(adv_image, inputs, reduction='sum')
        total_ssim += structural_similarity_index_measure(adv_image, inputs, reduction='sum')
        
        # Measure clean and adversarial accuracy 
        _, predicted = student_output.max(1)
        correct += predicted.eq(targets).sum().item()
        _, adv_predicted = adv_student_output.max(1)
        adv_correct += adv_predicted.eq(targets).sum().item()
    
    # SSIM and PSNR Average
    avg_ssim = total_ssim / total
    avg_psnr = total_psnr / total

    writer.add_scalar("Average SSIM: " + model_name, avg_ssim, epoch)
    writer.add_scalar("Average PSNR: " + model_name, avg_psnr, epoch)

    training_loss = train_loss  / total
    train_adv_accuracy = 100.0 * correct / total
    adv_train_adv_accuracy = 100.0 * adv_correct / total
    
    print('\nTotal adversarial train accuracy:', 100. * correct / total)
    print('Total adversarial train loss:', train_loss)
    
    # Write graph over epoch
    writer.add_scalar('Train loss: ' + model_name, training_loss, epoch)
    writer.add_scalar('Train accuracy: ' + model_name, train_adv_accuracy, epoch)
    writer.add_scalar('Adversarial Train accuracy: ' + model_name, adv_train_adv_accuracy, epoch)

    return training_loss

In [26]:
# Test function
best_acc = float(0)

def test(epoch, optimizer):
    global best_acc
    print('\n[ Test epoch: %d ]' % epoch)
    student_model.eval()
    benign_correct = 0
    adv_correct = 0
    total = 0
    with torch.no_grad():
        iterator = tqdm(trash_val_loader, ncols=0, leave=False)
        for i, (inputs, targets) in enumerate(iterator):
            inputs, targets = inputs.to(device), targets.to(device)
            
            total += targets.size(0)

            with torch.enable_grad():
                adv_images = attack(inputs, targets)
            
            # Output
            natural_outputs = student_model(inputs)
            adv_outputs = student_model(adv_images)

            # Prediction
            _, adv_predicted = adv_outputs.max(1)
            _, natural_predicted = natural_outputs.max(1)
            
            # Correct top 1
            benign_correct += natural_predicted.eq(targets).sum().item()
            adv_correct += adv_predicted.eq(targets).sum().item()

            iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0)))
    
    # Adversarial and Clean Accuracy 
    benign_val_accuracy = 100.0 * benign_correct / total
    adv_val_accuracy = 100.0 * adv_correct / total 
    
    # Logs
    print('\nTotal benign test accuarcy:', benign_val_accuracy)
    print('Total adversarial test Accuarcy:', adv_val_accuracy)
    
    # Graph
    writer.add_scalar("Natural test accuracy: " + model_name, benign_val_accuracy, epoch)
    writer.add_scalar("Adversarial test accuracy: " + model_name, adv_val_accuracy, epoch)
    
    # Save checkpoint
    state = {
        'epoch' : epoch,
        'net': student_model.state_dict(),
        'optim' : optimizer.state_dict()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/' + f'{training_name}--{model_name}.{version}.pth')
    if benign_val_accuracy > best_acc:
        print(f'Model saved: f{benign_val_accuracy}')
        best_acc = benign_val_accuracy
        torch.save(state, './trained_model/' + f'best_{training_name}_{model_name}_{version}_epoch{epoch}.pth')
    print('Model Saved!')
    return benign_val_accuracy, adv_val_accuracy

In [27]:
def main():
    learning_rate = lr
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate, weight_decay=0.0002)
    model_path = f'./checkpoint/{training_name}--{model_name}.{version}.pth'
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=lr_threshold, factor=lr_factor)
    if os.path.exists(model_path):
        # Load the saved model and optimizer state
        checkpoint = torch.load(model_path)
        student_model.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optim'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"=> Loaded checkpoint '{model_path}' (epoch {start_epoch})")
    else:
        start_epoch = 0
        print(f"=> No checkpoint found at '{model_path}'. Starting training from scratch.")


    for epoch in range(start_epoch, epochs):
        train_loss = train(epoch, optimizer)
        benign_val_accuracy , _ = test(epoch, optimizer)
        scheduler.step(metrics=train_loss, epoch=epoch)
        scheduler.print_lr(True, student_model.parameters(), learning_rate, epoch)

In [28]:
if __name__ == '__main__':
    main()

=> Loaded checkpoint './checkpoint/ARD_0.7_Temp=5--Mobilenetv3Small.v1.pth' (epoch 75)


                                                          


Total adversarial train accuracy: 65.06057847188178
Total adversarial train loss: 173.72823160886765

[ Test epoch: 75 ]


                                                       


Total benign test accuarcy: 59.292532285233015
Total adversarial test Accuarcy: 32.45367770915216
Model saved: f59.292532285233015
Model Saved!
Epoch 00075: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 64.46529869038449
Total adversarial train loss: 173.86176976561546

[ Test epoch: 76 ]


                                                        


Total benign test accuarcy: 60.47164514317799
Total adversarial test Accuarcy: 36.10331274564851
Model saved: f60.47164514317799
Model Saved!
Epoch 00076: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 64.83647314237692
Total adversarial train loss: 172.60813653469086

[ Test epoch: 77 ]


                                                       


Total benign test accuarcy: 59.74171813587872
Total adversarial test Accuarcy: 34.7557551937114
Model Saved!
Epoch 00077: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133AE0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 65.34071013376287
Total adversarial train loss: 172.51842385530472

[ Test epoch: 78 ]


                                                        


Total benign test accuarcy: 58.394160583941606
Total adversarial test Accuarcy: 34.19427288040427
Model Saved!
Epoch 00078: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133AE0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 65.64885496183206
Total adversarial train loss: 169.20270401239395

[ Test epoch: 79 ]


                                                        


Total benign test accuarcy: 59.629421673217294
Total adversarial test Accuarcy: 35.09264458169568
Model Saved!
Epoch 00079: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133AE0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 65.64185167028504
Total adversarial train loss: 170.5258210003376

[ Test epoch: 80 ]


                                                        


Total benign test accuarcy: 61.706906232453676
Total adversarial test Accuarcy: 37.057832678270636
Model saved: f61.706906232453676
Model Saved!
Epoch 00080: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133BC0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 65.13761467889908
Total adversarial train loss: 169.595505297184

[ Test epoch: 81 ]


                                                        


Total benign test accuarcy: 60.86468276249298
Total adversarial test Accuarcy: 35.204941044357106
Model Saved!
Epoch 00081: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133140> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 66.24413474332937
Total adversarial train loss: 168.61082424223423

[ Test epoch: 82 ]


                                                        


Total benign test accuarcy: 60.078607523863
Total adversarial test Accuarcy: 33.29590117911286
Model Saved!
Epoch 00082: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133140> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 66.49625323902234
Total adversarial train loss: 166.79673896729946

[ Test epoch: 83 ]


                                                        


Total benign test accuarcy: 61.48231330713082
Total adversarial test Accuarcy: 34.86805165637283
Model Saved!
Epoch 00083: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133AE0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 66.32817424189369
Total adversarial train loss: 167.26428523659706

[ Test epoch: 84 ]


                                                        


Total benign test accuarcy: 60.41549691184728
Total adversarial test Accuarcy: 34.306569343065696
Model Saved!
Epoch 00084: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133760> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 66.18810841095315
Total adversarial train loss: 166.46629737317562

[ Test epoch: 85 ]


                                                        


Total benign test accuarcy: 59.12408759124087
Total adversarial test Accuarcy: 37.00168444693992
Model Saved!
Epoch 00085: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133760> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 66.39120386581693
Total adversarial train loss: 165.8582319021225

[ Test epoch: 86 ]


                                                        


Total benign test accuarcy: 61.706906232453676
Total adversarial test Accuarcy: 35.82257158899495
Model Saved!
Epoch 00086: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133AE0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.44870088941802
Total adversarial train loss: 163.87282022833824

[ Test epoch: 87 ]


                                                       


Total benign test accuarcy: 60.30320044918585
Total adversarial test Accuarcy: 36.60864682762493
Model Saved!
Epoch 00087: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133BC0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 66.46824007283423
Total adversarial train loss: 164.537299990654

[ Test epoch: 88 ]


                                                        


Total benign test accuarcy: 60.30320044918585
Total adversarial test Accuarcy: 36.27175743964065
Model Saved!
Epoch 00088: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133AE0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.28062189228937
Total adversarial train loss: 164.63056966662407

[ Test epoch: 89 ]


                                                        


Total benign test accuarcy: 59.40482874789444
Total adversarial test Accuarcy: 37.3385738349242
Model Saved!
Epoch 00089: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133AE0> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 66.79039148399748
Total adversarial train loss: 164.00206357240677

[ Test epoch: 90 ]


                                                       


Total benign test accuarcy: 59.74171813587872
Total adversarial test Accuarcy: 35.93486805165637
Model Saved!
Epoch 00090: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133760> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.39267455704181
Total adversarial train loss: 162.35554759204388

[ Test epoch: 91 ]


                                                        


Total benign test accuarcy: 59.40482874789444
Total adversarial test Accuarcy: 35.99101628298708
Model Saved!
Epoch 00091: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.43469430632398
Total adversarial train loss: 161.01713724434376

[ Test epoch: 92 ]


                                                        


Total benign test accuarcy: 61.03312745648512
Total adversarial test Accuarcy: 35.261089275687816
Model Saved!
Epoch 00092: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133760> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.4136844316829
Total adversarial train loss: 161.74532687664032

[ Test epoch: 93 ]


                                                       


Total benign test accuarcy: 59.06793935991016
Total adversarial test Accuarcy: 35.317237507018525
Model Saved!
Epoch 00093: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133140> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.53974367952938
Total adversarial train loss: 162.15012370049953

[ Test epoch: 94 ]


                                                        


Total benign test accuarcy: 61.53846153846154
Total adversarial test Accuarcy: 33.40819764177429
Model Saved!
Epoch 00094: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.40668114013586
Total adversarial train loss: 159.75850777328014

[ Test epoch: 95 ]


                                                       


Total benign test accuarcy: 58.05727119595733
Total adversarial test Accuarcy: 38.23694553621561
Model Saved!
Epoch 00095: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133760> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.700819385111
Total adversarial train loss: 160.87530943751335

[ Test epoch: 96 ]


                                                       


Total benign test accuarcy: 61.089275687815835
Total adversarial test Accuarcy: 35.99101628298708
Model Saved!
Epoch 00096: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 68.10000700329155
Total adversarial train loss: 159.9154079258442

[ Test epoch: 97 ]


                                                        


Total benign test accuarcy: 60.02245929253228
Total adversarial test Accuarcy: 37.28242560359349
Model Saved!
Epoch 00097: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 67.97394775544505
Total adversarial train loss: 159.7767633497715

[ Test epoch: 98 ]


                                                       


Total benign test accuarcy: 59.57327344188658
Total adversarial test Accuarcy: 36.552498596294214
Model Saved!
Epoch 00098: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


                                                          


Total adversarial train accuracy: 68.1910497934029
Total adversarial train loss: 158.32385815680027

[ Test epoch: 99 ]


                                                       


Total benign test accuarcy: 58.338012352610896
Total adversarial test Accuarcy: 36.83323975294778
Model Saved!
Epoch 00099: adjusting learning rate of group <generator object Module.parameters at 0x0000021216133680> to 1.0000e-04.


