# About

This notebook attempts to finetune a resnet model to be more Robust by leveraging Projected Gradient Descent (PGD) in the two different ways:
1. include it as a part of the dataset
2. include the projected gradient as well while training along with the regular gradient from the loss function 

# 0. Importing required libraries


In [1]:
# !pip install lightning[extra]

In [2]:
import random
import os
from datasets import load_dataset
from torchvision.models import resnet50, ResNet50_Weights
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from tqdm import tqdm

# 1. Setup

In [3]:
# Manual seed for reproducibility
SEED = 1234
torch.manual_seed(SEED)

# Params for PGD
ALPHA = 2/255
STEPS = 20
EPSILON = 8/255

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 1.1 Loading the model

In [4]:
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
preprocess = weights.transforms()

### 1.2 Loading the datasets

In [5]:
def preprocess_img(example):
    example['image'] = preprocess(example['image'])
    return example

def get_dataloader(split: str, batch_num: int, batch_size: int):
    '''
        
    :param split: can be either train, test or validation 
    :return: 
    '''
    print(f"Loading {split} ILSVRC/imagenet-1k dataset...")
    ds = load_dataset("ILSVRC/imagenet-1k", split=split, streaming=True, trust_remote_code=True)
    
    # Filter out grayscale images
    ds = ds.filter(lambda example: example['image'].mode == 'RGB')
    
    # Preprocess function will be applied to images on-the-fly whenever they are being accessed in the loop
    ds = ds.map(preprocess_img)
    ds = ds.shuffle(seed=SEED)
    
    # Only take desired portion of dataset
    ds = ds.take(batch_num * batch_size)
    print(f"Creating dataloader with {batch_num} batches for split {split} each with size {batch_size}")
    return DataLoader(ds, batch_size=batch_size)
    

### 1.3 My Resnet PGD Attacker

In [13]:
class ResnetPGDAttacker:
    def __init__(self, model, device=DEVICE):
        '''
        The PGD attack on Resnet model.
        :param model: The resnet model on which we perform the attack
        :param dataloader: The dataloader loading the input data on which we perform the attack
        '''
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss()
        self.eps = 0
        self.alpha = 0
        self.steps = 0
        self.device = DEVICE
        self.model.to(self.device)
        
        # Nullify gradient for model params
        for p in self.model.parameters():
            p.requires_grad = False

    def pgd_attack(self, image, label, eps=None, alpha=None, steps=None):
        '''
        Create adversarial images for given batch of images and labels

        :param image: Batch of input images on which we perform the attack, size (BATCH_SIZE, 3, 224, 224)
        :param label: Batch of input labels on which we perform the attack, size (BATCH_SIZE)
        :return: Adversarial images for the given input images
        '''
        if eps is None:
            eps = self.eps
        if alpha is None:
            alpha = self.alpha
        if steps is None:
            steps = self.steps

        images = image.clone().detach().to(self.device)
        adv_images = image.clone().detach().to(self.device)
        labels = label.clone().detach().to(self.device)

        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise
        
        # Enable gradient tracking for adversarial images
        adv_images.requires_grad = True

        for _ in range(steps):
            # Get model predictions and apply softmax
            outputs = self.model(adv_images).softmax(1)

            # Calculate loss
            loss = self.loss_fn(outputs, labels)

            print(f"Loss: {loss}")
            print(f"Adversarial images: {adv_images}")
            
            assert loss.requires_grad == True, "loss should be requires_grad" #TODO.x why does this fail for jupyter notebook but work in python script?
            assert adv_images.requires_grad == True, "adv_images should be requires_grad"
            # Compute gradient wrt images
            grad = torch.autograd.grad(
                loss, adv_images, retain_graph=False, create_graph=False
            )[0]

            # Gradient update
            adv_images = adv_images + alpha * grad.sign()  # Update adversarial images using the sign of the gradient

            # Projection step
            # Clamping the adversarial images to ensure they are within the L∞ ball of eps radius of original image
            adv_images = torch.clamp(adv_images, images - eps, images + eps)

        #detaching only in the end
        adv_images = adv_images.detach()
        return adv_images  # Return the generated adversarial images


