# SPML HW3 - Defensive Distillation (30)

### Deadline: 1402/2/30

#### Name: Hamidreza Amirzadeh
#### Student No.: 401206999

In this notebook, you are going to attack a defensively distilled model

Please write your code in specified sections and do not change anything else. If you have a question regarding this homework, please ask it on the course page.

Also, it is recommended to use google colab to do this homework. You can connect to your drive using the code below:

## Initializations

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import resnet18
import torch.optim as optim

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Defining Teacher and Student Classes

In [3]:
class Teacher(nn.Module):
  def __init__(self, num_cls, T=1):
    super().__init__()
    self.conv = nn.Sequential(
        *list(resnet18(pretrained=False).children())[:-2])
    
    self.fc = nn.Linear(512, num_cls)
    self.temp = T
  
  def forward(self, x, T=None):
    if T is None:
      T = self.temp
    x = self.conv(x)
    x = torch.flatten(x, start_dim=1)
    logits = self.fc(x)
    output = torch.softmax(logits / T, dim=1)

    return logits, output

class Student(nn.Module):
  def __init__(self, num_cls, T=1):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, num_cls)
    self.temp = T

  def forward(self, x, T=None):
    if T is None:
      T = self.temp
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    logits = self.fc3(x)

    output = torch.softmax(logits / T, dim=1)
    return logits, output

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

###################### Problem 1 (3 points) ####################################
# todo: Define your data loaders for training, testing, and validation         #
################################################################################
# your code goes here
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
################################ End ###########################################

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 86394534.94it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Training

### Teacher Network Training

In [5]:
def standard_train(model, loader, num_epoch, optimizer, criterion, T=None, 
                   device=device):
  ###################### Problem 2 (4 points) ##################################
  # todo: Iterate over loader in each epoch                                    #
  # todo: Compute the model's output for each batch at the given temperature T #
  # todo: Compute the loss function and take a step by the optimizer           #
  # todo: Monitor the training procedure                                       #
  ##############################################################################

  # your code goes here
    for epoch in range(num_epoch):
      running_loss = 0.0
      for i, data in enumerate(loader, 0):
          inputs, labels = data[0].to(device), data[1].to(device)
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = criterion(outputs[1], labels)
          loss.backward()
          optimizer.step()
          running_loss += loss.item()

      print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))  
  ##############################################################################

In [6]:
T = 100
teacher = Teacher(len(classes), T=T).to(device)
teacher_optim = optim.Adam(teacher.parameters())
teacher_criterion = nn.CrossEntropyLoss()



In [10]:
# load teacher model
model_name = "teacherModel.pth"
teacher_model_PATH = "/content/drive/MyDrive/" + model_name
state_dict = torch.load(teacher_model_PATH)
teacher.load_state_dict(state_dict)
teacher = teacher.to(device)

In [None]:
standard_train(model=teacher, 
            loader=trainloader, 
            num_epoch=15, 
            optimizer=teacher_optim, 
            criterion=teacher_criterion, 
            T=T, 
            device=device)

Epoch 1 loss: 2.103
Epoch 2 loss: 1.950
Epoch 3 loss: 1.887
Epoch 4 loss: 1.842
Epoch 5 loss: 1.811
Epoch 6 loss: 1.789
Epoch 7 loss: 1.775
Epoch 8 loss: 1.759
Epoch 9 loss: 1.743
Epoch 10 loss: 1.733
Epoch 11 loss: 1.723
Epoch 12 loss: 1.714
Epoch 13 loss: 1.704
Epoch 14 loss: 1.697
Epoch 15 loss: 1.691


In [9]:
# save teacher model
teacher.eval()
model_name = "teacherModel.pth"
teacher_model_PATH = "/content/drive/MyDrive/" + model_name
torch.save(teacher.state_dict(), teacher_model_PATH)

### Student Network Training

