# Lab 4: Adversarial Training with MNIST

In [24]:
import os
import sys
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import torchattacks
from torchattacks import RPGD, FGSM
import time

In [7]:
import matplotlib.pyplot as plt
%matplotlib inline
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## 1. Dataset Preparation and DNN Trainin

### 1.1 MNIST Dataset

In [None]:
mnist_train = dsets.MNIST(root='data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)

In [None]:
batch_size = 128

train_loader  = torch.utils.data.DataLoader(dataset=mnist_train,
                                           batch_size=batch_size,
                                           shuffle=False)

test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                         batch_size=batch_size,
                                         shuffle=False)

### 1.2 Define Model

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.layer = nn.Sequential(
            nn.Conv2d(1,16,5),
            nn.ReLU(),
            nn.Conv2d(16,32,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(32,64,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        
        self.fc_layer = nn.Sequential(
            nn.Linear(64*3*3,100),
            nn.ReLU(),
            nn.Linear(100,10)
        )       
        
    def forward(self,x):
        out = self.layer(x)
        out = out.view(-1,64*3*3)
        out = self.fc_layer(out)

        return out

In [None]:
model = CNN().to(device)
if device =='cuda':
    print("Train on GPU...")
else:
    print("Train on CPU...")

In [None]:
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### 1.3 Normal training

In [None]:
num_epochs = 5
print('start training')
start = time.time()
for epoch in range(num_epochs):
    total_batch = len(mnist_train) // batch_size
    for i, (batch_images, batch_labels) in enumerate(train_loader):
        X = batch_images.to(device)
        Y = batch_labels.to(device)
        pre = model(X)
        cost = loss(pre, Y)
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        if (i+1) % 100 == 0:
            print('Epoch [%d/%d], lter [%d/%d], Loss: %.4f'
                 %(epoch+1, num_epochs, i+1, total_batch, cost.item()))
print('training finished')
end = time.time()
print ('the total running time is:', end - start)
torch.save(model.state_dict(),'normal_model.pt')

### 1.4 Accuracy on clean images

In [None]:
model.eval()

correct = 0
total = 0

for images, labels in test_loader:
    
    images = images.to(device)
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of Clean images: %f %%' % (100 * float(correct) / total))

## 2. Generating adversarial examples

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline


model = CNN().to(device)
model.load_state_dict(torch.load("normal_model.pt"))
fgsm_attack = FGSM(model, eps=0.3)

images, labels = next(iter(test_loader))
images = images[:5].to(device)
labels = labels[:5].to(device)
adver_images = fgsm_attack(images, labels).to(device)

outputs = model(adver_images)
_, predicted = torch.max(outputs.data, 1)
print(labels)
print(predicted)
for i,_ in enumerate(labels):
    plt.figure(figsize=(1.5,1.5))
    plt.imshow(images[i][0].cpu().detach().numpy(), cmap='gray')
    plt.figure(figsize=(1.5,1.5))
    plt.imshow(adver_images[i][0].cpu().detach().numpy(), cmap='gray')

## 3. Accuracy under Adversarial Attacks

### 3.1 FGSM

In [None]:
model.eval()

correct = 0
total = 0

fgsm_attack = FGSM(model,eps=0.2)

for images, labels in test_loader:
    
    images = fgsm_attack(images, labels).to(device)
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of Adversarial images: %f %%' % (100 * float(correct) / total))

### 3.2 PGD

In [None]:
model.eval()

correct = 0
total = 0

pgd_attack = RPGD(model,eps=0.2)

for images, labels in test_loader:
    
    images = pgd_attack(images, labels).to(device)
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of Adversarial images: %f %%' % (100 * float(correct) / total))

## 4. Train Model using adversarial training

### 4.1 Training with PGD

In [None]:
num_epochs = 5
model = CNN().to(device)
pgd_attack = RPGD(model, eps=0.2, iters=20)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
print('start training')
model.train()
start = time.time()
for epoch in range(num_epochs):
    total_batch = len(mnist_train) // batch_size
    for i, (batch_images, batch_labels) in enumerate(train_loader):
        X = pgd_attack(batch_images, batch_labels).to(device)
        Y = batch_labels.to(device)
        pre = model(X)
        cost = loss(pre, Y)
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        if (i+1) % 100 == 0:
            print('Epoch [%d/%d], lter [%d/%d], Loss: %.4f'
                 %(epoch+1, num_epochs, i+1, total_batch, cost.item()))
print('training finished')
end = time.time()
print ('the total running time is:', end - start)
torch.save(model.state_dict(),'adver_training_model.pt')

### 4.2 Test on clean images

In [None]:
model = CNN().to(device)
model.load_state_dict(torch.load("adver_training_model.pt"))
model.eval()
correct = 0
total = 0

for images, labels in test_loader:
    
    images = images.to(device)
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of Clean images: %f %%' % (100 * float(correct) / total))

## 5. Test on Adversarial images

In [None]:
model = CNN().to(device)
model.load_state_dict(torch.load("adver_training_model.pt"))
model.eval()

correct = 0
total = 0

pgd_attack = RPGD(model,eps=0.2)

for images, labels in test_loader:
    
    images = pgd_attack(images, labels).to(device)
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of Adversarial images: %f %%' % (100 * float(correct) / total))