# About

This notebook does the following things at a high level:
1. code for generating adversarial PGD images and saving it to be used for fine-tuning
2. some visualizations of the perturbed images
3. creating the dataset (with some random augmentations)
4. finetuning the robust model  
5. evaluating on test set 

# 0. Importing required libraries


In [None]:
import random
import os
import datasets
from datasets import load_dataset
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# 1. Setup

In [None]:
# Manual seed for reproducibility
SEED = 1234
torch.manual_seed(SEED)

# Params for PGD
ALPHA = 2/255
STEPS = 20
EPSILON = 8/255


### 1.1 Loading the model

In [None]:
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
preprocess = weights.transforms()

### 1.2 Loading the datasets

In [None]:
def preprocess_img(example):
    example['image'] = preprocess(example['image'])
    return example

def get_dataloader(split: str, batch_num: int, batch_size: int):
    '''
        
    :param split: can be either train, test or validation 
    :return: 
    '''
    print(f"Loading {split} ILSVRC/imagenet-1k dataset...")
    ds = load_dataset("ILSVRC/imagenet-1k", split=split, streaming=True, trust_remote_code=True)
    
    # Filter out grayscale images
    ds = ds.filter(lambda example: example['image'].mode == 'RGB')
    
    # Preprocess function will be applied to images on-the-fly whenever they are being accessed in the loop
    ds = ds.map(preprocess_img)
    ds = ds.shuffle(seed=SEED)
    
    # Only take desired portion of dataset
    ds = ds.take(batch_num * batch_size)
    print(f"Creating dataloader with {batch_num} batches for split {split} each with size {batch_size}")
    return DataLoader(ds, batch_size=batch_size)
    

### 1.3 My Resnet PGD Attacker

In [None]:
class ResnetPGDAttacker:
    def __init__(self, model, dataloader: DataLoader):
        '''
        The PGD attack on Resnet model.
        :param model: The resnet model on which we perform the attack
        :param dataloader: The dataloader loading the input data on which we perform the attack
        '''
        self.model = model
        self.dataloader = dataloader
        self.batch_size = dataloader.batch_size
        self.loss_fn = nn.CrossEntropyLoss()
        self.adv_images = []
        self.labels = []
        self.eps = 0
        self.alpha = 0
        self.steps = 0
        self.acc = 0
        self.adv_acc = 0
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        # Nullify gradient for model params
        for p in self.model.parameters():
            p.requires_grad = False

    def pgd_attack(self, image, label, eps=None, alpha=None, steps=None):
        '''
        Create adversarial images for given batch of images and labels

        :param image: Batch of input images on which we perform the attack, size (BATCH_SIZE, 3, 224, 224)
        :param label: Batch of input labels on which we perform the attack, size (BATCH_SIZE)
        :return: Adversarial images for the given input images
        '''
        if eps is None:
            eps = self.eps
        if alpha is None:
            alpha = self.alpha
        if steps is None:
            steps = self.steps

        images = image.clone().detach().to(self.device)
        adv_images = images.clone()
        labels = label.clone().detach().to(self.device)

        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise

        for _ in range(steps):
            # Enable gradient tracking for adversarial images
            adv_images.requires_grad = True

            # Get model predictions and apply softmax
            outputs = self.model(adv_images).softmax(1)

            # Calculate loss
            loss = self.loss_fn(outputs, labels)

            # Compute gradient wrt images
            grad = torch.autograd.grad(
                loss, adv_images, retain_graph=False, create_graph=False
            )[0]
            adv_images = adv_images.detach()

            # Gradient update
            adv_images = adv_images + alpha * grad.sign()  # Update adversarial images using the sign of the gradient

            # Projection step
            # Clamping the adversarial images to ensure they are within the L∞ ball of eps radius of original image
            adv_images = torch.clamp(adv_images, images - eps, images + eps)

            adv_images = adv_images.detach()

        return adv_images  # Return the generated adversarial images

    def pgd_batch_attack(self, eps, alpha, steps, batch_num):
        '''
        Launch attack for many batches and save results as class features
        :param eps: Epsilon value in PGD attack
        :param alpha: Alpha value in PGD attack
        :param steps: Step value in PGD attack
        :param batch_num: Number of batches to run the attack on
        :return: Update attacker accuracy on original images, accuracy on adversarial images,
        and list of adversarial images
        '''
        self.model.eval()
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        adv_correct = 0
        correct = 0
        total = 0
        adv_images_lst = []
        for i, inputs in enumerate(tqdm(self.dataloader, total=batch_num)):
            if i == batch_num:
                break
            adv_images = self.pgd_attack(**inputs)
            with torch.no_grad():
                adv_outputs = self.model(adv_images).softmax(1)
                adv_predictions = adv_outputs.argmax(dim=1).cpu()
                outputs = self.model(inputs['image'].to(self.device)).softmax(1)
                predictions = outputs.argmax(dim=1).cpu()
            labels = inputs['label']
            adv_correct += torch.sum(adv_predictions == labels).item()
            correct += torch.sum(predictions == labels).item()
            total += len(labels)
            adv_images_lst.append(adv_images)
        self.adv_images = torch.cat(adv_images_lst).cpu()
        self.acc = correct / total
        self.adv_acc = adv_correct / total

    def compute_accuracy(self, batch_num):
        '''
        Compute model accuracy for specified number of data batches from self.dataloader
        :param batch_num: Number of batches on which we compute model accuracy
        :return: Update model accuracy
        '''
        self.model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for i, inputs in enumerate(tqdm(self.dataloader, total=batch_num)):
                if i == batch_num:
                    break
                inputs = {k: v.to(self.device) for (k, v) in inputs.items()}
                outputs = self.model(inputs['image']).softmax(1)
                predictions = outputs.argmax(dim=1)
                correct += (predictions == inputs['label']).sum().item()
                total += predictions.size(0)
        self.acc = correct / total

