In [80]:
# Load dependencies

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.models as models
import torch.optim as optim
import os
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchattacks

In [81]:
# Setup CUDA Device

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device using: {device}')

Device using: cuda


In [82]:
model_name = "Mobilenetv3Small"
version = "v1"
training_name = "ARD_NotPretrainMobilenet"

height = 224
num_classes = 7
epochs = 100
lr = 0.0001
lr_factor = 0.1
lr_threshold = 6
weight_decay = 0.0002
batch_size = 32

# Attack hyperparameters 
epsilon = 8.0 / 255
alpha = 2.0 / 255
steps = 10

# Knowledge Distillation hyperparameters
temp = 20.0
alpha = 1.0

In [83]:
# Graph writer initialize for data visualization

writer = SummaryWriter("runs/trashbox/" + f'{training_name}--{model_name}.{version}')

In [84]:

# Initialize pre-trained weights - IMAGENET
# Get the preprocessing for the respective pre-trained model

preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=mean, std=std)
])
test_preprocessing = transforms.Compose([
    transforms.RandomResizedCrop((height, height)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=mean, std=std)
])

In [85]:
# Training and Validation dataset

print('===>> Preparing data...')
trash_train_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/train', transform=preprocessing)
trash_train_loader = torch.utils.data.DataLoader(dataset=trash_train_dataset, shuffle=True, batch_size=batch_size)
trash_val_dataset = torchvision.datasets.ImageFolder('dataset/trashbox/val', transform=test_preprocessing)
trash_val_loader = torch.utils.data.DataLoader(dataset=trash_val_dataset, shuffle=True, batch_size=batch_size)

===>> Preparing data...


In [86]:
# Setup model 

print('====>> Setting up teacher model...')
checkpoint = torch.load('./best_trained_models/best_AT--Googlenet.v1_epoch98.pth')
# model = local_models.xception(num_classes=num_classes).to(device)
teacher_model = models.googlenet(weights=models.GoogLeNet_Weights.DEFAULT)
infeatures = teacher_model.fc.in_features

for param in teacher_model.parameters():
    param.requires_grad = False



teacher_model.fc = nn.Linear(infeatures, num_classes, True)



if 'module' in list(checkpoint['net'].keys())[0]:
    new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['net'].items()}
    teacher_model.load_state_dict(new_state_dict)
else:
    teacher_model.load_state_dict(checkpoint['net'])

teacher_model = teacher_model.to(device)


====>> Setting up teacher model...


In [87]:
# Setup model 

print('====>> Setting up student model...')
# model = local_models.xception(num_classes=num_classes).to(device)
model = models.mobilenet_v3_small(weights=None)
model.classifier = nn.Sequential(
    nn.Linear(in_features=576, out_features=1024, bias=True),
    nn.Hardswish(),
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1024, out_features=num_classes, bias=True)
)
model = model.to(device)


====>> Setting up student model...


In [88]:
# Setup Adversarial attack 

attack = torchattacks.PGD(model=model, eps=epsilon, alpha=alpha, steps=steps)

class AttackPGD(nn.Module):
    def __init__(self, basic_net, config):
        super(AttackPGD, self).__init__()
        self.basic_net = basic_net
        self.step_size = config['step_size']
        self.epsilon = config['epsilon']
        self.num_steps = config['num_steps']

    def forward(self, inputs, targets):
        x = inputs.detach()
        x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                loss = F.cross_entropy(self.basic_net(x), targets, size_average=False)
            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + self.step_size*torch.sign(grad.detach())
            x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
            x = torch.clamp(x, 0.0, 1.0)
        return self.basic_net(x), x

In [89]:
config = {
    'step_size' : alpha,
    'epsilon' : epsilon, 
    'num_steps' : steps
}
net = AttackPGD(model, config)

In [90]:
if device == 'cuda':
    cudnn.benchmark = True

In [91]:
# Setup Loss function

XENT_loss = nn.CrossEntropyLoss()
KL_loss = nn.KLDivLoss()