# 2. Functions related to generating & saving + fetching saved data from a location

(TODO) : maybe this is not needed as I am creating a dataset directly!

In [14]:
def fetch_perturbed_images_from_location(location: str):
    """
    Fetch and concatenate all perturbed images and labels from the specified directory.

    :param location: The directory where the .pt files are saved.
    :return: Tuple of concatenated adversarial images and labels.
    """
    adv_images = []
    labels = []

    # Iterate through all .pt files in the specified directory
    for file_name in os.listdir(location):
        if file_name.endswith(".pt"):
            file_path = os.path.join(location, file_name)
            data = torch.load(file_path)
            adv_images.append(data['adv_images'])
            labels.append(data['labels'])

    # Concatenate all adversarial images and labels
    adv_images = torch.cat(adv_images)
    labels = torch.cat(labels)

    return adv_images, labels

def generate_adv_images(dataloader, model, store: bool = False, store_dir: str = 'adv_images'):
    """
    Generate adversarial images using the PGD method and save them in a directory.

    :param dataloader: dataloader of images, labels.
    :param model: The ResNet model used for generating adversarial images.
    :param store: Boolean flag to indicate whether to save the generated images.
    :param store_dir: The directory where the adversarial images and labels will be saved.
    :return: Tuple of adversarial images and their corresponding labels.
    """
    # Initialize the attacker
    attacker = ResnetPGDAttacker(model, dataloader)

    # Generate adversarial images
    adv_images = []
    labels = []
    for batch in tqdm(dataloader):
        img_batch, label_batch = batch['image'], batch['label']
        adv_images_batch = attacker.pgd_attack(img_batch, label_batch, steps = random.randint(5, STEPS)) #randomly take 5->20 steps
        adv_images.append(adv_images_batch)
        labels.append(label_batch)

    # Concatenate all adversarial images
    adv_images = torch.cat(adv_images)
    labels = torch.cat(labels)

    if store:
        # Create the store directory if it doesn't exist
        os.makedirs(store_dir, exist_ok=True)
    
        # Save the adversarial images and labels in multiple files
        for j in range(0, len(adv_images), 1000):  # Save in batches of 1000
            start = j
            end = min(j + 1000, len(adv_images))
            file_path = os.path.join(store_dir, f"adv_images_{start}-{end-1}.pt")
            torch.save({'adv_images': adv_images[start:end], 'labels': labels[start:end]}, file_path)
            print(f"Adversarial images saved at: {file_path}")
        
    return store_dir, adv_images, labels

# 3. Creating the fine-tuning dataset

In [15]:
class RobustDataLoader(DataLoader):
    def __init__(self, dataset, model, num_perturbations=2, num_transformations=1, device=DEVICE, **kwargs):
        """
        Robust DataLoader that includes original, perturbed, and augmented images.

        :param dataset: Dataset containing original images and labels.
        :param model: The ResNet model used for generating adversarial images.
        :param num_perturbations: Number of perturbations to generate for each original image.
        :param num_transformations: Number of transformations to apply to each image.
        :param kwargs: Additional arguments to be passed to the parent class constructor.
        """
        super().__init__(dataset, **kwargs) # self.dataset is defined inside of this       
        self.device = device
        self.model = model.to(self.device)
        self.num_perturbations = num_perturbations
        self.num_transformations = num_transformations
        self.pgd_attacker = ResnetPGDAttacker(model=self.model)

        # Define transformations
        self.transformations = [
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
            ]),
            transforms.Compose([
                transforms.RandomVerticalFlip(),
            ]),
            transforms.Compose([
                transforms.RandomRotation(20),
            ]),
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(20),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            ]),
            transforms.Compose([
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(20),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            ])
        ]

    def perform_pgd(self, image, label):
        random_eps = random.uniform(0.01, 0.3)
        random_alpha = random.uniform(0.01, 0.1)
        random_steps = 15
        perturbed_image = self.pgd_attacker.pgd_attack(image.unsqueeze(0), label.unsqueeze(0), eps=random_eps,
                                                  alpha=random_alpha, steps=random_steps)
        return perturbed_image.squeeze(0).to(self.device)

    def __iter__(self):
        for datapoint in self.dataset:
            image = datapoint['image'].to(self.device)
            label = torch.tensor(datapoint['label']).to(self.device)

            original_images = []
            original_labels = []
            perturbed_images = []
            transformed_images = []

            # Collect original images and labels
            original_images.append(image)
            original_labels.append(label)

            # Generate and collect perturbations
            for _ in range(self.num_perturbations):
                perturbed_image = self.perform_pgd(image, label)
                perturbed_images.append(perturbed_image)
                original_labels.append(label)

            # Generate and collect transformations
            for _ in range(self.num_transformations):
                transformation_index = random.randint(0, len(self.transformations) - 1)
                transformed_image = self.transformations[transformation_index](image).to(self.device)
                transformed_images.append(transformed_image)
                original_labels.append(label)

            # Concatenate images and labels into batches
            all_images = original_images + perturbed_images + transformed_images
            all_labels = original_labels

            # Yield the batch
            yield torch.stack(all_images).to(self.device), torch.tensor(all_labels).to(self.device)

