# SPML HW3 - Defensive Distillation (30)


In this notebook, you are going to attack a defensively distilled model

Please write your code in specified sections and do not change anything else. If you have a question regarding this homework, please ask it on the course page.

Also, it is recommended to use google colab to do this homework. You can connect to your drive using the code below:

## Initializations

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import resnet18
import torch.optim as optim

## Defining Teacher and Student Classes

In [2]:
class Teacher(nn.Module):
  def __init__(self, num_cls, T=1):
    super().__init__()
    self.conv = nn.Sequential(
        *list(resnet18(pretrained=False).children())[:-2])
    
    self.fc = nn.Linear(512, num_cls)
    self.temp = T
  
  def forward(self, x, T=None):
    if T is None:
      T = self.temp
    x = self.conv(x)
    x = torch.flatten(x, start_dim=1)
    logits = self.fc(x)
    output = torch.softmax(logits / T, dim=1)

    return logits, output

class Student(nn.Module):
  def __init__(self, num_cls, T=1):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, num_cls)
    self.temp = T

  def forward(self, x, T=None):
    if T is None:
      T = self.temp
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    logits = self.fc3(x)

    output = torch.softmax(logits / T, dim=1)
    return logits, output

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

###################### Problem 1 (3 points) ####################################
# todo: Define your data loaders for training, testing, and validation         #
################################################################################

# your code goes here

train_size = int(0.8 * len(trainset))
valid_size = len(trainset) - train_size
trainset, validset = torch.utils.data.random_split(trainset, [train_size, valid_size])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
################################ End ###########################################

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:14<00:00, 12114555.64it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Training

### Teacher Network Training

In [4]:
def standard_train(model, loader, num_epoch, optimizer, criterion, T=None, 
                   device=device):
  ###################### Problem 2 (4 points) ##################################
  # todo: Iterate over loader in each epoch                                    #
  # todo: Compute the model's output for each batch at the given temperature T #
  # todo: Compute the loss function and take a step by the optimizer           #
  # todo: Monitor the training procedure                                       #
  ##############################################################################

  # your code goes here
    model.train() 
    model.to(device)  
    
    for epoch in range(num_epoch):
        running_loss = 0.0
        
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()  
            
            if T is None:
                logits, outputs = model(inputs)
            else:
                logits, outputs = model(inputs, T=T)

            loss = criterion(logits, labels)  
            loss.backward()  
            optimizer.step()  
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(loader.dataset)
        print(f"Epoch {epoch+1}/{num_epoch}, Loss: {epoch_loss:.4f}")
  ##############################################################################

In [5]:
T = 100
teacher = Teacher(len(classes), T=T).to(device)
teacher_optim = optim.Adam(teacher.parameters())
teacher_criterion = nn.CrossEntropyLoss()



In [None]:
standard_train(model=teacher, 
            loader=trainloader, 
            num_epoch=15, 
            optimizer=teacher_optim, 
            criterion=teacher_criterion, 
            T=T, 
            device=device)

Epoch 1/15, Loss: 1.3531
Epoch 2/15, Loss: 0.9579
Epoch 3/15, Loss: 0.7785
Epoch 4/15, Loss: 0.6561
Epoch 5/15, Loss: 0.5515
Epoch 6/15, Loss: 0.4605
Epoch 7/15, Loss: 0.3744
Epoch 8/15, Loss: 0.3028
Epoch 9/15, Loss: 0.2465
Epoch 10/15, Loss: 0.1995
Epoch 11/15, Loss: 0.1653
Epoch 12/15, Loss: 0.1403
Epoch 13/15, Loss: 0.1217
Epoch 14/15, Loss: 0.1076
Epoch 15/15, Loss: 0.0991


### Student Network Training

In [6]:
def distillation(teacher, student, loader, num_epoch, optimizer, criterion, 
                 T=None, device=device):
  ###################### Problem 3 (6 points) ##################################
  # todo: Iterate over loader in each epoch                                    #
  # todo: Compute MSE loss between student's logit and teacher's logit         #
  # todo: Take a step by the optimizer                                         #
  # todo: Monitor the training procedure                                       #
  ##############################################################################

  # your code goes here    
    teacher.eval()  
    teacher.to(device)  
    student.train()  
    student.to(device)  
    
    for epoch in range(num_epoch):
        running_loss = 0.0
        
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad() 
            teacher_logits, _ = teacher(inputs, T=T) 
            student_logits, _ = student(inputs)  

            loss = criterion(student_logits, teacher_logits)  
            loss.backward() 
            optimizer.step()  
            
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(loader.dataset)
        print(f"Epoch {epoch+1}/{num_epoch}, Loss: {epoch_loss:.4f}")
  ################################ End #########################################

In [10]:
T = 100
student = Student(len(classes), T=T).to(device)
student_optim = optim.Adam(student.parameters())
std_criterion = nn.MSELoss()

In [None]:
distillation(teacher=teacher, 
             student=student, 
             loader=trainloader, 
             num_epoch=100, 
             optimizer=student_optim, 
             criterion=std_criterion, 
             T=T, 
             device=device)

