# About

This notebook attempts to finetune a resnet model to be more Robust by leveraging Projected Gradient Descent (PGD) in the two different ways:
1. include it as a part of the dataset
2. include the projected gradient as well while training along with the regular gradient from the loss function 

# 0. Importing required libraries


In [1]:
# !pip install lightning[extra]
# !pip install tensorboard

In [10]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import uuid
import random
from datasets import load_dataset
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision.models import resnet50, ResNet50_Weights
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from tqdm import tqdm
import copy

# 1. Setup

In [11]:
# Manual seed for reproducibility
SEED = 1234
torch.manual_seed(SEED)

# Params for PGD
ALPHA = 2/255
STEPS = 20
EPSILON = 8/255

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {DEVICE}")

# model = resnet50(weights=ResNet50_Weights.DEFAULT)
# print("initialized the model")

device: cuda


## 2. Creating a finetuning dataset

Idea is to create a balanced finetuning dataset which is run only once to be saved onto the disk and then from there we can just create dataloaders on that for finetuning

### 2.1 Reusing the same PGD attacker class from before

In [12]:
class ResnetPGDAttacker:
    def __init__(self):
        '''
        The PGD attack on Resnet model.
        :param model: The resnet model on which we perform the attack
        :param dataloader: The dataloader loading the input data on which we perform the attack
        '''
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.loss_fn = nn.CrossEntropyLoss()
        self.adv_images = []
        self.labels = []
        self.eps = 0
        self.alpha = 0
        self.steps = 0
        self.acc = 0
        self.adv_acc = 0
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        # Nullify gradient for model params
        for p in self.model.parameters():
            p.requires_grad = False

    def pgd_attack(self, image, label, eps=None, alpha=None, steps=None):
        '''
        Create adversarial images for given batch of images and labels

        :param image: Batch of input images on which we perform the attack, size (BATCH_SIZE, 3, 224, 224)
        :param label: Batch of input labels on which we perform the attack, size (BATCH_SIZE)
        :return: Adversarial images for the given input images
        '''
        if eps is None:
            eps = self.eps
        if alpha is None:
            alpha = self.alpha
        if steps is None:
            steps = self.steps

        images = image.clone().detach().to(self.device)
        adv_images = images.clone()
        labels = label.clone().detach().to(self.device)

        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise

        for _ in range(steps):
            # Enable gradient tracking for adversarial images
            adv_images.requires_grad = True

            # Get model predictions and apply softmax
            outputs = self.model(adv_images).softmax(1)

            # Calculate loss
            loss = self.loss_fn(outputs, labels)

            # Compute gradient wrt images
            grad = torch.autograd.grad(
                loss, adv_images, retain_graph=False, create_graph=False
            )[0]
            adv_images = adv_images.detach()

            # Gradient update
            adv_images = adv_images + alpha * grad.sign()  # Update adversarial images using the sign of the gradient

            # Projection step
            # Clamping the adversarial images to ensure they are within the L∞ ball of eps radius of original image
            adv_images = torch.clamp(adv_images, images - eps, images + eps)

            adv_images = adv_images.detach()

        return adv_images  # Return the generated adversarial images



### 2.2 Creating a finetuned dataset and saving it in disk

