# About:

This notebook basically finetunes a "Robust Resnet" by using a fine tuning dataset of PGD perturbed images and its augmentations

# Installing Libraries


In [None]:
!pip cache purge

In [None]:
!pip install pytorch-lightning
!pip install lightning
!pip install lightning[extra]
!pip install datasets
!pip install -U "huggingface_hub[cli]"
!pip install tensorboard

# 0. Importing libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from datasets import load_dataset
from torch.optim import lr_scheduler

from pytorch_lightning.loggers import TensorBoardLogger
import uuid
import random
from torchvision.models import resnet50, ResNet50_Weights
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
import copy
import shutil
import zipfile
from torch.optim import lr_scheduler

# 1.Setup

In [None]:
# Manual seed for reproducibility
# SEED = 1234
# torch.manual_seed(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {DEVICE}")

base_path = "/content/drive/MyDrive/trustworthyml"
zips_path = os.path.join(base_path, "zips")
model_path = os.path.join(base_path, "model", "ft_stream")
logs_path = os.path.join(base_path, "logs")

# 2. Creating data loaders

In [None]:
class ResnetPGDAttacker:
    def __init__(self):
        '''
        The PGD attack on Resnet model.
        :param model: The resnet model on which we perform the attack
        :param dataloader: The dataloader loading the input data on which we perform the attack
        '''
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.loss_fn = nn.CrossEntropyLoss()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.resnet.to(self.device)

        # Nullify gradient for model params
        # for p in self.resnet.parameters():
        #     p.requires_grad = False

    def pgd_attack(self, images, labels, eps, alpha, steps):
        images = images.requires_grad_(True).to(self.device)
        labels = labels.to(self.device)
        adv_images = images.clone().requires_grad_(True).to(self.device)

        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise

        for _ in range(steps):
            outputs = self.resnet(adv_images)
            pgd_loss = nn.CrossEntropyLoss()(outputs, labels)
            grad = torch.autograd.grad(pgd_loss, adv_images, retain_graph=True)[0]

            with torch.no_grad():
                adv_images = adv_images + alpha * grad.sign()
                adv_images = torch.clamp(adv_images, images - eps, images + eps)

            adv_images = adv_images.detach().requires_grad_(True).to(self.device)

        return adv_images

In [None]:

class RobustDataLoader:
    def __init__(self, batch_num, batch_size, allow_perturbations=True, perturbation_chance=1.0):
        self.batch_num = batch_num
        self.batch_size = batch_size
        self.allow_perturbations = allow_perturbations
        self.pgd_attacker = ResnetPGDAttacker()
        self.perturbation_chance = perturbation_chance

        # Load the dataset
        self.ds = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True, trust_remote_code=True)
        self.ds = self.ds.shuffle()
        self.ds = self.ds.filter(lambda example: example['image'].mode == 'RGB')
        self.ds = self.ds.take(self.batch_num * self.batch_size)  # Take a fixed number of examples
        self.ds = self.ds.map(self.preprocess_img)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Create a DataLoader
        self.data_loader = DataLoader(self.ds, batch_size=self.batch_size)
        self.data_iter = iter(self.data_loader)

        print(
            f"Robust DataLoader created with batch_num = {self.batch_num}, batch_size = {self.batch_size}, perturbation_chance = {self.perturbation_chance}"
        )

    def preprocess_img(self, example):
        weights = ResNet50_Weights.DEFAULT
        transform = weights.transforms()
        example['image'] = transform(example['image'])
        return example

    def __iter__(self):
        return self

    def __next__(self):
        # Get the next batch from the DataLoader iterator
        try:
            batch = next(self.data_iter)
        except StopIteration:
            raise StopIteration  # Raise StopIteration when the DataLoader is exhausted

        # Get the next batch from the DataLoader
        images = batch['image']
        labels = batch['label']

        # Randomly apply PGD attack if allowed
        if self.allow_perturbations and random.random() < self.perturbation_chance:
            random_eps = random.uniform(0.01, 0.3)
            random_alpha = random.uniform(0.01, 0.07)
            random_steps = random.randint(15, 20)
            perturbed_images = self.pgd_attacker.pgd_attack(images, labels, random_eps, random_alpha, random_steps)
            return perturbed_images.to(self.device), labels.to(self.device)
        else:
            return images.to(self.device), labels.to(self.device)

    def __len__(self):
        return self.batch_num


# 3. Implementing Robust Resnet