In [11]:
def distillation(teacher, student, loader, num_epoch, optimizer, criterion, 
                 T=None, device=device):
  ###################### Problem 3 (6 points) ##################################
  # todo: Iterate over loader in each epoch                                    #
  # todo: Compute MSE loss between student's logit and teacher's logit         #
  # todo: Take a step by the optimizer                                         #
  # todo: Monitor the training procedure                                       #
  ##############################################################################
  # your code goes here    
  for epoch in range(num_epoch):
    running_loss = 0.0
    for i, data in enumerate(loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = student(inputs)
        teacher.eval()
        with torch.no_grad():
          outputs_teacher = teacher(inputs)
        loss = criterion(outputs[0], outputs_teacher[0])
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print('Epoch %d loss: %.3f' % (epoch, running_loss / len(trainloader))) 
  ################################ End #########################################

In [12]:
T = 100
student = Student(len(classes), T=T).to(device)
student_optim = optim.Adam(student.parameters())
std_criterion = nn.MSELoss()

In [14]:
# load student model
model_name = "studentModel.pth"
student_model_PATH = "/content/drive/MyDrive/" + model_name
state_dict = torch.load(student_model_PATH)
student.load_state_dict(state_dict)
student = student.to(device)

In [None]:
distillation(teacher=teacher, 
             student=student, 
             loader=trainloader, 
             num_epoch=50, 
             optimizer=student_optim, 
             criterion=std_criterion, 
             T=T, 
             device=device)

Epoch 0 loss: 197814.864
Epoch 1 loss: 181382.306
Epoch 2 loss: 164789.905
Epoch 3 loss: 155433.515
Epoch 4 loss: 147304.658
Epoch 5 loss: 139767.458
Epoch 6 loss: 133575.293
Epoch 7 loss: 128247.517
Epoch 8 loss: 123480.855
Epoch 9 loss: 118911.982
Epoch 10 loss: 115263.871
Epoch 11 loss: 111419.158
Epoch 12 loss: 108117.410
Epoch 13 loss: 105084.380
Epoch 14 loss: 102011.637
Epoch 15 loss: 99757.691
Epoch 16 loss: 97059.545
Epoch 17 loss: 95027.708
Epoch 18 loss: 92772.566
Epoch 19 loss: 90987.901
Epoch 20 loss: 89101.935
Epoch 21 loss: 87445.933
Epoch 22 loss: 86005.737
Epoch 23 loss: 84493.557
Epoch 24 loss: 82978.072
Epoch 25 loss: 81742.326
Epoch 26 loss: 80430.235
Epoch 27 loss: 79518.892
Epoch 28 loss: 78344.848
Epoch 29 loss: 77262.287
Epoch 30 loss: 76446.392
Epoch 31 loss: 75579.115
Epoch 32 loss: 74706.370
Epoch 33 loss: 73979.775
Epoch 34 loss: 73318.592
Epoch 35 loss: 72400.093
Epoch 36 loss: 71681.119
Epoch 37 loss: 71126.060
Epoch 38 loss: 70633.454
Epoch 39 loss: 69883

In [None]:
# save student model
student.eval()
model_name = "studentModel.pth"
teacher_model_PATH = "/content/drive/MyDrive/" + model_name
torch.save(student.state_dict(), teacher_model_PATH)

### Computing Clean Accuracy

In [15]:
def standard_test(model, loader, T=1, device=device):
  correct = 0
  total = 0
  ###################### Problem 4 (3 points) ##################################
  # todo: Iterate over loader, compute the output and predicted                #
  # label, and update "correct" and "total" counters accordingly.              # 
  ##############################################################################
  # your code goes here
  student.eval()
  with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = student(images)
        _, predicted = torch.max(outputs[1].data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
  ################################ End #########################################
  print(f'Clean accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [16]:
standard_test(student, testloader)

Clean accuracy of the network on the 10000 test images: 62 %


## Attacks

In [17]:
mean=[0.5, 0.5, 0.5]
std=[0.5, 0.5, 0.5]
inv_normalize = torchvision.transforms.Normalize(mean=[-m/s for m, s in zip(mean, std)], std=[1/s for s in std])
normalize = torchvision.transforms.Normalize(mean=mean, std=std)

In [18]:
def fgsm_attack (model, X, y, epsilon, T=1, mode="output"):
  delta = torch.zeros_like(X, requires_grad=True).to(device)
  ###################### Problem 5 (9 points) ######################
  # todo: Perform forward path on model with the image             #
  # todo: Compute loss:                                            #
  #       - In output mode, set cross entropy as the loss function #
  #       - In logit mode, set the logit value of the target label #
  #         as the loss function                                   #
  # todo: Perform backward path on loss function                   #
  # todo: Calculate the gradient w.r.t. the data                   #
  # todo: Determine delta based on the gradient and epsilon        #
  # Also, if the perturbed image exceeds the valid range, clamp    #
  # the delta in order to obtain an image in the valid range       #  
  ##################################################################
  # your code goes here
  
  X.requires_grad = True 
  out = model(X , T=1)
  if mode == "output":
    loss = F.nll_loss(out[1][0], y)
    model.zero_grad()
    loss.backward()
    perturbed_image = X + epsilon * X.grad.data.sign()
    perturbed_image = inv_normalize(perturbed_image)
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    perturbed_image = normalize(perturbed_image)
    perturbation = epsilon * X.grad.data.sign()

  else:
    loss = out[0][0][y]
    model.zero_grad()
    loss.backward()
    perturbed_image = X - epsilon * X.grad.data.sign()
    perturbed_image = inv_normalize(perturbed_image)
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    perturbed_image = normalize(perturbed_image)
    perturbation = -epsilon * X.grad.data.sign()
  ########################### End ##################################
    
  return perturbation + X

In [20]:
def attack_test(model, attack_model, loader, mode="output", epsilon=4/255, T=1, 
                device=device):
  correct = 0
  total = 0
  ###################### Problem 6 (4.5 points) ################################
  # todo: Iterate over loader                                                  #
  # todo: Find an adversarial example by FGSM attack on attack_model           #
  # todo: Compute the output and predicted label, and updated "correct" and    #
  # "total" counters accordingly.                                              # 
  ##############################################################################
  # your code goes here
  for data in loader:
    images, labels = data
    images, labels = images.to(device), labels.to(device)
    purterbed = []
    for i in range(len(images)):
      purterbed.append(fgsm_attack(attack_model, images[i].unsqueeze(0), labels[i], epsilon, T=1, mode=mode).squeeze(0))
    perturbedImage = torch.stack(purterbed)
    model.eval()
    with torch.no_grad():
        outputs = model(perturbedImage)
        _, predicted = torch.max(outputs[1].data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
  ################################ End #########################################
  print(f'Accuracy of the network on the 10000 test images after attacking {mode} layer: {100 * correct // total} %')

In [22]:
###################### Problem 7 (1.5 points) ##################################
# todo: Report the accuracy of the student model under attack output and logit #
#       layers                                                                 #
# todo: Do not forget to set the temperature of the student model to 1         #
################################################################################
# your code goes here
print('attack on student model and logits:')
attack_test(student, student, testloader, mode="logits", epsilon=4/255, T=1, device=device)

print('attack on student model and output:')
attack_test(student, student, testloader, mode="output", epsilon=4/255, T=1, device=device)
################################ End ###########################################

attack on student model and logits:
Accuracy of the network on the 10000 test images after attacking logits layer: 30 %
attack on student model and output:
Accuracy of the network on the 10000 test images after attacking output layer: 59 %
