# TODOs:

1. For dataloaders keep the model separate ( or use the same model for train, test and val which just the base one) but for the actual finetuning use a different initialization
2. Since filter is possible on dataset, construct a dataset/dataloader with balanced class representation and then perturb it. 3 images * 2 perturbations * 1000 * (1 + 2 augmentations) =18k images ( check if this is enough first of all )
3. the model has issues with the validation step. (when is it called, why is grad_fn not available)
 (check from here -> https://github.com/Lightning-AI/pytorch-lightning/issues/13948)


# About

This notebook attempts to finetune a resnet model to be more Robust by leveraging Projected Gradient Descent (PGD) in the two different ways:
1. include it as a part of the dataset
2. include the projected gradient as well while training along with the regular gradient from the loss function 

# 0. Importing required libraries


In [1]:
# !pip install lightning[extra]

In [2]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import random
from datasets import load_dataset
from torchvision.models import resnet50, ResNet50_Weights
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from tqdm import tqdm

# 1. Setup

In [3]:
# Manual seed for reproducibility
SEED = 1234
torch.manual_seed(SEED)

# Params for PGD
ALPHA = 2/255
STEPS = 20
EPSILON = 8/255

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {DEVICE}")

# model = resnet50(weights=ResNet50_Weights.DEFAULT)
# print("initialized the model")

device: cuda


### 1.1 Loading the datasets

In [4]:
def preprocess_img(example):
    weights = ResNet50_Weights.DEFAULT
    preprocess = weights.transforms()
    example['image'] = preprocess(example['image'])
    return example

def get_dataloader(split: str, batch_num: int, batch_size: int):
    '''
        
    :param split: can be either train, test or validation 
    :return: 
    '''
    print(f"Loading {split} ILSVRC/imagenet-1k dataset...")
    ds = load_dataset("ILSVRC/imagenet-1k", split=split, streaming=True, trust_remote_code=True)
    
    # Filter out grayscale images
    ds = ds.filter(lambda example: example['image'].mode == 'RGB')
    
    # Preprocess function will be applied to images on-the-fly whenever they are being accessed in the loop
    ds = ds.map(preprocess_img)
    ds = ds.shuffle(seed=SEED)
    
    # Only take desired portion of dataset
    ds = ds.take(batch_num * batch_size)
    print(f"Creating dataloader with {batch_num} batches for split {split} each with size {batch_size}")
    return DataLoader(ds, batch_size=batch_size)
    

In [4]:
dl = get_dataloader("train", 2, 2)
for batch in dl:
    print(batch)
    break

Loading train ILSVRC/imagenet-1k dataset...
Creating dataloader with 2 batches for split train each with size 2
{'image': tensor([[[[-1.7925, -1.7925, -1.7925,  ..., -1.7412, -1.7583, -1.7583],
          [-1.7754, -1.7583, -1.7754,  ..., -1.7240, -1.7240, -1.7240],
          [-1.7754, -1.7925, -1.7925,  ..., -1.7240, -1.7240, -1.7240],
          ...,
          [-1.4843, -1.4843, -1.4843,  ..., -1.4843, -1.4843, -1.4843],
          [-1.4672, -1.4843, -1.4843,  ..., -1.4843, -1.4672, -1.4672],
          [-1.4843, -1.4843, -1.4843,  ..., -1.4843, -1.4843, -1.4672]],

         [[-0.7752, -0.7577, -0.7577,  ..., -0.7227, -0.7227, -0.7402],
          [-0.7752, -0.7577, -0.7752,  ..., -0.7402, -0.7577, -0.7402],
          [-0.7752, -0.7927, -0.7927,  ..., -0.7227, -0.7227, -0.7227],
          ...,
          [-0.1800, -0.1800, -0.1800,  ..., -0.1625, -0.1800, -0.1800],
          [-0.1625, -0.1800, -0.1800,  ..., -0.1800, -0.1625, -0.1625],
          [-0.1800, -0.1800, -0.1800,  ..., -0.1800, -

In [13]:
# Can use this for filtering
def chumma(example):
    # print(f"example: {example}")
    # print(f"example['label']: {example['label']} -> {type(example['label'])}")
    return example['label'] == 916

ds = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True, trust_remote_code=True)
ds = ds.filter(chumma)
# ds = ds.filter(lambda example: example['image'].mode == 'RGB')
# ds = ds.filter(lambda example: example['label'] == 726)
ds = ds.map(preprocess_img)
ds = ds.take(2 * 2)
dl = DataLoader(ds, batch_size=2)

for batch in dl:
    image, label = batch["image"], batch["label"]
    print(label)
    break

tensor([916, 916])


### 1.2 My Resnet PGD Attacker

In [5]:
class ResnetPGDAttacker:
    def __init__(self, model=None, device=DEVICE):
        '''
        The PGD attack on Resnet model.
        :param model: The resnet model on which we perform the attack
        :param dataloader: The dataloader loading the input data on which we perform the attack
        '''
        self.device = device
        
        if model == None:
            print(f"Creating new model")
            self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        else:
            print(f"Using existing model")
            self.model = model     
        self.model.to(self.device)
        
        self.loss_fn = nn.CrossEntropyLoss()
        self.eps = 0
        self.alpha = 0
        self.steps = 0
        
        # Nullify gradient for model params
        for p in self.model.parameters():
            p.requires_grad = False

    def pgd_attack(self, image, label, eps=None, alpha=None, steps=None):
        '''
        Create adversarial images for given batch of images and labels

        :param image: Batch of input images on which we perform the attack, size (BATCH_SIZE, 3, 224, 224)
        :param label: Batch of input labels on which we perform the attack, size (BATCH_SIZE)
        :return: Adversarial images for the given input images
        '''
        if eps is None:
            eps = self.eps
        if alpha is None:
            alpha = self.alpha
        if steps is None:
            steps = self.steps

        images = image.clone().to(self.device)
        adv_images = image.clone().to(self.device)
        labels = label.clone().to(self.device)

        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise
        
        # Enable gradient tracking for adversarial images
        adv_images.requires_grad = True

        for _ in range(steps):
            # Get model predictions and apply softmax
            outputs = self.model(adv_images).softmax(1)

            # Calculate loss
            loss = self.loss_fn(outputs, labels)
                        
            assert loss.requires_grad == True, f"loss should be requires_grad {loss}."
            assert adv_images.requires_grad == True, "adv_images should be requires_grad"
            
            # Compute gradient wrt images
            grad = torch.autograd.grad(
                loss, adv_images, retain_graph=False, create_graph=False
            )[0]

            # Gradient update
            adv_images = adv_images + alpha * grad.sign()  # Update adversarial images using the sign of the gradient

            # Projection step
            # Clamping the adversarial images to ensure they are within the L∞ ball of eps radius of original image
            adv_images = torch.clamp(adv_images, images - eps, images + eps)

        #detaching only in the end
        adv_images = adv_images.detach()
        return adv_images  # Return the generated adversarial images


# 2. Functions related to generating & saving + fetching saved data from a location

(TODO) : maybe this is not needed as I am creating a dataset directly!

In [6]:
def fetch_perturbed_images_from_location(location: str):
    """
    Fetch and concatenate all perturbed images and labels from the specified directory.

    :param location: The directory where the .pt files are saved.
    :return: Tuple of concatenated adversarial images and labels.
    """
    adv_images = []
    labels = []

    # Iterate through all .pt files in the specified directory
    for file_name in os.listdir(location):
        if file_name.endswith(".pt"):
            file_path = os.path.join(location, file_name)
            data = torch.load(file_path)
            adv_images.append(data['adv_images'])
            labels.append(data['labels'])

    # Concatenate all adversarial images and labels
    adv_images = torch.cat(adv_images)
    labels = torch.cat(labels)

    return adv_images, labels

def generate_adv_images(dataloader, store: bool = False, store_dir: str = 'adv_images'):
    """
    Generate adversarial images using the PGD method and save them in a directory.

    :param dataloader: dataloader of images, labels.
    :param model: The ResNet model used for generating adversarial images.
    :param store: Boolean flag to indicate whether to save the generated images.
    :param store_dir: The directory where the adversarial images and labels will be saved.
    :return: Tuple of adversarial images and their corresponding labels.
    """    
    # Initialize the attacker
    attacker = ResnetPGDAttacker()

    # Generate adversarial images
    adv_images = []
    labels = []
    for batch in tqdm(dataloader):
        img_batch, label_batch = batch['image'], batch['label']
        adv_images_batch = attacker.pgd_attack(img_batch, label_batch, steps = random.randint(5, STEPS)) #randomly take 5->20 steps
        adv_images.append(adv_images_batch)
        labels.append(label_batch)

    # Concatenate all adversarial images
    adv_images = torch.cat(adv_images)
    labels = torch.cat(labels)

    if store:
        # Create the store directory if it doesn't exist
        os.makedirs(store_dir, exist_ok=True)
    
        # Save the adversarial images and labels in multiple files
        for j in range(0, len(adv_images), 1000):  # Save in batches of 1000
            start = j
            end = min(j + 1000, len(adv_images))
            file_path = os.path.join(store_dir, f"adv_images_{start}-{end-1}.pt")
            torch.save({'adv_images': adv_images[start:end], 'labels': labels[start:end]}, file_path)
            print(f"Adversarial images saved at: {file_path}")
        
    return store_dir, adv_images, labels

# 3. Creating the fine-tuning dataset

In [7]:
class RobustDataLoader(DataLoader):
    def __init__(self, dataset, split, batch_size=1, batch_num=1, num_perturbations=2, num_transformations=1,
                 device=DEVICE,
                 **kwargs):
        """
        Robust DataLoader that includes original, perturbed, and augmented images.

        :param dataset: Dataset containing original images and labels.
        :param model: The ResNet model used for generating adversarial images.
        :param num_perturbations: Number of perturbations to generate for each original image.
        :param num_transformations: Number of transformations to apply to each image.
        :param kwargs: Additional arguments to be passed to the parent class constructor.
        """
        super().__init__(dataset, batch_size=batch_size,
                         **kwargs)  # self.dataset & self.batch_size is defined inside of this       
        self.device = device
        self.split = split
        self.num_perturbations = num_perturbations
        self.num_transformations = num_transformations
        self.pgd_attacker = ResnetPGDAttacker()
        self.batch_num = batch_num
        self.dataset_size = self.batch_num * self.batch_size
        
        if self.split == "test": 
            self.per_yield_limit = self.batch_size
            print(f"setting per_yield_limit to {self.per_yield_limit}, which is the same as batch size as it is in test mode")
        else:
            self.per_yield_limit = self.batch_size * (1 + self.num_transformations + self.num_perturbations)
            print(f"setting per_yield_limit to {self.per_yield_limit}, as there are additional #{self.num_transformations} transformations & #{self.num_perturbations} perturbations")
            

        # Define transformations
        self.transformations = [
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
            ]),
            transforms.Compose([
                transforms.RandomVerticalFlip(),
            ]),
            transforms.Compose([
                transforms.RandomRotation(20),
            ]),
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(20),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            ]),
            transforms.Compose([
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(20),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            ])
        ]
        print(
            f"RobustDataLoader initialized with split={split}, batch_num={batch_num}, batch_size={self.batch_size}, num_perturbations={self.num_perturbations}, num_transformations={self.num_transformations}, per_yield_limit={self.per_yield_limit}")

    def perform_pgd(self, image, label):
        random_eps = random.uniform(0.01, 0.3)
        random_alpha = random.uniform(0.01, 0.1)
        random_steps = random.randint(15, 20)
        perturbed_image = self.pgd_attacker.pgd_attack(image.unsqueeze(0), label.unsqueeze(0), eps=random_eps,
                                                       alpha=random_alpha, steps=random_steps)
        return perturbed_image.squeeze(0).to(self.device)

    def __len__(self):
        '''
        For each image, we do T transformations and P perturbations and retain the original image.
        
        :return: 
        '''
        return (self.batch_num * self.batch_size) * (1 + self.num_transformations + self.num_perturbations)

    def __iter__(self):
        collected_images = []
        collected_labels = []

        for datapoint in self.dataset:
            image = datapoint['image'].to(self.device)
            label = torch.tensor(datapoint['label']).to(self.device)

            # Collect original images and labels
            collected_images.append(image)
            collected_labels.append(label)

            if self.split != "test":
                # Generate and collect perturbations
                for _ in range(self.num_perturbations):
                    perturbed_image = self.perform_pgd(image, label)
                    collected_images.append(perturbed_image)
                    collected_labels.append(label)
    
                # Generate and collect transformations
                for _ in range(self.num_transformations):
                    transformation_index = random.randint(0, len(self.transformations) - 1)
                    transformed_image = self.transformations[transformation_index](image).to(self.device)
                    collected_images.append(transformed_image)
                    collected_labels.append(label)

            # Check if we have collected enough images for the batch
            if len(collected_images) == self.per_yield_limit:
                # Yield the batch
                yield {
                    "image": torch.stack(collected_images[:self.per_yield_limit]).to(self.device),
                    "label": torch.tensor(collected_labels[:self.per_yield_limit]).to(self.device)
                }

                # Clear the lists after yielding
                collected_images.clear()
                collected_labels.clear()


In [8]:
def create_dataloader(split, batch_num=1, batch_size=1, num_perturbations=1, num_transformations=1):
    # Load the dataset directly using load_dataset
    ds = load_dataset("ILSVRC/imagenet-1k", split=split, streaming=True, trust_remote_code=True)

    # Preprocess the dataset
    ds = ds.filter(lambda example: example['image'].mode == 'RGB')
    ds = ds.map(preprocess_img)
    ds = ds.shuffle(seed=SEED)
    ds = ds.take(batch_num * batch_size)
    print(f"loaded dataset of {batch_num * batch_size} images & labels from split {split}")

    # Create the RobustDataLoader
    data_loader = RobustDataLoader(ds,
                                   split=split,
                                   num_perturbations=num_perturbations,
                                   num_transformations=num_transformations,
                                   batch_size=batch_size,
                                   batch_num=batch_num
                                   )
    print(f"Instantiated dataloader for split {split}, with {len(data_loader)} image, label pairs")
    assert len(data_loader) == batch_num * batch_size * (
            1 + num_transformations + num_perturbations), f"Length of the dataset should have been {batch_num * batch_size * (1 + num_transformations) + num_perturbations}"
    return data_loader


### 3.1 Testing Data loaders

In [9]:
loader = create_dataloader("train", batch_num=2, batch_size=1)

for i, batch in enumerate(loader):
    print(f"batch {i}")
    img, label = batch['image'], batch['label']
    print(f"image shape: {img.shape}")
    print(f"label shape: {label.shape}, labels: {label}")
    print("-"* 30)

loaded dataset of 2 images & labels from split train
Creating new model
setting per_yield_limit to 3, as there are additional #1 transformations & #1 perturbations
RobustDataLoader initialized with split=train, batch_num=2, batch_size=1, num_perturbations=1, num_transformations=1, per_yield_limit=3
Instantiated dataloader for split train, with 6 image, label pairs
batch 0
image shape: torch.Size([3, 3, 224, 224])
label shape: torch.Size([3]), labels: tensor([417, 417, 417], device='cuda:0')
------------------------------
batch 1
image shape: torch.Size([3, 3, 224, 224])
label shape: torch.Size([3]), labels: tensor([476, 476, 476], device='cuda:0')
------------------------------


In [10]:
loader = create_dataloader("validation", batch_num=2, batch_size=1)
for batch in loader:
    print(f"val batch {i}")
    img, label = batch['image'], batch['label']
    print(f"image shape: {img.shape}")
    print(f"label shape: {label.shape}")
    print("-"* 30)


loaded dataset of 2 images & labels from split validation
Creating new model
setting per_yield_limit to 3, as there are additional #1 transformations & #1 perturbations
RobustDataLoader initialized with split=validation, batch_num=2, batch_size=1, num_perturbations=1, num_transformations=1, per_yield_limit=3
Instantiated dataloader for split validation, with 6 image, label pairs
val batch 1
image shape: torch.Size([3, 3, 224, 224])
label shape: torch.Size([3])
------------------------------
val batch 1
image shape: torch.Size([3, 3, 224, 224])
label shape: torch.Size([3])
------------------------------


In [11]:
loader = create_dataloader("test", batch_num=2, batch_size=1)
for batch in loader:
    print(f"test batch {i}")
    img, label = batch['image'], batch['label']
    print(f"image shape: {img.shape}")
    print(f"label shape: {label.shape}")
    print("-"* 30)

del loader

loaded dataset of 2 images & labels from split test
Creating new model
setting per_yield_limit to 1, which is the same as batch size as it is in test mode
RobustDataLoader initialized with split=test, batch_num=2, batch_size=1, num_perturbations=1, num_transformations=1, per_yield_limit=1
Instantiated dataloader for split test, with 6 image, label pairs
test batch 1
image shape: torch.Size([1, 3, 224, 224])
label shape: torch.Size([1])
------------------------------
test batch 1
image shape: torch.Size([1, 3, 224, 224])
label shape: torch.Size([1])
------------------------------


# 4. Implementing Robust Resnet

In [13]:
class RobustResnet(LightningModule):
    def __init__(self, train_loader, val_loader, test_loader, learning_rate=1e-3):
        super().__init__()
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.loss_fn = nn.CrossEntropyLoss()
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader
        self.learning_rate = learning_rate
        self.train_acc = Accuracy("multiclass", num_classes=1000)
        self.val_acc = Accuracy("multiclass", num_classes=1000)
        self.test_acc = Accuracy("multiclass", num_classes=1000)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch['image'], batch['label']
        logits = self.model(images)
        ce_loss = self.loss_fn(logits, labels)
        pgd_loss = self.pgd_loss(images, labels)
        loss = ce_loss + pgd_loss

        self.train_acc(logits, labels)
        self.log('train_ce_loss', ce_loss)
        self.log('train_pgd_loss', pgd_loss)
        self.log('train_loss', loss)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch['image'], batch['label']
        logits = self.model(images)
        ce_loss = self.loss_fn(logits, labels)
        pgd_loss = self.pgd_loss(images, labels)
        loss = ce_loss + pgd_loss
        self.val_acc(logits, labels)
        self.log('val_ce_loss', ce_loss, on_step=False, on_epoch=True)
        self.log('val_pgd_loss', pgd_loss, on_step=False, on_epoch=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)

    #TODO. this wont work as the labels will be -1 in this case
    # def test_step(self, batch, batch_idx):
    #     images, labels = batch
    #     logits = self.model(images)
    #     ce_loss = F.cross_entropy(logits, labels)
    #     pgd_loss = self.pgd_loss(images, labels)
    #     loss = ce_loss + pgd_loss
    #     acc = self.test_acc(logits, labels)
    #     self.log('test_ce_loss', ce_loss)
    #     self.log('test_pgd_loss', pgd_loss)
    #     self.log('test_loss', loss)
    #     self.log('test_acc', acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def pgd_loss(self, images, labels):
        # randomize pgd params
        eps = random.uniform(0.01, 0.3)
        alpha = random.uniform(0.01, 0.1)
        steps = random.randint(15, 20)
        adv_images = images.clone() #TODO.maybe this is the problem?
        #TODO.x print this and see if it has a grad_fn function at all 
        
        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise
        
        # Enable gradient tracking for adversarial images
        adv_images.requires_grad = True
        
        # do randomized PGD process and compute loss
        for _ in range(steps):
            # Get model predictions and apply softmax
            outputs = self.model(adv_images).softmax(1)

            # Calculate loss
            pgd_loss = self.loss_fn(outputs, labels)

            # Compute gradient wrt images
            grad = torch.autograd.grad(
                pgd_loss, adv_images, retain_graph=True, create_graph=True
            )
            grad = grad[0]

            # Gradient update
            adv_images = adv_images + alpha * grad.sign()  # Update adversarial images using the sign of the gradient

            # Projection step
            # Clamping the adversarial images to ensure they are within the L∞ ball of eps radius of original image
            adv_images = torch.clamp(adv_images, images - eps, images + eps)

        return pgd_loss

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    # def test_dataloader(self):
    #     return self.test_loader
    
#regular kind
train_loader = get_dataloader("train", batch_num=2, batch_size=1)
val_loader = get_dataloader("validation", batch_num=2, batch_size=1)
test_loader = get_dataloader("test", batch_num=2, batch_size=1)


# Create the classifier
classifier = RobustResnet(train_loader, val_loader, test_loader)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Monitor validation loss for saving the best model
    dirpath='checkpoints/',  # Directory where checkpoints will be saved
    filename='best-checkpoint',  # Checkpoint filename
    save_top_k=1,  # Save only the best model
    mode='min'  # Minimize the monitored metric
)

# Create a Trainer
#TODO.x check fi training step can be done first insteadf of validation step
trainer = Trainer(
    max_epochs=3,
    accelerator='gpu',  # Specify the accelerator type
    devices=1,  # Use 1 GPU if available
    callbacks=[checkpoint_callback],  # Add the checkpoint callback
    log_every_n_steps=1  # Log metrics every step
)

# Train the model
trainer.fit(classifier)



Loading train ILSVRC/imagenet-1k dataset...
Creating dataloader with 2 batches for split train each with size 1
Loading validation ILSVRC/imagenet-1k dataset...
Creating dataloader with 2 batches for split validation each with size 1
Loading test ILSVRC/imagenet-1k dataset...
Creating dataloader with 2 batches for split test each with size 1


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | ResNet             | 25.6 M | train
1 | loss_fn   | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
4 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
25.6 M    Trainable params
0         Non-trainable params
25.6 M    Total params
102.228   Total estimated model params size (MB)
155       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\paras\anaconda3\envs\trustworthyml\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [10]:
# Create dataloaders
# Robust kind
# train_loader = create_dataloader("train", batch_num=2, batch_size=1)
# val_loader = create_dataloader("validation", batch_num=2, batch_size=1)
# test_loader = create_dataloader("test", batch_num=2, batch_size=1)


# Create the classifier
classifier = RobustResnet(train_loader, val_loader, test_loader)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Monitor validation loss for saving the best model
    dirpath='checkpoints/',  # Directory where checkpoints will be saved
    filename='best-checkpoint',  # Checkpoint filename
    save_top_k=1,  # Save only the best model
    mode='min'  # Minimize the monitored metric
)