In [None]:
class RobustResnet:
    def __init__(self, data_loader, epochs, train_steps, val_steps, test_steps, checkpoint_dir, learning_rate=1e-3,
                 freeze_interval=10):
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.freeze_interval = freeze_interval
        self.loss_fn = nn.CrossEntropyLoss()
        self.data_loader = data_loader
        self.data_iter = iter(self.data_loader)
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.train_steps = train_steps
        self.val_steps = val_steps
        self.test_steps = test_steps
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.global_step = 0  # Track the number of training steps
        self.best_val_accuracy = 0.0  # Track the best validation accuracy
        self.checkpoint_dir = checkpoint_dir

        # Learning rate scheduler (example: StepLR)
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.1)

    def training_phase(self):
        self.model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        # Use TQDM for training
        pbar = tqdm(total=self.train_steps, desc="Training", colour="green")
        for _ in range(self.train_steps):
            try:
                batch = next(self.data_iter)
            except StopIteration:
                print("Training dataloader exhausted!")
                raise StopIteration
            
            if self.global_step % self.freeze_interval == 0:
                self.random_freeze_layers()

            images, labels = batch
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.loss_fn(outputs, labels)
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            self.global_step += 1

            # Update TQDM
            pbar.set_postfix(loss=loss.item(), accuracy=(predicted == labels).sum().item() / len(labels))
            pbar.update(1)

        pbar.close()

        # Step the scheduler at the end of the epoch
        self.scheduler.step()

        if self.train_steps == 0:
            return 0, 0
        
        epoch_train_loss = train_loss / self.train_steps
        epoch_train_accuracy = train_correct / train_total
        return epoch_train_loss, epoch_train_accuracy

    def validation_phase(self):
        self.model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        pbar = tqdm(total=self.val_steps, desc="Validation", colour="orange")
        for _ in range(self.val_steps):
            try:
                batch = next(self.data_iter)
            except StopIteration:
                print("Training dataloader exhausted!")
                raise StopIteration

            with torch.no_grad():
                images, labels = batch
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.loss_fn(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

                # Update TQDM
                pbar.set_postfix(loss=loss.item(), accuracy=(predicted == labels).sum().item() / len(labels))
                pbar.update(1)

        pbar.close()

        epoch_val_loss = val_loss / self.val_steps
        epoch_val_accuracy = val_correct / val_total
        return epoch_val_loss, epoch_val_accuracy

    def train(self):
        for epoch in range(1, self.epochs + 1):
            print("-" * 50)
            print(f"EPOCH {epoch}/{self.epochs}")

            epoch_train_loss, epoch_train_accuracy = self.training_phase()
            print(f"Epoch {epoch} Training Loss: {epoch_train_loss}, Accuracy: {epoch_train_accuracy}")

            epoch_val_loss, epoch_val_accuracy = self.validation_phase()
            print(f"Epoch {epoch} Validation Loss: {epoch_val_loss}, Accuracy: {epoch_val_accuracy}")

            if epoch_val_accuracy > self.best_val_accuracy:
                print(f"Previous best accuracy: {self.best_val_accuracy}, Current epoch accuracy: {epoch_val_accuracy}")
                self.best_val_accuracy = epoch_val_accuracy
                checkpoint_file_name = f"ft-stream-crazy-epoch-{epoch}.ckpt"
                checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_file_name)
                torch.save({"state_dict": self.model.state_dict()}, checkpoint_path)
                print(f"Checkpoint saved: {checkpoint_path}")

    def test(self):
        self.model.eval()
        test_loss, test_correct, test_total = 0.0, 0, 0
        pbar = tqdm(total=self.test_steps, desc="Test", colour="red")
        for _ in range(self.test_steps):
            try:
                batch = next(self.data_iter)
            except StopIteration:
                print("Training dataloader exhausted!")
                raise StopIteration

            with torch.no_grad():
                images, labels = batch
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.loss_fn(outputs, labels)

                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()

                # Update TQDM
                pbar.set_postfix(loss=loss.item(), accuracy=(predicted == labels).sum().item() / len(labels))
                pbar.update(1)

        pbar.close()

        epoch_test_loss = test_loss / self.test_steps
        epoch_test_accuracy = test_correct / test_total
        print(f"Test Loss: {epoch_test_loss}, Test Accuracy: {epoch_test_accuracy}")
        return epoch_test_loss, epoch_test_accuracy

    def random_freeze_layers(self):
        num_layers = len(list(self.model.children()))
        layers_to_freeze = int(num_layers * 0.6)

        # Unfreeze all layers first
        for layer in self.model.children():
            for param in layer.parameters():
                param.requires_grad = True

        # Randomly select layers to freeze
        layers_to_freeze_indices = random.sample(range(num_layers), layers_to_freeze)
        for i, child in enumerate(self.model.children()):
            if i in layers_to_freeze_indices:
                for param in child.parameters():
                    param.requires_grad = False  # Freeze parameters
            else:
                for param in child.parameters():
                    param.requires_grad = True