In [13]:
class FineTuneDatasetGenerator:
    def __init__(self, num_classes=1000, num_images_per_class=4, num_transforms=3, num_perturbations=2,
                 save_path="./dataset"
                 ):
        self.num_classes = num_classes
        self.num_images_per_class = num_images_per_class
        self.num_transforms = num_transforms
        self.num_perturbations = num_perturbations
        self.save_path = save_path
        
        # Create the save directory if it doesn't exist
        os.makedirs(self.save_path, exist_ok=True)
        weights = ResNet50_Weights.DEFAULT
        self.resnet_transform = weights.transforms()  #PIL -> tensor

        self.transformations = [
            transforms.Compose([
                transforms.RandomRotation(15)
            ]),
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15)
            ]),
            transforms.Compose([
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(15)
            ])
        ]

        self.pgd_attacker = ResnetPGDAttacker()

        self.ds = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True, trust_remote_code=True)
        self.ds = self.ds.shuffle()
        print(f"Fine Tune Dataset Generator has been initialized.")

    def save_datapoint_with_augmentation_and_perturbations(self, img, label):
        images = []
        tensor_img = self.resnet_transform(img)

        #original image
        images.append(tensor_img)

        #augmentations (will be four)
        for _ in range(self.num_transforms):
            transformed_img = random.choice(self.transformations)(img)
            transformed_tensor_img = self.resnet_transform(transformed_img)
            images.append(transformed_tensor_img)

        #perturbations on all images thus far
        batch = list(
            map(
                lambda img: img.to("cpu"),
                images
            )
        )
        # Get a batch of images and create corresponding labels
        img_batch = torch.stack(batch)
        label_batch = torch.tensor([label] * len(img_batch))  # Repeat the label for the batch

        for _ in range(self.num_perturbations):
            # Generate random parameters for PGD attack
            random_eps = random.uniform(0.01, 0.3)
            random_alpha = random.uniform(0.01, 0.1)
            random_steps = random.randint(15, 20)

            # Perform the PGD attack
            perturbed_images = self.pgd_attacker.pgd_attack(img_batch,
                                                            label_batch,
                                                            eps=random_eps,
                                                            alpha=random_alpha,
                                                            steps=random_steps)

            for perturbed_img in perturbed_images:
                images.append(perturbed_img)

        #save this
        self.save_images(images, label)

    def generate(self):
        for label in tqdm(range(self.num_classes), desc="Fetching images for each class"):
            # Define the filter function for the current class label
            def filter_class(example):
                return example['label'] == label and example['image'].mode == 'RGB'

            # Load the dataset and filter for the current class
            ds = copy.deepcopy(self.ds)
            ds = ds.filter(filter_class)

            # Use take to get the desired number of images
            ds = ds.take(self.num_images_per_class)

            i = 0
            for data in ds:
                img = data["image"]
                self.save_datapoint_with_augmentation_and_perturbations(img, label)
                i += 1
                if i == self.num_images_per_class:
                    break

            print(
                f"Saved {self.num_images_per_class * (1 + self.num_transforms) * self.num_perturbations} images for class {label}."
            )

    def save_images(self, images, label):
        for idx, image in enumerate(images):
            img_id = str(uuid.uuid4())
            save_file = os.path.join(self.save_path, f"class_{label}_img_{img_id}.pt")
            torch.save(image, save_file)

        del images


# Example usage
# fine_tune_gen = FineTuneDatasetGenerator(num_perturbations=2, num_transforms=2, num_images_per_class=2, num_classes=2)
# fine_tune_gen.generate()


### 2.3 Creating a Pytorch dataset and dataloaders over the saved data in the disk

In [14]:
from torch.utils.data import random_split


class FineTunedDataset(Dataset):
    def __init__(self, root_dir, device=DEVICE):
        self.root_dir = root_dir
        self.image_files = os.listdir(root_dir)
        self.device = device

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load the image from the file
        image_file = self.image_files[idx]
        image = torch.load(os.path.join(self.root_dir, image_file)).to(self.device)

        # Extract label from the filename
        label = int(os.path.basename(image_file).split("_")[1])
        label = torch.tensor(label).to(self.device)  # e.g., "class_0_img_3.pt" -> label = 0

        return image, label

def get_dataloaders(root_dir, batch_size, train_val_test_split):
    dataset = FineTunedDataset(root_dir)

    train_size = int(len(dataset) * train_val_test_split[0])
    val_size = int(len(dataset) * train_val_test_split[1])
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    return train_loader, val_loader, test_loader


In [15]:
path = "C:\\Parashara\\Projects\\NUS projects\\Sem3\\Trustworthy ML\\Assignment 1\\PGDAttack\\dataset"
train_loader, val_loader, test_loader = get_dataloaders(path, batch_size=1, train_val_test_split=[0.33, 0.33, 0.33])

# 3. Implementing Robust Resnet