Epoch 1/100, Loss: 28.4551
Epoch 2/100, Loss: 20.7455
Epoch 3/100, Loss: 18.3272
Epoch 4/100, Loss: 16.9647
Epoch 5/100, Loss: 15.7545
Epoch 6/100, Loss: 14.8682
Epoch 7/100, Loss: 14.0867
Epoch 8/100, Loss: 13.4199
Epoch 9/100, Loss: 12.8785
Epoch 10/100, Loss: 12.4718
Epoch 11/100, Loss: 12.0440
Epoch 12/100, Loss: 11.6775
Epoch 13/100, Loss: 11.4168
Epoch 14/100, Loss: 11.1060
Epoch 15/100, Loss: 10.9150
Epoch 16/100, Loss: 10.6756
Epoch 17/100, Loss: 10.5095
Epoch 18/100, Loss: 10.3439
Epoch 19/100, Loss: 10.1503
Epoch 20/100, Loss: 9.9972
Epoch 21/100, Loss: 9.8667
Epoch 22/100, Loss: 9.7239
Epoch 23/100, Loss: 9.6335
Epoch 24/100, Loss: 9.4485
Epoch 25/100, Loss: 9.3934
Epoch 26/100, Loss: 9.2665
Epoch 27/100, Loss: 9.2041
Epoch 28/100, Loss: 9.1264
Epoch 29/100, Loss: 9.0207
Epoch 30/100, Loss: 8.9613
Epoch 31/100, Loss: 8.8919
Epoch 32/100, Loss: 8.8218
Epoch 33/100, Loss: 8.7165
Epoch 34/100, Loss: 8.7040
Epoch 35/100, Loss: 8.6447
Epoch 36/100, Loss: 8.5949
Epoch 37/100, Loss

### Computing Clean Accuracy

In [None]:
def standard_test(model, loader, T=1, device=device):
    correct = 0
    total = 0
    ###################### Problem 4 (3 points) ##################################
    # todo: Iterate over loader, compute the output and predicted                #
    # label, and update "correct" and "total" counters accordingly.              # 
    ##############################################################################
    
    # your code goes here
    model.eval()  
    model.to(device) 
    
    correct = 0
    total = 0
    
    with torch.no_grad():  
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            if T is None:
                _, outputs = model(inputs)
            else:
                _, outputs = model(inputs, T=T)
            
            _, predicted = torch.max(outputs.data, 1)  
            total += labels.size(0)  
            correct += (predicted == labels).sum().item()  
    
    ################################ End #########################################
    print(f'Clean accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [None]:
standard_test(student, testloader)

Clean accuracy of the network on the 10000 test images: 66 %


## Attacks

In [21]:
def fgsm_attack (model, X, y, epsilon, T=1, mode="output"):
    delta = torch.zeros_like(X, requires_grad=True).to(device)
    ###################### Problem 5 (9 points) ######################
    # todo: Perform forward path on model with the image             #
    # todo: Compute loss:                                            #
    #       - In output mode, set cross entropy as the loss function #
    #       - In logit mode, set the logit value of the target label #
    #         as the loss function                                   #
    # todo: Perform backward path on loss function                   #
    # todo: Calculate the gradient w.r.t. the data                   #
    # todo: Determine delta based on the gradient and epsilon        #
    # Also, if the perturbed image exceeds the valid range, clamp    #
    # the delta in order to obtain an image in the valid range       #  
    ##################################################################

    # your code goes here
    X = X.to(device) 
    y = y.to(device)  
    X.requires_grad = True
    logits, outputs = model(X, T=T)  

    if mode == "output":
        loss = nn.CrossEntropyLoss()(outputs, y)  
    elif mode == "logit":
        loss = -logits[torch.arange(logits.size(0)), y].mean()

    model.zero_grad()  
    loss.backward()  
    gradient = X.grad.detach().data
    delta.data = epsilon * torch.sign(gradient)
    delta.data = torch.clamp(delta.data, -epsilon, epsilon)

    #perturbed_image = torch.clamp(X + delta, 0, 1)

    return X+delta
  ########################### End ##################################
    

In [28]:
def attack_test(model, attack_model, loader, mode="output", epsilon=4/255, T=1, 
                device=device):
    correct = 0
    total = 0
    ###################### Problem 6 (4.5 points) ################################
    # todo: Iterate over loader                                                  #
    # todo: Find an adversarial example by FGSM attack on attack_model           #
    # todo: Compute the output and predicted label, and updated "correct" and    #
    # "total" counters accordingly.                                              # 
    ##############################################################################

    # your code goes here
      
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        adversarial_images = fgsm_attack(attack_model, images, labels, epsilon, T, mode)
        
        outputs, _ = model(adversarial_images, T=T)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    ################################ End #########################################
    print(f'Accuracy of the network on the 10000 test images after attacking {mode} layer: {100 * correct // total} %')

In [29]:
###################### Problem 7 (1.5 points) ##################################
# todo: Report the accuracy of the student model under attack output and logit #
#       layers                                                                 #
# todo: Do not forget to set the temperature of the student model to 1         #
################################################################################

# your code goes here
attack_test(model=student,
            attack_model=teacher,
            loader=testloader,
            mode="output",
            epsilon=4/255,
            T=1,
            device=device)

attack_test(model=student,
            attack_model=teacher,
            loader=testloader,
            mode="logit",
            epsilon=4/255,
            T=1,
            device=device)

attack_test(model=student,
            attack_model=student,
            loader=testloader,
            mode="output",
            epsilon=4/255,
            T=1,
            device=device)

attack_test(model=student,
            attack_model=student,
            loader=testloader,
            mode="logit",
            epsilon=4/255,
            T=1,
            device=device)
################################ End ###########################################

Accuracy of the network on the 10000 test images after attacking output layer: 63 %
Accuracy of the network on the 10000 test images after attacking logit layer: 60 %
Accuracy of the network on the 10000 test images after attacking output layer: 58 %
Accuracy of the network on the 10000 test images after attacking logit layer: 22 %


همانطور که می‌بینیم دقت حمله زمانی که روی لاجیت‌های شبکه‌ی استیودنت اعمال می‌شود بیشتر است و دقت نهایی کاهش می‌یابد. به صورت کلی هم حمله‌ی FGSM روی این دفاع نتوانسته زیاد موفق عمل کند که این امر مطابق انتظار است.