# 3. Training the model

In [None]:
epochs = 15
train_steps = 1200
val_steps = 25
test_steps = 60
batch_size = 16
freeze_steps = 15

data_loader = RobustDataLoader(
    batch_num=epochs * (train_steps + val_steps) + test_steps,
    batch_size=batch_size,
    perturbation_chance=0.72
)

classifier = RobustResnet(data_loader, epochs=epochs, train_steps=train_steps, val_steps=val_steps,
                          test_steps=test_steps, checkpoint_dir=model_path)

classifier.train()



In [None]:
classifier.test()

# 3 Evaluating Results

In [None]:
class Evaluation:
    def __init__(self, checkpoint_path, test_loader, evaluate_original=True, evaluate_finetuned=True):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.evaluate_original = evaluate_original
        self.evaluate_finetuned = evaluate_finetuned
        self.checkpoint_path = checkpoint_path
        self.test_loader = test_loader

        if self.evaluate_original:
            self.original_model = resnet50(weights=ResNet50_Weights.DEFAULT)
            self.original_model.to(self.device)
            self.original_model.eval()

        if self.evaluate_finetuned:
            self.fine_tuned_model = resnet50()
            self.load_model_from_checkpoint()
            self.fine_tuned_model.eval()

        print(f"evaluate original : {evaluate_original}")
        print(f"evaluate finetuned : {evaluate_finetuned}")

    def load_model_from_checkpoint(self):
        checkpoint = torch.load(self.checkpoint_path)

        self.fine_tuned_model.load_state_dict(checkpoint['state_dict'])
        self.fine_tuned_model.to(self.device)
        self.fine_tuned_model.eval()

    def evaluate_model(self):
        # Initialize accuracy metrics
        total = 0
        orig_correct = 0
        ft_correct = 0

        # Evaluate the original model
        for batch in tqdm(self.test_loader):
            images, labels = batch
            total += len(labels)

            images, labels = images.to(self.device), labels.to(self.device)

            # Original model predictions
            if self.evaluate_original:
                original_logits = self.original_model(images).softmax(1)
                original_predictions = original_logits.argmax(dim=1)
                orig_correct += torch.sum(original_predictions == labels).item()

            if self.evaluate_finetuned:
                ft_images = images.clone().detach().to(self.device)
                ft_labels = labels.clone().detach().to(self.device)

                # Fine-tuned model predictions
                fine_tuned_logits = self.fine_tuned_model(ft_images).softmax(1)
                fine_tuned_predictions = fine_tuned_logits.argmax(dim=1)
                ft_correct += torch.sum(fine_tuned_predictions == ft_labels).item()

        # Calculate accuracies
        result = {}
        if self.evaluate_original:
            original_accuracy = orig_correct / total
            print(f'Evaluation Original Model Accuracy: {original_accuracy * 100} %')
            result["original_accuracy"] = original_accuracy

        if self.evaluate_finetuned:
            fine_tuned_accuracy = ft_correct / total
            print(f'Evaluation Fine-Tuned Model Accuracy: {fine_tuned_accuracy * 100} %')
            result["fine_tuned_accuracy"] = fine_tuned_accuracy

        return result

In [None]:
clean_test_loader = RobustDataLoader(
    batch_num=60,
    batch_size=batch_size,
    perturbation_chance=0.0
)

perturbed_test_loader = RobustDataLoader(
    batch_num=60,
    batch_size=batch_size,
    perturbation_chance=1.0
)

mix_test_loader = RobustDataLoader(
    batch_num=60,
    batch_size=batch_size,
    perturbation_chance=0.55
)

epoch = 1 #TODO. change this
checkpoint_to_evaluate = os.path.join(model_path, f"ft-stream-crazy-epoch-{epoch}.ckpt")

Testing on only clean data

In [None]:
evaluation = Evaluation(checkpoint_to_evaluate, clean_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

Testing on only perturbed data

In [None]:
evaluation = Evaluation(checkpoint_to_evaluate, perturbed_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

Testing on mix data

In [None]:
evaluation = Evaluation(checkpoint_to_evaluate, mix_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()