In [115]:
# Load dependencies

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.models as models
import torch.optim as optim
import os
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchattacks
from utils.psnr import PSNR, SSIM

In [116]:
# Setup CUDA Device

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device using: {device}')

Device using: cuda


In [117]:
model_name = "Mobilenetv3Small"
version = "v3"
training_name = "ARD_PreTrain_Frozen"

height = 224
num_classes = 7
epochs = 100
lr = 0.0001
lr_factor = 0.1
lr_threshold = 6
weight_decay = 0.0002
batch_size = 32

# Attack hyperparameters 
epsilon = 8.0 / 255
alpha = 2.0 / 255
steps = 10

# Knowledge Distillation hyperparameters
temp = 20.0
alpha = 1.0

In [118]:
# Graph writer initialize for data visualization

writer = SummaryWriter("runs/trashbox/" + f'{training_name}--{model_name}.{version}')

In [119]:
preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
test_preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.ToTensor()
])

In [120]:
# Training and Validation dataset

print('===>> Preparing data...')
trash_train_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/train', transform=preprocessing)
trash_train_loader = torch.utils.data.DataLoader(dataset=trash_train_dataset, shuffle=True, batch_size=batch_size)
trash_val_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/val', transform=test_preprocessing)
trash_val_loader = torch.utils.data.DataLoader(dataset=trash_val_dataset, shuffle=True, batch_size=batch_size)

===>> Preparing data...


In [121]:
print('====>> Setting up teacher model...')

# Initialize architecture with modified output features
teacher_model = models.googlenet(weights=models.GoogLeNet_Weights.DEFAULT)
infeatures = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(infeatures, num_classes, True)

# Load saved weights 
checkpoint = torch.load('./best_trained_models/best_AT--Googlenet.v1_epoch98.pth')

if 'module' in list(checkpoint['net'].keys())[0]:
    new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['net'].items()}
    teacher_model.load_state_dict(new_state_dict)
else:
    teacher_model.load_state_dict(checkpoint['net'])

teacher_model = teacher_model.to(device)

for param in teacher_model.parameters():
    param.requires_grad = False

teacher_model.eval()

====>> Setting up teacher model...


GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track

In [122]:
# Setup student model 

print('====>> Setting up student model...')
student_model = models.mobilenet_v3_small(weights=None)

for param in teacher_model.parameters():
    param.requires_grad = False

student_model.classifier = nn.Sequential(
    nn.Linear(in_features=576, out_features=1024, bias=True),
    nn.Hardswish(),
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1024, out_features=num_classes, bias=True)
)

# Load saved weights 
checkpoint = torch.load('./best_trained_models/best_NORMAL--Mobilenetv3Small.v1_epoch40.pth')

if 'module' in list(checkpoint['net'].keys())[0]:
    new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['net'].items()}
    student_model.load_state_dict(new_state_dict)
else:
    student_model.load_state_dict(checkpoint['net'])



student_model = student_model.to(device)

====>> Setting up student model...


In [123]:
# Initialize adversarial attack for generating adversarial samples

attack = torchattacks.TPGD(student_model, eps=epsilon, alpha=alpha, steps=steps)

In [124]:
# Setup Loss functions

XENT_loss = nn.CrossEntropyLoss()
KL_loss = nn.KLDivLoss()

In [125]:
# Train Loop