# Create a Trainer
trainer = Trainer(
    max_epochs=3,
    accelerator='gpu',  # Specify the accelerator type
    devices=1,  # Use 1 GPU if available
    callbacks=[checkpoint_callback],  # Add the checkpoint callback
    log_every_n_steps=1  # Log metrics every step
)

# Train the model
trainer.fit(classifier)


Loading train ILSVRC/imagenet-1k dataset...
Creating dataloader with 2 batches for split train each with size 1
Loading validation ILSVRC/imagenet-1k dataset...
Creating dataloader with 2 batches for split validation each with size 1
Loading test ILSVRC/imagenet-1k dataset...
Creating dataloader with 2 batches for split test each with size 1


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | ResNet             | 25.6 M | train
1 | loss_fn   | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
4 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
25.6 M    Trainable params
0         N

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\paras\anaconda3\envs\trustworthyml\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


TypeError: conv2d() received an invalid combination of arguments - got (str, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!str!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!str!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)


In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
# Test the model
trainer.test()

In [None]:
'''
TODO:
1. check tensorboard here in the cell directly 
2. evaluate the evaluate code as the checkpointing and reloading of that model is not as straightforward as I thought it would be 
'''

# 5. Saving model & Evaluating Results

In [None]:
def load_model_from_checkpoint(checkpoint_path):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)

    # Create a new instance of your model
    finetuned_model = RobustResnet()  # Initialize with the required parameters

    # Load the model weights from the checkpoint
    finetuned_model.load_state_dict(checkpoint['state_dict'])

    # Set the model to evaluation mode
    finetuned_model.eval()
    
    return finetuned_model

def evaluate_model(checkpoint_path, original_model, test_dataloader, device=DEVICE):
    # Load the fine-tuned model from the checkpoint
    fine_tuned_model = load_model_from_checkpoint(checkpoint_path)

    # Move models to the appropriate device
    original_model.to(device)
    fine_tuned_model.to(device)

    # Initialize accuracy metrics
    original_acc = Accuracy()
    fine_tuned_acc = Accuracy()

    # Evaluate the original model
    with torch.no_grad():
        for batch in test_dataloader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)

            # Original model predictions
            original_logits = original_model(images)
            original_acc(original_logits, labels)

            # Fine-tuned model predictions
            fine_tuned_logits = fine_tuned_model(images)
            fine_tuned_acc(fine_tuned_logits, labels)

    # Calculate accuracies
    original_accuracy = original_acc.compute()
    fine_tuned_accuracy = fine_tuned_acc.compute()

    print(f'Evaluation Original Model Accuracy: {original_accuracy:.4f}')
    print(f'Evaluation Fine-Tuned Model Accuracy: {fine_tuned_accuracy:.4f}')

    return original_accuracy, fine_tuned_accuracy