In [22]:
class RobustPGDResnet(LightningModule):
    def __init__(self, train_loader=train_loader, val_loader=val_loader, test_loader=val_loader, learning_rate=1e-3):
        super().__init__()
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.loss_fn = nn.CrossEntropyLoss()
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader
        self.learning_rate = learning_rate
        self.accuracy_metric = Accuracy("multiclass", num_classes=1000)

        # Important: This property activates manual optimization.
        self.automatic_optimization = False

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def forward(self, x):
        return self.model(x)

    def compute_pgd_loss(self, images, labels):
        # Randomize PGD parameters
        eps = random.uniform(0.01, 0.3)
        alpha = random.uniform(0.01, 0.1)
        steps = random.randint(5, 13)

        # Create adversarial images
        adv_images = images.clone().detach().requires_grad_(True)

        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise

        # PGD process
        for _ in range(steps):
            outputs = self.model(adv_images)

            # Calculate PGD loss
            pgd_loss = self.loss_fn(outputs, labels)

            #Manual calculation of grad
            grad = torch.autograd.grad(pgd_loss, adv_images, retain_graph=True)[0]

            with torch.no_grad():
                # Gradient update
                adv_images = adv_images + alpha * grad.sign()
                adv_images = torch.clamp(adv_images, images - eps, images + eps)
                
            adv_images = adv_images.detach().requires_grad_(True)

        return pgd_loss

    def training_step(self, batch, batch_idx):
        images, labels = batch

        # Manually optimize
        optimizer = self.optimizers()
        optimizer.zero_grad()  # Clear previous gradients

        # Forward pass
        logits = self.model(images)
        ce_loss = self.loss_fn(logits, labels)

        # Compute PGD loss
        pgd_loss = self.compute_pgd_loss(images, labels)

        # Combine losses (you can adjust the weighting)
        total_loss = ce_loss + pgd_loss

        # Backward pass
        self.manual_backward(total_loss)
        optimizer.step()  # Update model parameters

        # compute accuracy
        train_acc = self.accuracy_metric(logits, labels)

        # Log metrics
        self.log('train_ce_loss', ce_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_pgd_loss', pgd_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', train_acc, on_step=True, on_epoch=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.model(images)
        ce_loss = self.loss_fn(logits, labels)
        val_acc = self.accuracy_metric(logits, labels)
        self.log('val_loss', ce_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', val_acc, on_step=True, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.model(images)
        ce_loss = self.loss_fn(logits, labels)
        test_acc = self.accuracy_metric(logits, labels)
        self.log('test_loss', ce_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test_acc', test_acc, on_step=True, on_epoch=True, prog_bar=True)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    def test_dataloader(self):
        return self.test_loader


#TODO.x just for testing out pgd
# Create the classifier
classifier = RobustPGDResnet(train_loader, val_loader, test_loader)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Monitor validation loss for saving the best model
    dirpath='./pgd_checkpoints/',  # Directory where checkpoints will be saved
    filename='best-checkpoint',  # Checkpoint filename
    save_top_k=1,  # Save only the best model
    mode='min'  # Minimize the monitored metric
)

# Create a Trainer
trainer = Trainer(
    # fast_dev_run =True, #Flag for debugging
    callbacks=[checkpoint_callback],
    # profiler="simple", In case you want to time each step
    max_epochs=2,
    accelerator="auto",
    devices=1,
    check_val_every_n_epoch=1,
    # val_check_interval=0.5, Use this in case the epoch takes too long, fraction of epoch
    log_every_n_steps=1,
    enable_checkpointing=True,
    enable_progress_bar=True,
    logger=TensorBoardLogger(save_dir='./logs', name="ft_pgd_resnet")
)
# Other interesting flags
#max_time=some timedelta

# Train the model
trainer.fit(classifier)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | model           | ResNet             | 25.6 M | train
1 | loss_fn         | CrossEntropyLoss   | 0      | train
2 | accuracy_metric | MulticlassAccuracy | 0      | train
---------------------------------------------------------------
25.6 M    Trainable params
0         Non-trainable params
25.6 M    Total params
102.228   Total estimated model params size (MB)
153       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  image = torch.load(os.path.join(self.root_dir, image_file)).to(self.device)


Training: |          | 0/? [00:00<?, ?it/s]

  image = torch.load(os.path.join(self.root_dir, image_file)).to(self.device)


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


### 3.1 Training the model

In [17]:
# Create the classifier
classifier = RobustResnet(train_loader, val_loader, test_loader)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Monitor validation loss for saving the best model
    dirpath='./pgd_checkpoints/',  # Directory where checkpoints will be saved
    filename='best-checkpoint',  # Checkpoint filename
    save_top_k=1,  # Save only the best model
    mode='min'  # Minimize the monitored metric
)

# Create a Trainer
trainer = Trainer(
    # fast_dev_run =True, #Flag for debugging
    callbacks=[checkpoint_callback],
    # profiler="simple", In case you want to time each step
    max_epochs=4,
    accelerator="auto",
    devices=1,
    check_val_every_n_epoch=1,
    # val_check_interval=0.5, Use this in case the epoch takes too long, fraction of epoch
    log_every_n_steps=1,
    enable_checkpointing=True,
    enable_progress_bar=True,
    logger=TensorBoardLogger(save_dir='./logs', name="ft_pgd_resnet")
)
    # Other interesting flags
    #max_time=some timedelta

# Train the model
trainer.fit(classifier)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | model           | ResNet             | 25.6 M | train
1 | loss_fn         | CrossEntropyLoss   | 0      | train
2 | accuracy_metric | MulticlassAccuracy | 0      | train
---------------------------------------------------------------
25.6 M    Trainable params
0         Non-trainable params
25.6 M    Total params
102.228   Total estimated

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\paras\anaconda3\envs\trustworthyml\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
  image = torch.load(os.path.join(self.root_dir, image_file)).to(self.device)
C:\Users\paras\anaconda3\envs\trustworthyml\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=4` reached.


### 3.2 Tensorboard visualizations

In [23]:
%load_ext tensorboard
%tensorboard --logdir ./logs/ft_pgd_resnet/

### 3.3 Testing the trained model

In [None]:
# Test the model (can pass checkpoint path also)
trainer.test()

# 4. Evaluating Results

### 4.1 Evaluating the best model from checkpoint

In [41]:
def load_model_from_checkpoint(checkpoint_path):
    # Option 1: Use pytorch lightning
    finetuned_model = RobustResnet.load_from_checkpoint(checkpoint_path)

    # Option 2: Regular approach
    # Load the checkpoint
    # checkpoint = torch.load(checkpoint_path)
    # 
    # # Create a new instance of your model
    # finetuned_model = RobustResnet()  # Initialize with the required parameters
    # 
    # # Load the model weights from the checkpoint
    # finetuned_model.load_state_dict(checkpoint['state_dict'])

    return finetuned_model


def evaluate_model(test_loader=test_loader, checkpoint_path="./checkpoints/best-checkpoint.ckpt",
                   original_model=resnet50(weights=ResNet50_Weights.DEFAULT), device=DEVICE):
    # Load the fine-tuned model from the checkpoint
    fine_tuned_model = load_model_from_checkpoint(checkpoint_path)

    # Move models to the appropriate device
    original_model.to(device)
    fine_tuned_model.to(device)

    # set to eval modes
    original_model.eval()
    fine_tuned_model.eval()

    # Initialize accuracy metrics
    total = 0
    orig_correct = 0
    ft_correct = 0

    # Evaluate the original model
    with torch.no_grad():
        for batch in test_loader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)

            # Original model predictions
            original_logits = original_model(images).softmax(1)
            original_predictions = original_logits.argmax(dim=1)

            # Fine-tuned model predictions
            fine_tuned_logits = fine_tuned_model(images).softmax(1)
            fine_tuned_predictions = fine_tuned_logits.argmax(dim=1)

            # Accumulate accuracy counts
            orig_correct += torch.sum(original_predictions == labels).item()
            ft_correct += torch.sum(fine_tuned_predictions == labels).item()
            total += len(labels)

    # Calculate accuracies
    original_accuracy = orig_correct / total
    fine_tuned_accuracy = ft_correct / total

    print(f'Evaluation Original Model Accuracy: {original_accuracy * 100} %')
    print(f'Evaluation Fine-Tuned Model Accuracy: {fine_tuned_accuracy * 100} %')

    return original_accuracy, fine_tuned_accuracy

In [42]:
evaluate_model()

Evaluation Original Model Accuracy: 100.0 %
Evaluation Fine-Tuned Model Accuracy: 100.0 %


  image = torch.load(os.path.join(self.root_dir, image_file)).to(self.device)


(1.0, 1.0)