In [92]:
# Train Loop

def train(epoch, optimizer):
    train_loss = 0
    correct = 0
    total = 0
    adv_correct = 0
    net.train()
    iterator = tqdm(trash_train_loader, ncols=0, leave=False)
    for i, (inputs, targets)in enumerate(iterator):
        inputs, targets = inputs.to(device),targets.to(device)

        optimizer.zero_grad()
        # adv_image = attack(inputs, targets)
        output,per_inputs = net(inputs, targets)
        teacher_output = teacher_model(inputs)
        basic_output = model(inputs)
        # adv_basic_output = model(adv_image)
        # _, predicted = adv_basic_output.max(1)
        
        loss =  alpha* temp* temp*KL_loss(F.log_softmax(output/ temp, dim=1),F.softmax(teacher_output/ temp, dim=1))+(1.0- alpha)*XENT_loss(basic_output, targets)
        loss.backward()
    
        optimizer.step()
        train_loss += loss.item()
        iterator.set_description(str(loss.item()))       
        total += targets.size(0)
        
        _, predicted = basic_output.max(1)
        correct += predicted.eq(targets).sum().item()
        _, adv_predicted = output.max(1)
        adv_correct += adv_predicted.eq(targets).sum().item()
        
    validation_loss = train_loss  / len(trash_val_dataset)
    train_adv_accuracy = 100.0 * correct / total
    adv_train_adv_accuracy = 100.0 * adv_correct / total
    
    print('\nTotal adversarial train accuarcy:', 100. * correct / total)
    print('Total adversarial train loss:', train_loss)
    
    writer.add_scalar('Train loss: ' + model_name, validation_loss, epoch)
    writer.add_scalar('Train accuracy: ' + model_name, train_adv_accuracy, epoch)
    writer.add_scalar('Adversarial Train accuracy: ' + model_name, adv_train_adv_accuracy, epoch)

In [93]:
# Test function
best_loss = float(0)

def test(epoch, optimizer):
    global best_loss
    print('\n[ Test epoch: %d ]' % epoch)
    net.eval()
    benign_loss = 0
    adv_loss = 0
    benign_correct = 0
    adv_correct = 0
    total = 0
    with torch.no_grad():
        iterator = tqdm(trash_val_loader, ncols=0, leave=False)
        for i, (inputs, targets) in enumerate(iterator):
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)
        
            adv_outputs, pert_inputs = net(inputs, targets)
            natural_outputs = model(inputs)
            _, adv_predicted = adv_outputs.max(1)
            _, natural_predicted = natural_outputs.max(1)
            
            benign_correct += natural_predicted.eq(targets).sum().item()
            adv_correct += adv_predicted.eq(targets).sum().item()

            iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0)))
    
    benign_val_accuracy = 100.0 * benign_correct / total
    adv_val_accuracy = 100.0 * adv_correct / total 
    
    print('\nTotal benign test accuarcy:', benign_val_accuracy)
    print('Total adversarial test Accuarcy:', adv_val_accuracy)
    
    # Graph
    writer.add_scalar("Natural test accuracy: " + model_name, benign_val_accuracy, epoch)
    writer.add_scalar("Adversarial test accuracy: " + model_name, adv_val_accuracy, epoch)
    
    # Save checkpoint
    state = {
        'epoch' : epoch,
        'net': model.state_dict(),
        'second_net' : net.state_dict(),
        'optim' : optimizer.state_dict()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/' + f'{training_name}--{model_name}.{version}.pth')
    if benign_val_accuracy > best_loss:
        print(f'Model saved: f{benign_val_accuracy}')
        torch.save(state, './trained_model/' + f'best_{training_name}_{model_name}_{version}_epoch{epoch}.pth')
    print('Model Saved!')
    return benign_val_accuracy, adv_val_accuracy

In [94]:
def main():
    learning_rate = lr
    optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.0002)
    model_path = f'./checkpoint/{training_name}--{model_name}.{version}.pth'
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=lr_threshold, factor=lr_factor)
    if os.path.exists(model_path):
        # Load the saved model and optimizer state
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optim'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"=> Loaded checkpoint '{model_path}' (epoch {start_epoch})")
    else:
        start_epoch = 0
        print(f"=> No checkpoint found at '{model_path}'. Starting training from scratch.")


    for epoch in range(start_epoch, epochs):
        train_loss = train(epoch, optimizer)
        benign_val_accuracy , _ =test(epoch, optimizer)
        scheduler.step(metrics=benign_val_accuracy, epoch=epoch)
        scheduler.print_lr(True, model.parameters(), learning_rate, epoch)