In [16]:
def create_dataloader(split, batch_num=1, batch_size=1):
    BATCH_NUM = batch_num
    BATCH_SIZE = batch_size
    
    # Load the dataset directly using load_dataset
    ds = load_dataset("ILSVRC/imagenet-1k", split=split, streaming=True, trust_remote_code=True)
    
    # Preprocess the dataset
    ds = ds.filter(lambda example: example['image'].mode == 'RGB')
    ds = ds.map(preprocess_img)
    ds = ds.shuffle(seed=SEED)
    ds = ds.take(BATCH_NUM * BATCH_SIZE)

    # Create the RobustDataLoader
    return RobustDataLoader(ds, model, num_perturbations=1, num_transformations=1, batch_size=BATCH_SIZE)


In [17]:
train_loader = create_dataloader("train")
val_loader = create_dataloader("validation")
test_loader = create_dataloader("test")

# 4. Implementing Robust Resnet

In [18]:
class RobustResnet(LightningModule):
    def __init__(self, model=model, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader,
                 learning_rate=1e-3):
        super().__init__()
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader
        self.model = model
        self.learning_rate = learning_rate
        self.train_acc = Accuracy("multiclass", num_classes=1000)
        self.val_acc = Accuracy("multiclass", num_classes=1000)
        self.test_acc = Accuracy("multiclass", num_classes=1000)

        # Initialize the PGD attacker once
        self.pgd_attacker = ResnetPGDAttacker(model=self.model)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.model(images)
        ce_loss = F.cross_entropy(logits, labels)
        pgd_loss = self.pgd_loss(images, labels)
        loss = ce_loss + pgd_loss

        self.train_acc(logits, labels)
        self.log('train_ce_loss', ce_loss)
        self.log('train_pgd_loss', pgd_loss)
        self.log('train_loss', loss)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.model(images)
        ce_loss = F.cross_entropy(logits, labels)
        pgd_loss = self.pgd_loss(images, labels)
        loss = ce_loss + pgd_loss
        self.val_acc(logits, labels)
        self.log('val_ce_loss', ce_loss, on_step=False, on_epoch=True)
        self.log('val_pgd_loss', pgd_loss, on_step=False, on_epoch=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.model(images)
        ce_loss = F.cross_entropy(logits, labels)
        pgd_loss = self.pgd_loss(images, labels)
        loss = ce_loss + pgd_loss
        acc = self.test_acc(logits, labels)
        self.log('test_ce_loss', ce_loss)
        self.log('test_pgd_loss', pgd_loss)
        self.log('test_loss', loss)
        self.log('test_acc', acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def pgd_loss(self, images, labels):
        # Generate adversarial examples and calculate the loss
        adv_images = self.pgd_attacker.pgd_attack(images.to(self.device), labels.to(self.device))
        adv_logits = self.model(adv_images)
        pgd_loss = F.cross_entropy(adv_logits, labels.to(self.device))

        return pgd_loss

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    def test_dataloader(self):
        return self.test_loader

In [19]:
# Create the classifier
classifier = RobustResnet(model, train_loader, val_loader, test_loader)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Monitor validation loss for saving the best model
    dirpath='checkpoints/',  # Directory where checkpoints will be saved
    filename='best-checkpoint',  # Checkpoint filename
    save_top_k=1,  # Save only the best model
    mode='min'  # Minimize the monitored metric
)

# Create a Trainer
trainer = Trainer(
    max_epochs=3,
    accelerator='gpu',  # Specify the accelerator type
    devices=1,  # Use 1 GPU if available
    callbacks=[checkpoint_callback],  # Add the checkpoint callback
    log_every_n_steps=1  # Log metrics every step
)

# Train the model
trainer.fit(classifier)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | ResNet             | 25.6 M | eval 
1 | train_acc | MulticlassAccuracy | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
0         Trainable params
25.6 M    Non-trainable params
25.6 M    Total params
102.228   Total estimated model params size (MB)
3         Modules in train mode
151       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Loss: 6.653629779815674
Adversarial images: tensor([[[[-0.5023, -0.4569, -0.5377,  ...,  0.9713,  0.8745,  0.7386],
          [-0.4527, -0.2359, -0.4827,  ...,  0.9710,  0.9936,  0.9518],
          [-0.2093, -0.2199, -0.5619,  ...,  0.6689,  0.8476,  1.1675],
          ...,
          [ 1.3939,  1.5134,  1.6850,  ...,  1.7968,  2.0799,  2.0123],
          [ 1.5227,  1.7923,  1.5735,  ...,  1.6647,  1.8153,  1.9972],
          [ 1.7974,  1.5117,  1.4493,  ...,  1.9405,  2.0178,  1.9242]],

         [[-0.4390, -0.2397, -0.0431,  ...,  0.8970,  0.7794,  1.1366],
          [-0.4884, -0.3496, -0.4493,  ...,  0.7607,  1.0689,  1.1576],
          [-0.1161, -0.3655, -0.3715,  ...,  1.0805,  1.1811,  1.1440],
          ...,
          [ 2.0228,  1.7886,  2.0053,  ...,  2.1421,  2.2692,  2.2802],
          [ 1.9608,  1.8276,  1.7157,  ...,  2.2434,  2.2565,  2.0161],
          [ 1.9258,  1.7670,  1.8564,  ...,  2.2285,  2.1459,  2.3452]],

         [[-0.1415, -0.3137, -0.0630,  ...,  1.1809,  1.35

AssertionError: loss should be requires_grad

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
# Test the model
trainer.test()

In [None]:
'''
TODO:
1. check tensorboard here in the cell directly 
2. evaluate the evaluate code as the checkpointing and reloading of that model is not as straightforward as I thought it would be 
'''

# 5. Saving model & Evaluating Results

In [None]:
def load_model_from_checkpoint(checkpoint_path):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)

    # Create a new instance of your model
    finetuned_model = RobustResnet()  # Initialize with the required parameters

    # Load the model weights from the checkpoint
    finetuned_model.load_state_dict(checkpoint['state_dict'])

    # Set the model to evaluation mode
    finetuned_model.eval()
    
    return finetuned_model

def evaluate_model(checkpoint_path, original_model, test_dataloader, device=DEVICE):
    # Load the fine-tuned model from the checkpoint
    fine_tuned_model = load_model_from_checkpoint(checkpoint_path)

    # Move models to the appropriate device
    original_model.to(device)
    fine_tuned_model.to(device)

    # Initialize accuracy metrics
    original_acc = Accuracy()
    fine_tuned_acc = Accuracy()

    # Evaluate the original model
    with torch.no_grad():
        for batch in test_dataloader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)

            # Original model predictions
            original_logits = original_model(images)
            original_acc(original_logits, labels)

            # Fine-tuned model predictions
            fine_tuned_logits = fine_tuned_model(images)
            fine_tuned_acc(fine_tuned_logits, labels)

    # Calculate accuracies
    original_accuracy = original_acc.compute()
    fine_tuned_accuracy = fine_tuned_acc.compute()

    print(f'Evaluation Original Model Accuracy: {original_accuracy:.4f}')
    print(f'Evaluation Fine-Tuned Model Accuracy: {fine_tuned_accuracy:.4f}')

    return original_accuracy, fine_tuned_accuracy