# Phase1 Data Collection

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import torchattacks

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
folder_path = ''

In [None]:
batch_size = 64

cifar10_train = dsets.CIFAR10(root='./data', train=True,
                              download=True, transform=transforms.ToTensor())
cifar10_test  = dsets.CIFAR10(root='./data', train=False,
                              download=True, transform=transforms.ToTensor())

train_loader = DataLoader(cifar10_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(cifar10_test, batch_size=batch_size, shuffle=False)

print("Training data: ", len(cifar10_train))
print("Testing data: ",len(cifar10_test))

In [None]:
model = models.resnet18(pretrained=True)
num_frts = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_frts, 10)
    )

model = model.to(device)

In [None]:
def model_training(num_epoch=100):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)

  epochs = num_epoch
  losses = []
  accuracy = []

  for epoch in range(epochs): 
      print('Epoch', epoch+1, '/', epochs)
      total = 0
      correct = 0   
      model.train()

      for images, labels in train_loader:
          images = images.to(device)
          labels = labels.to(device)
          optimizer.zero_grad()
          outputs = model(images)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          _, predicted = torch.max(outputs.data, 1)

          total += labels.size(0)
          correct += (predicted == labels.to(device)).sum()
      
      losses.append(loss)
      accuracy.append((correct/total).item())
          
      print("train loss ", float(loss), "train acc ", (correct/total).item())
          
  print('Finished Training')
  torch.save(model.state_dict(), folder_path+'cifar10_resnet18.pt')

In [None]:
if os.path.isfile(folder_path+'/cifar10_resnet18.pt'):
  model.load_state_dict(torch.load(folder_path+'cifar10_resnet18.pt')) 
  model = model.eval()

else:
  model_training(100)
  model = model.eval()

In [None]:
def load_attack(attack_name):
  adv_loader = atk.load(load_path=folder_path+'cifar10_'+str(attack_name)+'.pt', load_predictions=True, load_clean_inputs=True, batch_size=10000)
  adv_images, orig_labels, adv_preds, clean_images = next(iter(adv_loader))

  clean_preds = torch.tensor([])
  correct = 0
  total = 0

  for images, labels in test_loader:
    output = model(images.to(device))
    _, pred = torch.max(output.data, 1)

    total += labels.size(0)
    correct += (pred == labels.to(device)).sum()

    clean_preds = torch.cat((clean_preds, pred.cpu().data), 0)
  
  print('Standard accuracy: %.2f %%' % (100 * float(correct) / total))

  adv_correct = 0
  adv_total = 0

  for images, labels, _, _ in adv_loader: 
    outputs = model(images.to(device))
    _, pred = torch.max(outputs.data, 1)

    adv_total += labels.size(0)
    adv_correct += (pred == labels.to(device)).sum()

  print('Robust accuracy: %.2f %%' % (100 * float(adv_correct) / adv_total))

  return orig_labels, clean_images, clean_preds, adv_images, adv_preds

In [None]:
def generate_success_attack(orig_labels, clean_images, clean_preds, adv_images, adv_preds, attack_name):
  images = torch.tensor([])
  labels = torch.tensor([])
  i = 0 

  for num in range(len(orig_labels)):
    if clean_preds[num] != adv_preds[num]:
      clean_image_num = clean_images[num].unsqueeze(0)
      clean_pred_num = torch.zeros(1)
      adv_image_num = adv_images[num].unsqueeze(0)
      adv_pred_num = torch.ones(1)
      
      if orig_labels[num] == clean_preds[num] and orig_labels[num] != adv_preds[num]: 
        i += 1
        images = torch.cat((images, clean_image_num, adv_image_num), 0)
        labels = torch.cat((labels, clean_pred_num, adv_pred_num), 0)

  print("Successful attack: ", i)
  print(images.shape)
  print(labels.shape)

  np.save(folder_path + 'images_'+str(attack_name)+'.npy', images)
  np.save(folder_path + 'labels_'+str(attack_name)+'.npy', labels) 

In [None]:
attack_name = ['pgd','deepfool','fgsm']

for attack in attack_name:
  print(attack)
  if attack == 'pgd':
    atk = torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=100, random_start=True)
  elif attack == 'deepfool':
    atk = torchattacks.DeepFool(model, steps=50, overshoot=0.02)
  elif attack == 'fgsm':
    atk = torchattacks.FGSM(model, eps=8/255)
  else:
    print("Error method")
    
  if os.path.isfile(folder_path+'/cifar10_'+str(attack)+'.pt'):
    o_labels, c_images, c_preds, a_images, a_preds = load_attack(attack)
    generate_success_attack(o_labels, c_images, c_preds, a_images, a_preds, attack)
    
  else:
    atk.save(data_loader=test_loader, save_path=folder_path+'cifar10_'+str(attack)+'.pt', save_predictions=True, save_clean_inputs=True)
    o_labels, c_images, c_preds, a_images, a_preds = load_attack(attack)
    generate_success_attack(o_labels, c_images, c_preds, a_images, a_preds, attack)