In [95]:
if __name__ == '__main__':
    main()

=> Loaded checkpoint './checkpoint/ARD_NotPretrainMobilenet--Mobilenetv3Small.v1.pth' (epoch 89)


                                                           


Total adversarial train accuarcy: 54.828769521675184
Total adversarial train loss: 15.318321684375405

[ Test epoch: 89 ]





Total benign test accuarcy: 57.495788882650196
Total adversarial test Accuarcy: 70.01684446939922
Model saved: f57.495788882650196
Model Saved!
Epoch 00089: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7920> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 55.16492751593249
Total adversarial train loss: 15.098812878131866

[ Test epoch: 90 ]


                                                       


Total benign test accuarcy: 55.3621560920831
Total adversarial test Accuarcy: 69.79225154407636
Model saved: f55.3621560920831
Model Saved!
Epoch 00090: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7920> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 55.311996638420055
Total adversarial train loss: 15.196843275800347

[ Test epoch: 91 ]


                                                       


Total benign test accuarcy: 56.26052779337451
Total adversarial test Accuarcy: 69.39921392476137
Model saved: f56.26052779337451
Model Saved!
Epoch 00091: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7920> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 54.6256740668114
Total adversarial train loss: 15.176218964159489

[ Test epoch: 92 ]


                                                       


Total benign test accuarcy: 56.99045480067378
Total adversarial test Accuarcy: 69.56765861875351
Model saved: f56.99045480067378
Model Saved!
Epoch 00092: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7840> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 54.919812311786536
Total adversarial train loss: 15.429466512054205

[ Test epoch: 93 ]


                                                       


Total benign test accuarcy: 57.776530039303765
Total adversarial test Accuarcy: 68.05165637282425
Model saved: f57.776530039303765
Model Saved!
Epoch 00093: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7A00> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 55.1929406821206
Total adversarial train loss: 15.160516142845154

[ Test epoch: 94 ]


                                                       


Total benign test accuarcy: 55.24985962942167
Total adversarial test Accuarcy: 68.16395283548569
Model saved: f55.24985962942167
Model Saved!
Epoch 00094: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7840> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 55.55711184256601
Total adversarial train loss: 14.971419904381037

[ Test epoch: 95 ]


                                                       


Total benign test accuarcy: 56.541268950028076
Total adversarial test Accuarcy: 69.90454800673778
Model saved: f56.541268950028076
Model Saved!
Epoch 00095: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7920> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 54.71671685692275
Total adversarial train loss: 15.128390615805984

[ Test epoch: 96 ]


                                                       


Total benign test accuarcy: 55.8113419427288
Total adversarial test Accuarcy: 69.45536215609208
Model saved: f55.8113419427288
Model Saved!
Epoch 00096: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7840> to 1.0000e-04.


                                                           


Total adversarial train accuarcy: 54.60466419217032
Total adversarial train loss: 15.06790179759264

[ Test epoch: 97 ]


                                                       


Total benign test accuarcy: 55.30600786075239
Total adversarial test Accuarcy: 69.39921392476137
Model saved: f55.30600786075239
Model Saved!
Epoch 00097: adjusting learning rate of group <generator object Module.parameters at 0x00000235391F7920> to 1.0000e-04.


                                                           

KeyboardInterrupt: 