def train(epoch, optimizer):
    train_loss = 0
    correct = 0
    total = 0
    adv_correct = 0
    student_model.train()
    total_ssim = 0
    total_psnr = 0

    iterator = tqdm(trash_train_loader, ncols=0, leave=False)
    for i, (inputs, targets)in enumerate(iterator):
        inputs, targets = inputs.to(device),targets.to(device)

        optimizer.zero_grad()
        adv_image = attack(inputs, targets)

        # Soft labels
        teacher_output = teacher_model(inputs)
        student_output = student_model(inputs)
        adv_student_output = student_model(adv_image)
        
        # ARD loss function formula
        loss =  alpha* temp* temp*KL_loss(F.log_softmax(adv_student_output/ temp, dim=1),F.softmax(teacher_output/ temp, dim=1))+(1.0- alpha)*XENT_loss(student_output, targets)
        loss.backward()
        optimizer.step()

        # Measure loss
        train_loss += loss.item()
        iterator.set_description(str(loss.item()))      

        # Get total 
        total += targets.size(0)

        # SSIM and PSNR 
        print(inputs.shape)
        psnr = PSNR()
        ssim = SSIM()
        total_psnr += psnr(inputs, adv_image)
        total_ssim += ssim(inputs, adv_image)
        
        # Measure clean and adversarial accuracy 
        _, predicted = student_output.max(1)
        correct += predicted.eq(targets).sum().item()
        _, adv_predicted = adv_student_output.max(1)
        adv_correct += adv_predicted.eq(targets).sum().item()
    
    # SSIM and PSNR Average
    avg_ssim = total_ssim / total
    avg_psnr = total_psnr / total

    writer.add_scalar("Average SSIM " + model_name, avg_ssim, epoch)
    writer.add_scalar("Average PSNR " + model_name, avg_psnr, epoch)

    training_loss = train_loss  / total
    train_adv_accuracy = 100.0 * correct / total
    adv_train_adv_accuracy = 100.0 * adv_correct / total
    
    print('\nTotal adversarial train accuracy:', 100. * correct / total)
    print('Total adversarial train loss:', train_loss)
    
    # Write graph over epoch
    writer.add_scalar('Train loss: ' + model_name, training_loss, epoch)
    writer.add_scalar('Train accuracy: ' + model_name, train_adv_accuracy, epoch)
    writer.add_scalar('Adversarial Train accuracy: ' + model_name, adv_train_adv_accuracy, epoch)

In [126]:
# Test function
best_loss = float(0)

def test(epoch, optimizer):
    global best_loss
    print('\n[ Test epoch: %d ]' % epoch)
    student_model.eval()
    benign_correct = 0
    adv_correct = 0
    total = 0
    with torch.no_grad():
        iterator = tqdm(trash_val_loader, ncols=0, leave=False)
        for i, (inputs, targets) in enumerate(iterator):
            inputs, targets = inputs.to(device), targets.to(device)
            
            total += targets.size(0)

            with torch.enable_grad():
                adv_images = attack(inputs, targets)
            
            # Output
            natural_outputs = student_model(inputs)
            adv_outputs = student_model(adv_images)

            # Prediction
            _, adv_predicted = adv_outputs.max(1)
            _, natural_predicted = natural_outputs.max(1)
            
            # Correct top 1
            benign_correct += natural_predicted.eq(targets).sum().item()
            adv_correct += adv_predicted.eq(targets).sum().item()

            iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0)))
    
    # Adversarial and Clean Accuracy 
    benign_val_accuracy = 100.0 * benign_correct / total
    adv_val_accuracy = 100.0 * adv_correct / total 
    
    # Logs
    print('\nTotal benign test accuarcy:', benign_val_accuracy)
    print('Total adversarial test Accuarcy:', adv_val_accuracy)
    
    # Graph
    writer.add_scalar("Natural test accuracy: " + model_name, benign_val_accuracy, epoch)
    writer.add_scalar("Adversarial test accuracy: " + model_name, adv_val_accuracy, epoch)
    
    # Save checkpoint
    state = {
        'epoch' : epoch,
        'net': student_model.state_dict(),
        'optim' : optimizer.state_dict()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/' + f'{training_name}--{model_name}.{version}.pth')
    if benign_val_accuracy > best_loss:
        print(f'Model saved: f{benign_val_accuracy}')
        torch.save(state, './trained_model/' + f'best_{training_name}_{model_name}_{version}_epoch{epoch}.pth')
    print('Model Saved!')
    return benign_val_accuracy, adv_val_accuracy

In [127]:
def main():
    learning_rate = lr
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate, weight_decay=0.0002)
    model_path = f'./checkpoint/{training_name}--{model_name}.{version}.pth'
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=lr_threshold, factor=lr_factor)
    if os.path.exists(model_path):
        # Load the saved model and optimizer state
        checkpoint = torch.load(model_path)
        student_model.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optim'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"=> Loaded checkpoint '{model_path}' (epoch {start_epoch})")
    else:
        start_epoch = 0
        print(f"=> No checkpoint found at '{model_path}'. Starting training from scratch.")


    for epoch in range(start_epoch, epochs):
        train_loss = train(epoch, optimizer)
        benign_val_accuracy , _ =test(epoch, optimizer)
        scheduler.step(metrics=benign_val_accuracy, epoch=epoch)
        scheduler.print_lr(True, student_model.parameters(), learning_rate, epoch)

In [128]:
if __name__ == '__main__':
    main()

=> No checkpoint found at './checkpoint/ARD_PreTrain_Frozen--Mobilenetv3Small.v3.pth'. Starting training from scratch.


                                               

torch.Size([32, 3, 224, 224])




ValueError: Wrong input image dimensions.