# 2. Functions related to generating & saving + fetching saved data from a location

In [11]:
def fetch_perturbed_images_from_location(location: str):
    """
    Fetch and concatenate all perturbed images and labels from the specified directory.

    :param location: The directory where the .pt files are saved.
    :return: Tuple of concatenated adversarial images and labels.
    """
    adv_images = []
    labels = []

    # Iterate through all .pt files in the specified directory
    for file_name in os.listdir(location):
        if file_name.endswith(".pt"):
            file_path = os.path.join(location, file_name)
            data = torch.load(file_path)
            adv_images.append(data['adv_images'])
            labels.append(data['labels'])

    # Concatenate all adversarial images and labels
    adv_images = torch.cat(adv_images)
    labels = torch.cat(labels)

    return adv_images, labels

def generate_adv_images(dataloader, model, store: bool = False, store_dir: str = 'adv_images'):
    """
    Generate adversarial images using the PGD method and save them in a directory.

    :param dataloader: dataloader of images, labels.
    :param model: The ResNet model used for generating adversarial images.
    :param store: Boolean flag to indicate whether to save the generated images.
    :param store_dir: The directory where the adversarial images and labels will be saved.
    :return: Tuple of adversarial images and their corresponding labels.
    """
    # Initialize the attacker
    attacker = ResnetPGDAttacker(model, dataloader)

    # Generate adversarial images
    adv_images = []
    labels = []
    for batch in tqdm(dataloader):
        img_batch, label_batch = batch['image'], batch['label']
        adv_images_batch = attacker.pgd_attack(img_batch, label_batch, steps = random.randint(5, STEPS)) #randomly take 5->20 steps
        adv_images.append(adv_images_batch)
        labels.append(label_batch)

    # Concatenate all adversarial images
    adv_images = torch.cat(adv_images)
    labels = torch.cat(labels)

    if store:
        # Create the store directory if it doesn't exist
        os.makedirs(store_dir, exist_ok=True)
    
        # Save the adversarial images and labels in multiple files
        for j in range(0, len(adv_images), 1000):  # Save in batches of 1000
            start = j
            end = min(j + 1000, len(adv_images))
            file_path = os.path.join(store_dir, f"adv_images_{start}-{end-1}.pt")
            torch.save({'adv_images': adv_images[start:end], 'labels': labels[start:end]}, file_path)
            print(f"Adversarial images saved at: {file_path}")
        
    return store_dir, adv_images, labels

# 3. Creating the fine-tuning dataset

# 4. Implementing Robust Resnet

# 5. Saving model & Evaluating Results