# About:

This notebook basically finetunes a "Robust Resnet" by using a fine tuning dataset of PGD perturbed images and its augmentations

# Installing Libraries


In [None]:
!pip cache purge

Files removed: 2913


In [None]:
# !pip install pytorch-lightning
# !pip install lightning
# !pip install lightning[extra]
!pip install datasets
!pip install -U "huggingface_hub[cli]"
# !pip install tensorboard

Collecting datasets
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.0-py3-none-any.whl (474 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.3/474.3 kB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K 

# 0. Importing libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from datasets import load_dataset
from torch.optim import lr_scheduler

import uuid
import random
from torchvision.models import resnet50, ResNet50_Weights
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
import copy
import shutil
import zipfile
from torch.optim import lr_scheduler

# 1.Setup

In [None]:
# Manual seed for reproducibility
# SEED = 1234
# torch.manual_seed(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {DEVICE}")

base_path = "/content/drive/MyDrive/trustworthyml"
zips_path = os.path.join(base_path, "zips")
model_path = os.path.join(base_path, "model", "ft_stream")
logs_path = os.path.join(base_path, "logs")

device: cuda


# 2. Creating data loaders

In [None]:
class ResnetPGDAttacker:
    def __init__(self):
        '''
        The PGD attack on Resnet model.
        :param model: The resnet model on which we perform the attack
        :param dataloader: The dataloader loading the input data on which we perform the attack
        '''
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.loss_fn = nn.CrossEntropyLoss()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.resnet.to(self.device)

        # Nullify gradient for model params
        # for p in self.resnet.parameters():
        #     p.requires_grad = False

    def pgd_attack(self, images, labels, eps, alpha, steps):
        images = images.requires_grad_(True).to(self.device)
        labels = labels.to(self.device)
        adv_images = images.clone().requires_grad_(True).to(self.device)

        # Starting at a uniformly random point within the eps ball
        random_noise = torch.zeros_like(adv_images).uniform_(-eps, eps)
        adv_images = adv_images + random_noise

        for _ in range(steps):
            outputs = self.resnet(adv_images)
            pgd_loss = nn.CrossEntropyLoss()(outputs, labels)
            grad = torch.autograd.grad(pgd_loss, adv_images, retain_graph=True)[0]

            with torch.no_grad():
                adv_images = adv_images + alpha * grad.sign()
                adv_images = torch.clamp(adv_images, images - eps, images + eps)

            adv_images = adv_images.detach().requires_grad_(True).to(self.device)

        return adv_images

In [None]:
class RobustDataLoader:
    def __init__(self, batch_num, batch_size, allow_perturbations=True, perturbation_chance=1.0):
        self.batch_num = batch_num
        self.batch_size = batch_size
        self.allow_perturbations = allow_perturbations
        self.pgd_attacker = ResnetPGDAttacker()
        self.perturbation_chance = perturbation_chance

        # Load the dataset
        self.ds = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True, trust_remote_code=True)
        self.ds = self.ds.shuffle()
        self.ds = self.ds.filter(lambda example: example['image'].mode == 'RGB')
        self.ds = self.ds.take(self.batch_num * self.batch_size)  # Take a fixed number of examples
        self.ds = self.ds.map(self.preprocess_img)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Create a DataLoader
        self.data_loader = DataLoader(self.ds, batch_size=self.batch_size)
        self.data_iter = iter(self.data_loader)

        print(
            f"Robust DataLoader created with batch_num = {self.batch_num}, batch_size = {self.batch_size}, perturbation_chance = {self.perturbation_chance}"
        )

    def preprocess_img(self, example):
        weights = ResNet50_Weights.DEFAULT
        transform = weights.transforms()
        example['image'] = transform(example['image'])
        return example

    def __iter__(self):
        return self

    def __next__(self):
        # Get the next batch from the DataLoader iterator
        try:
            batch = next(self.data_iter)
        except StopIteration:
            raise StopIteration  # Raise StopIteration when the DataLoader is exhausted

        # Get the next batch from the DataLoader
        images = batch['image']
        labels = batch['label']

        # Randomly apply PGD attack if allowed
        if self.allow_perturbations and random.random() < self.perturbation_chance:
            random_eps = random.uniform(0.01, 0.3)
            random_alpha = random.uniform(0.01, 0.07)
            random_steps = random.randint(15, 20)
            perturbed_images = self.pgd_attacker.pgd_attack(images, labels, random_eps, random_alpha, random_steps)
            return perturbed_images.to(self.device), labels.to(self.device)
        else:
            return images.to(self.device), labels.to(self.device)

    def __len__(self):
        return self.batch_num


# 3. Implementing Robust Resnet

In [None]:
class RobustResnet:
    def __init__(self, data_loader, epochs, train_steps, val_steps, test_steps, checkpoint_dir, learning_rate=1e-3,
                 freeze_interval=10):
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.freeze_interval = freeze_interval
        self.loss_fn = nn.CrossEntropyLoss()
        self.data_loader = data_loader
        self.data_iter = iter(self.data_loader)
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.train_steps = train_steps
        self.val_steps = val_steps
        self.test_steps = test_steps
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.global_step = 0  # Track the number of training steps
        self.best_val_accuracy = 0.0  # Track the best validation accuracy
        self.checkpoint_dir = checkpoint_dir
        self.info = dict()

        # Learning rate scheduler (example: StepLR)
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.1)

    def load_model_from_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path)

        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)
        print(f"Loaded model from path {checkpoint_path}")

    def training_phase(self):
        self.model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        # Use TQDM for training
        pbar = tqdm(total=self.train_steps, desc="Training", colour="green")
        for _ in range(self.train_steps):
            try:
                batch = next(self.data_iter)
            except StopIteration:
                print("Training dataloader exhausted!")
                raise StopIteration

            if self.global_step % self.freeze_interval == 0:
                self.random_freeze_layers()

            images, labels = batch
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.loss_fn(outputs, labels)
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            self.global_step += 1

            # Update TQDM
            pbar.set_postfix(loss=loss.item(), accuracy=(predicted == labels).sum().item() / len(labels))
            pbar.update(1)

        pbar.close()

        # Step the scheduler at the end of the epoch
        self.scheduler.step()

        if self.train_steps == 0:
            return 0, 0

        epoch_train_loss = train_loss / self.train_steps
        epoch_train_accuracy = train_correct / train_total
        return epoch_train_loss, epoch_train_accuracy

    def validation_phase(self):
        self.model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        pbar = tqdm(total=self.val_steps, desc="Validation", colour="orange")
        for _ in range(self.val_steps):
            try:
                batch = next(self.data_iter)
            except StopIteration:
                print("Training dataloader exhausted!")
                raise StopIteration

            with torch.no_grad():
                images, labels = batch
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.loss_fn(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

                # Update TQDM
                pbar.set_postfix(loss=loss.item(), accuracy=(predicted == labels).sum().item() / len(labels))
                pbar.update(1)

        pbar.close()

        epoch_val_loss = val_loss / self.val_steps
        epoch_val_accuracy = val_correct / val_total
        return epoch_val_loss, epoch_val_accuracy

    def train(self):
        for epoch in range(1, self.epochs + 1):
            print("-" * 50)
            print(f"EPOCH {epoch}/{self.epochs}")

            epoch_train_loss, epoch_train_accuracy = self.training_phase()
            print(f"Epoch {epoch} Training Loss: {epoch_train_loss}, Accuracy: {epoch_train_accuracy}")

            epoch_val_loss, epoch_val_accuracy = self.validation_phase()
            print(f"Epoch {epoch} Validation Loss: {epoch_val_loss}, Accuracy: {epoch_val_accuracy}")

            self.info[f"epoch-{epoch}"] = {
                "train_loss": epoch_train_loss,
                "train_accuracy": epoch_train_accuracy,
                "val_loss": epoch_val_loss,
                "val_accuracy": epoch_val_accuracy,
                "is_checkpointed": False
            }

            if epoch_val_accuracy >= self.best_val_accuracy:
                print(f"SAVING NEW WEIGHTS!!")
                print()
                print(f"Previous best accuracy: {self.best_val_accuracy}, Current epoch accuracy: {epoch_val_accuracy}")
                self.best_val_accuracy = epoch_val_accuracy
                checkpoint_file_name = f"ft-stream-crazy-epoch-{epoch}-{str(uuid.uuid4())}.ckpt"
                checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_file_name)
                torch.save({"state_dict": self.model.state_dict()}, checkpoint_path)
                print(f"Checkpoint saved: {checkpoint_path}")
                self.info[f"epoch-{epoch}"]["is_checkpointed"] = True

        # saving the final state as is
        print("Going to save final checkpoint")
        checkpoint_file_name = f"ft-stream-crazy-epoch-final-{str(uuid.uuid4())}.ckpt"
        checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_file_name)
        torch.save({"state_dict": self.model.state_dict()}, checkpoint_path)
        print(f"Final Checkpoint saved: {checkpoint_path}")
        return self.info

    def test(self):
        self.model.eval()
        test_loss, test_correct, test_total = 0.0, 0, 0
        pbar = tqdm(total=self.test_steps, desc="Test", colour="red")
        for _ in range(self.test_steps):
            try:
                batch = next(self.data_iter)
            except StopIteration:
                print("Training dataloader exhausted!")
                raise StopIteration

            with torch.no_grad():
                images, labels = batch
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.loss_fn(outputs, labels)

                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()

                # Update TQDM
                pbar.set_postfix(loss=loss.item(), accuracy=(predicted == labels).sum().item() / len(labels))
                pbar.update(1)

        pbar.close()

        epoch_test_loss = test_loss / self.test_steps
        epoch_test_accuracy = test_correct / test_total
        print(f"Test Loss: {epoch_test_loss}, Test Accuracy: {epoch_test_accuracy}")
        return epoch_test_loss, epoch_test_accuracy

    def random_freeze_layers(self):
        num_layers = len(list(self.model.children()))
        layers_to_freeze = int(num_layers * 0.6)

        # Unfreeze all layers first
        for layer in self.model.children():
            for param in layer.parameters():
                param.requires_grad = True

        # Randomly select layers to freeze
        layers_to_freeze_indices = random.sample(range(num_layers), layers_to_freeze)
        for i, child in enumerate(self.model.children()):
            if i in layers_to_freeze_indices:
                for param in child.parameters():
                    param.requires_grad = False  # Freeze parameters
            else:
                for param in child.parameters():
                    param.requires_grad = True

        # truly crazy approach , where we determine when to freeze next by again doing a random number pick
        self.freeze_interval = random.randint(20,50)


# 3. Training the model

In [None]:
print("DONT RUN!")

In [None]:
# buffer

In [None]:
epochs = 15
train_steps = 1400
val_steps = 30
test_steps = 60
batch_size = 16
freeze_steps = 15

data_loader = RobustDataLoader(
    batch_num=epochs * (train_steps + val_steps) + test_steps,
    batch_size=batch_size,
    perturbation_chance=0.72
)

classifier = RobustResnet(data_loader, epochs=epochs, train_steps=train_steps, val_steps=val_steps,
                          test_steps=test_steps, checkpoint_dir=model_path)

classifier.train()



Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 192MB/s]


imagenet-1k.py:   0%|          | 0.00/4.58k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/85.4k [00:00<?, ?B/s]

classes.py:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

Robust DataLoader created with batch_num = 21510, batch_size = 16, perturbation_chance = 0.72
--------------------------------------------------
EPOCH 1/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]



Epoch 1 Training Loss: 4.498416296924863, Accuracy: 0.18004464285714286


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1 Validation Loss: 4.223898180325826, Accuracy: 0.19166666666666668
Previous best accuracy: 0.0, Current epoch accuracy: 0.19166666666666668
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-1.ckpt
--------------------------------------------------
EPOCH 2/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 2 Training Loss: 4.21191226380212, Accuracy: 0.19058035714285715


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 2 Validation Loss: 4.169860315322876, Accuracy: 0.2125
Previous best accuracy: 0.19166666666666668, Current epoch accuracy: 0.2125
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-2.ckpt
--------------------------------------------------
EPOCH 3/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 3 Training Loss: 4.04775197335652, Accuracy: 0.20616071428571428


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 3 Validation Loss: 3.7000436941782633, Accuracy: 0.2604166666666667
Previous best accuracy: 0.2125, Current epoch accuracy: 0.2604166666666667
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-3.ckpt
--------------------------------------------------
EPOCH 4/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 4 Training Loss: 3.8319600100176676, Accuracy: 0.2303125


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 4 Validation Loss: 3.9693466424942017, Accuracy: 0.24375
--------------------------------------------------
EPOCH 5/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 5 Training Loss: 3.7576196094921657, Accuracy: 0.24736607142857142


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 5 Validation Loss: 3.7863086064656577, Accuracy: 0.2604166666666667
--------------------------------------------------
EPOCH 6/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 6 Training Loss: 3.3460745852334157, Accuracy: 0.3069642857142857


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 6 Validation Loss: 3.099467424551646, Accuracy: 0.35208333333333336
Previous best accuracy: 0.2604166666666667, Current epoch accuracy: 0.35208333333333336
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-6.ckpt
--------------------------------------------------
EPOCH 7/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]



Epoch 7 Training Loss: 3.0931578053746906, Accuracy: 0.3467857142857143


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 7 Validation Loss: 2.848844623565674, Accuracy: 0.4
Previous best accuracy: 0.35208333333333336, Current epoch accuracy: 0.4
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-7.ckpt
--------------------------------------------------
EPOCH 8/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 8 Training Loss: 2.9736335290329796, Accuracy: 0.3721875


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 8 Validation Loss: 2.765092138449351, Accuracy: 0.39166666666666666
--------------------------------------------------
EPOCH 9/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 9 Training Loss: 2.865706052354404, Accuracy: 0.3833035714285714


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 9 Validation Loss: 2.674315098921458, Accuracy: 0.425
Previous best accuracy: 0.4, Current epoch accuracy: 0.425
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-9.ckpt
--------------------------------------------------
EPOCH 10/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 10 Training Loss: 2.8079429468086787, Accuracy: 0.39102678571428573


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 10 Validation Loss: 2.7770190517107647, Accuracy: 0.3729166666666667
--------------------------------------------------
EPOCH 11/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 11 Training Loss: 2.733821140740599, Accuracy: 0.4059375


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 11 Validation Loss: 2.6677047411600747, Accuracy: 0.4083333333333333
--------------------------------------------------
EPOCH 12/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 12 Training Loss: 2.716946328708104, Accuracy: 0.40950892857142857


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 12 Validation Loss: 2.6638808727264403, Accuracy: 0.425
--------------------------------------------------
EPOCH 13/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 13 Training Loss: 2.729991740158626, Accuracy: 0.4077232142857143


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 13 Validation Loss: 2.735597725709279, Accuracy: 0.4166666666666667
--------------------------------------------------
EPOCH 14/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 4dbd03eb-bfee-4187-afa8-0540ea8d2623)')' thrown while requesting GET https://huggingface.co/datasets/ILSVRC/imagenet-1k/resolve/main/data/train_images_0.tar.gz
Retrying in 1s [Retry 1/5].


Epoch 14 Training Loss: 2.686984382527215, Accuracy: 0.4148660714285714


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 14 Validation Loss: 2.92523060242335, Accuracy: 0.375
--------------------------------------------------
EPOCH 15/15


Training:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch 15 Training Loss: 2.6864608507922716, Accuracy: 0.41535714285714287


Validation:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 15 Validation Loss: 2.480913786093394, Accuracy: 0.45416666666666666
Previous best accuracy: 0.425, Current epoch accuracy: 0.45416666666666666
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-15.ckpt


In [None]:
classifier.test()

Test:   0%|          | 0/60 [00:00<?, ?it/s]

Test Loss: 2.7773940642674764, Test Accuracy: 0.4114583333333333


(2.7773940642674764, 0.4114583333333333)

# 4. More training (from previous run)



In [None]:
def get_newest_file(directory_path):
    # Initialize variables to hold the name of the most recent file and its modification time
    most_recent_file = None
    most_recent_time = 0

    # Iterate over the files in the directory using os.scandir()
    for entry in os.scandir(directory_path):
        if entry.is_file():  # Check if the entry is a file
            # Get the modification time of the file
            mod_time = entry.stat().st_mtime
            # Update if this file is more recent than the current most recent
            if mod_time >= most_recent_time:
                most_recent_file = entry.name  # Store the full path
                most_recent_time = mod_time

    return os.path.join(directory_path, most_recent_file)

In [None]:
checkpoint_to_continue_training_from_path = get_newest_file(model_path)
print(checkpoint_to_continue_training_from_path)

/content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-final-6fc9177b-fe57-4950-98f8-45ef577d1bda.ckpt


In [20]:
epochs = 35
train_steps = 1500
val_steps = 35
test_steps = 60
batch_size = 16


data_loader = RobustDataLoader(
    batch_num=epochs * (train_steps + val_steps) + test_steps,
    batch_size=batch_size,
    perturbation_chance=0.6 # we would like clean data as well
)

re_classifier = RobustResnet(data_loader, epochs=epochs, train_steps=train_steps, val_steps=val_steps,
                          test_steps=test_steps, checkpoint_dir=model_path)


# re_classifier.load_model_from_checkpoint(checkpoint_to_continue_training_from_path)

re_classifier.train()



Robust DataLoader created with batch_num = 53785, batch_size = 16, perturbation_chance = 0.6
--------------------------------------------------
EPOCH 1/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 1 Training Loss: 3.9863444365660348, Accuracy: 0.2425


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 1 Validation Loss: 3.9069802011762347, Accuracy: 0.23214285714285715
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.0, Current epoch accuracy: 0.23214285714285715
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-1-c3c4b227-ffc6-462e-a9cc-e9d349df602d.ckpt
--------------------------------------------------
EPOCH 2/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]



Epoch 2 Training Loss: 3.8473707874615988, Accuracy: 0.242125


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 2 Validation Loss: 3.8532280649457658, Accuracy: 0.21428571428571427
--------------------------------------------------
EPOCH 3/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 3 Training Loss: 3.761446208000183, Accuracy: 0.24316666666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 3 Validation Loss: 3.6212562424795967, Accuracy: 0.26964285714285713
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.23214285714285715, Current epoch accuracy: 0.26964285714285713
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-3-a1e9ab66-fbfa-4fd4-9dfd-901aa84c62b1.ckpt
--------------------------------------------------
EPOCH 4/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]



Epoch 4 Training Loss: 3.6292912650903064, Accuracy: 0.2612083333333333


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 4 Validation Loss: 3.375184828894479, Accuracy: 0.3142857142857143
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.26964285714285713, Current epoch accuracy: 0.3142857142857143
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-4-af7a9594-c297-4e07-8d8d-8b9e825c066e.ckpt
--------------------------------------------------
EPOCH 5/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 5 Training Loss: 3.5420033915837608, Accuracy: 0.27316666666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 5 Validation Loss: 3.831254277910505, Accuracy: 0.2732142857142857
--------------------------------------------------
EPOCH 6/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 6 Training Loss: 3.114917760372162, Accuracy: 0.34120833333333334


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 6 Validation Loss: 2.8230250460760935, Accuracy: 0.4017857142857143
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.3142857142857143, Current epoch accuracy: 0.4017857142857143
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-6-73dd8b0d-3ef1-41a3-b07b-6b5653b09bd5.ckpt
--------------------------------------------------
EPOCH 7/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 7 Training Loss: 2.869030921379725, Accuracy: 0.38366666666666666


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 7 Validation Loss: 2.9386260202952794, Accuracy: 0.3732142857142857
--------------------------------------------------
EPOCH 8/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 8 Training Loss: 2.7674787809848787, Accuracy: 0.40129166666666666


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 8 Validation Loss: 2.750288282121931, Accuracy: 0.41964285714285715
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.4017857142857143, Current epoch accuracy: 0.41964285714285715
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-8-1e0c5da5-af6f-414a-9b3d-a937e3037d17.ckpt
--------------------------------------------------
EPOCH 9/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 9 Training Loss: 2.638507659117381, Accuracy: 0.4270833333333333


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 9 Validation Loss: 2.6977296931403023, Accuracy: 0.42142857142857143
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.41964285714285715, Current epoch accuracy: 0.42142857142857143
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-9-a7426719-278a-4f51-9822-bad5a0ff5035.ckpt
--------------------------------------------------
EPOCH 10/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 10 Training Loss: 2.5577388672828674, Accuracy: 0.43233333333333335


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 10 Validation Loss: 2.5273122889654975, Accuracy: 0.4660714285714286
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.42142857142857143, Current epoch accuracy: 0.4660714285714286
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-10-5c557ca7-5866-4c7f-996c-bc16713f7991.ckpt
--------------------------------------------------
EPOCH 11/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 11 Training Loss: 2.5329241843223573, Accuracy: 0.4414166666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 11 Validation Loss: 2.438398698398045, Accuracy: 0.48035714285714287
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.4660714285714286, Current epoch accuracy: 0.48035714285714287
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-11-f179f158-12bb-458f-9c38-d027d50f4d77.ckpt
--------------------------------------------------
EPOCH 12/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 12 Training Loss: 2.520109809954961, Accuracy: 0.44283333333333336


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 12 Validation Loss: 2.768987972395761, Accuracy: 0.3982142857142857
--------------------------------------------------
EPOCH 13/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 13 Training Loss: 2.469308384656906, Accuracy: 0.455625


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 13 Validation Loss: 2.529043480328151, Accuracy: 0.4589285714285714
--------------------------------------------------
EPOCH 14/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 14 Training Loss: 2.472635795990626, Accuracy: 0.4500416666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 14 Validation Loss: 2.4241785628455026, Accuracy: 0.4607142857142857
--------------------------------------------------
EPOCH 15/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]



Epoch 15 Training Loss: 2.491005914926529, Accuracy: 0.4487083333333333


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 15 Validation Loss: 2.3927333831787108, Accuracy: 0.4714285714285714
--------------------------------------------------
EPOCH 16/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 16 Training Loss: 2.4672432165940603, Accuracy: 0.4510416666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 16 Validation Loss: 2.6792311089379446, Accuracy: 0.44107142857142856
--------------------------------------------------
EPOCH 17/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 17 Training Loss: 2.46222508875529, Accuracy: 0.45675


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 17 Validation Loss: 2.325783494540623, Accuracy: 0.48928571428571427
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.48035714285714287, Current epoch accuracy: 0.48928571428571427
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-17-023683f3-4073-467f-a6da-91a4e3c22496.ckpt
--------------------------------------------------
EPOCH 18/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 18 Training Loss: 2.433821765780449, Accuracy: 0.4575


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 18 Validation Loss: 2.3529835905347554, Accuracy: 0.4732142857142857
--------------------------------------------------
EPOCH 19/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 19 Training Loss: 2.4724899258613586, Accuracy: 0.44925


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 19 Validation Loss: 2.4586020248276848, Accuracy: 0.4607142857142857
--------------------------------------------------
EPOCH 20/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 20 Training Loss: 2.473184358437856, Accuracy: 0.4514166666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 20 Validation Loss: 2.276105533327375, Accuracy: 0.48928571428571427
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.48928571428571427, Current epoch accuracy: 0.48928571428571427
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-20-12185abf-b9df-4720-8f60-56d72d652fa5.ckpt
--------------------------------------------------
EPOCH 21/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 21 Training Loss: 2.48071765601635, Accuracy: 0.450875


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 21 Validation Loss: 2.4502825771059307, Accuracy: 0.4785714285714286
--------------------------------------------------
EPOCH 22/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 22 Training Loss: 2.4596696434020995, Accuracy: 0.4530416666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 22 Validation Loss: 2.372183265004839, Accuracy: 0.4767857142857143
--------------------------------------------------
EPOCH 23/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 23 Training Loss: 2.5032312384049096, Accuracy: 0.44825


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 23 Validation Loss: 2.4632119280951366, Accuracy: 0.48392857142857143
--------------------------------------------------
EPOCH 24/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 24 Training Loss: 2.477288980325063, Accuracy: 0.44908333333333333


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 24 Validation Loss: 2.3973424843379427, Accuracy: 0.45714285714285713
--------------------------------------------------
EPOCH 25/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 25 Training Loss: 2.4432792309125264, Accuracy: 0.4590416666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 25 Validation Loss: 2.4299794503620693, Accuracy: 0.4589285714285714
--------------------------------------------------
EPOCH 26/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 26 Training Loss: 2.4802408179044724, Accuracy: 0.451875


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 26 Validation Loss: 2.6475999525615146, Accuracy: 0.4357142857142857
--------------------------------------------------
EPOCH 27/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 27 Training Loss: 2.4538530507087706, Accuracy: 0.45608333333333334


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 27 Validation Loss: 2.5312048264912197, Accuracy: 0.45357142857142857
--------------------------------------------------
EPOCH 28/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 28 Training Loss: 2.480343128760656, Accuracy: 0.45416666666666666


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 28 Validation Loss: 2.270134162902832, Accuracy: 0.4875
--------------------------------------------------
EPOCH 29/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 29 Training Loss: 2.4373019577264787, Accuracy: 0.45729166666666665


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 29 Validation Loss: 2.332135592188154, Accuracy: 0.4642857142857143
--------------------------------------------------
EPOCH 30/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 30 Training Loss: 2.4627515080769857, Accuracy: 0.451875


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 30 Validation Loss: 2.532383877890451, Accuracy: 0.4982142857142857
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.48928571428571427, Current epoch accuracy: 0.4982142857142857
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-30-91119256-ff3a-40c1-a0af-2044956b1323.ckpt
--------------------------------------------------
EPOCH 31/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 31 Training Loss: 2.4688870392243065, Accuracy: 0.450125


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 31 Validation Loss: 2.2024461235318866, Accuracy: 0.5196428571428572
SAVING NEW WEIGHTS!!

Previous best accuracy: 0.4982142857142857, Current epoch accuracy: 0.5196428571428572
Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-31-1102d439-4c93-4a96-b1c6-a2ea21534227.ckpt
--------------------------------------------------
EPOCH 32/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]



Epoch 32 Training Loss: 2.450085377375285, Accuracy: 0.458


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 32 Validation Loss: 2.252433810915266, Accuracy: 0.5125
--------------------------------------------------
EPOCH 33/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 33 Training Loss: 2.496230677008629, Accuracy: 0.44779166666666664


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 33 Validation Loss: 2.3494615282331193, Accuracy: 0.46964285714285714
--------------------------------------------------
EPOCH 34/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 34 Training Loss: 2.504154915571213, Accuracy: 0.4495416666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 34 Validation Loss: 2.358935151781355, Accuracy: 0.48035714285714287
--------------------------------------------------
EPOCH 35/35


Training:   0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 35 Training Loss: 2.4869305351177853, Accuracy: 0.4469166666666667


Validation:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 35 Validation Loss: 2.4348430923053197, Accuracy: 0.4660714285714286
Going to save final checkpoint
Final Checkpoint saved: /content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-final-fc4d6b21-4fcc-4220-90c4-bf93a71a4068.ckpt


{'epoch-1': {'train_loss': 3.9863444365660348,
  'train_accuracy': 0.2425,
  'val_loss': 3.9069802011762347,
  'val_accuracy': 0.23214285714285715,
  'is_checkpointed': True},
 'epoch-2': {'train_loss': 3.8473707874615988,
  'train_accuracy': 0.242125,
  'val_loss': 3.8532280649457658,
  'val_accuracy': 0.21428571428571427,
  'is_checkpointed': False},
 'epoch-3': {'train_loss': 3.761446208000183,
  'train_accuracy': 0.24316666666666667,
  'val_loss': 3.6212562424795967,
  'val_accuracy': 0.26964285714285713,
  'is_checkpointed': True},
 'epoch-4': {'train_loss': 3.6292912650903064,
  'train_accuracy': 0.2612083333333333,
  'val_loss': 3.375184828894479,
  'val_accuracy': 0.3142857142857143,
  'is_checkpointed': True},
 'epoch-5': {'train_loss': 3.5420033915837608,
  'train_accuracy': 0.27316666666666667,
  'val_loss': 3.831254277910505,
  'val_accuracy': 0.2732142857142857,
  'is_checkpointed': False},
 'epoch-6': {'train_loss': 3.114917760372162,
  'train_accuracy': 0.341208333333333

In [21]:
re_classifier.test()

Test:   0%|          | 0/60 [00:00<?, ?it/s]

Test Loss: 2.2602140804131827, Test Accuracy: 0.49270833333333336


(2.2602140804131827, 0.49270833333333336)

# 5. Evaluating Results

In [None]:
class Evaluation:
    def __init__(self, checkpoint_path, test_loader, evaluate_original=True, evaluate_finetuned=True):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.evaluate_original = evaluate_original
        self.evaluate_finetuned = evaluate_finetuned
        self.checkpoint_path = checkpoint_path
        self.test_loader = test_loader

        if self.evaluate_original:
            self.original_model = resnet50(weights=ResNet50_Weights.DEFAULT)
            self.original_model.to(self.device)
            self.original_model.eval()

        if self.evaluate_finetuned:
            self.fine_tuned_model = resnet50()
            self.load_model_from_checkpoint()
            self.fine_tuned_model.eval()

        print(f"evaluate original : {evaluate_original}")
        print(f"evaluate finetuned : {evaluate_finetuned}")

    def load_model_from_checkpoint(self):
        checkpoint = torch.load(self.checkpoint_path)

        self.fine_tuned_model.load_state_dict(checkpoint['state_dict'])
        self.fine_tuned_model.to(self.device)
        self.fine_tuned_model.eval()
        print(f"Loaded model from {self.checkpoint_path}")

    def evaluate_model(self):
        # Initialize accuracy metrics
        total = 0
        orig_correct = 0
        ft_correct = 0

        # Evaluate the original model
        for batch in tqdm(self.test_loader):
            images, labels = batch
            total += len(labels)

            images, labels = images.to(self.device), labels.to(self.device)

            # Original model predictions
            if self.evaluate_original:
                original_logits = self.original_model(images).softmax(1)
                original_predictions = original_logits.argmax(dim=1)
                orig_correct += torch.sum(original_predictions == labels).item()

            if self.evaluate_finetuned:
                ft_images = images.clone().detach().to(self.device)
                ft_labels = labels.clone().detach().to(self.device)

                # Fine-tuned model predictions
                fine_tuned_logits = self.fine_tuned_model(ft_images).softmax(1)
                fine_tuned_predictions = fine_tuned_logits.argmax(dim=1)
                ft_correct += torch.sum(fine_tuned_predictions == ft_labels).item()

        # Calculate accuracies
        result = {}
        if self.evaluate_original:
            original_accuracy = orig_correct / total
            print(f'Evaluation Original Model Accuracy: {original_accuracy * 100} %')
            result["original_accuracy"] = original_accuracy

        if self.evaluate_finetuned:
            fine_tuned_accuracy = ft_correct / total
            print(f'Evaluation Fine-Tuned Model Accuracy: {fine_tuned_accuracy * 100} %')
            result["fine_tuned_accuracy"] = fine_tuned_accuracy

        return result

In [23]:
clean_test_loader = RobustDataLoader(
    batch_num=65,
    batch_size=batch_size,
    perturbation_chance=0.0
)

perturbed_test_loader = RobustDataLoader(
    batch_num=65,
    batch_size=batch_size,
    perturbation_chance=1.0
)

mix_test_loader = RobustDataLoader(
    batch_num=65,
    batch_size=batch_size,
    perturbation_chance=0.55
)



Robust DataLoader created with batch_num = 65, batch_size = 16, perturbation_chance = 0.0
Robust DataLoader created with batch_num = 65, batch_size = 16, perturbation_chance = 1.0
Robust DataLoader created with batch_num = 65, batch_size = 16, perturbation_chance = 0.55


In [24]:
checkpoint_to_evaluate = get_newest_file(model_path)
print(checkpoint_to_evaluate)

/content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-final-fc4d6b21-4fcc-4220-90c4-bf93a71a4068.ckpt


Testing on only clean data

In [25]:
evaluation = Evaluation(checkpoint_to_evaluate, clean_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

  checkpoint = torch.load(self.checkpoint_path)


evaluate original : True
evaluate finetuned : True


  0%|          | 0/65 [00:00<?, ?it/s]

Evaluation Original Model Accuracy: 92.01923076923076 %
Evaluation Fine-Tuned Model Accuracy: 52.307692307692314 %


{'original_accuracy': 0.9201923076923076,
 'fine_tuned_accuracy': 0.5230769230769231}

Testing on only perturbed data

In [29]:
perturbed_test_loader = RobustDataLoader(
    batch_num=65,
    batch_size=batch_size,
    perturbation_chance=1.0
)


evaluation = Evaluation(checkpoint_to_evaluate, perturbed_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

Robust DataLoader created with batch_num = 65, batch_size = 16, perturbation_chance = 1.0


  checkpoint = torch.load(self.checkpoint_path)


evaluate original : True
evaluate finetuned : True


  0%|          | 0/65 [00:00<?, ?it/s]

Evaluation Original Model Accuracy: 8.75 %
Evaluation Fine-Tuned Model Accuracy: 44.13461538461539 %


{'original_accuracy': 0.0875, 'fine_tuned_accuracy': 0.44134615384615383}

Testing on mix data

In [27]:
evaluation = Evaluation(checkpoint_to_evaluate, mix_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

  checkpoint = torch.load(self.checkpoint_path)


evaluate original : True
evaluate finetuned : True


  0%|          | 0/65 [00:00<?, ?it/s]

Evaluation Original Model Accuracy: 52.78846153846154 %
Evaluation Fine-Tuned Model Accuracy: 49.80769230769231 %


{'original_accuracy': 0.5278846153846154,
 'fine_tuned_accuracy': 0.4980769230769231}

Testing another checkpoint which had highest validation accuracy

In [31]:
checkpoint_to_evaluate = "/content/drive/MyDrive/trustworthyml/model/ft_stream/ft-stream-crazy-epoch-31-1102d439-4c93-4a96-b1c6-a2ea21534227.ckpt"

In [32]:
clean_test_loader = RobustDataLoader(
    batch_num=65,
    batch_size=batch_size,
    perturbation_chance=0.0
)

perturbed_test_loader = RobustDataLoader(
    batch_num=65,
    batch_size=batch_size,
    perturbation_chance=1.0
)

mix_test_loader = RobustDataLoader(
    batch_num=65,
    batch_size=batch_size,
    perturbation_chance=0.55
)



Robust DataLoader created with batch_num = 65, batch_size = 16, perturbation_chance = 0.0
Robust DataLoader created with batch_num = 65, batch_size = 16, perturbation_chance = 1.0
Robust DataLoader created with batch_num = 65, batch_size = 16, perturbation_chance = 0.55


Clean data

In [33]:
evaluation = Evaluation(checkpoint_to_evaluate, clean_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

  checkpoint = torch.load(self.checkpoint_path)


evaluate original : True
evaluate finetuned : True


  0%|          | 0/65 [00:00<?, ?it/s]

Evaluation Original Model Accuracy: 90.38461538461539 %
Evaluation Fine-Tuned Model Accuracy: 47.01923076923077 %


{'original_accuracy': 0.9038461538461539,
 'fine_tuned_accuracy': 0.4701923076923077}

Mix Data

In [34]:
evaluation = Evaluation(checkpoint_to_evaluate, mix_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

  checkpoint = torch.load(self.checkpoint_path)


evaluate original : True
evaluate finetuned : True


  0%|          | 0/65 [00:00<?, ?it/s]

Evaluation Original Model Accuracy: 54.61538461538461 %
Evaluation Fine-Tuned Model Accuracy: 48.26923076923077 %


{'original_accuracy': 0.5461538461538461,
 'fine_tuned_accuracy': 0.4826923076923077}

Perturbed data

In [35]:
evaluation = Evaluation(checkpoint_to_evaluate, perturbed_test_loader, evaluate_original=True, evaluate_finetuned=True)
evaluation.evaluate_model()

  checkpoint = torch.load(self.checkpoint_path)


evaluate original : True
evaluate finetuned : True


  0%|          | 0/65 [00:00<?, ?it/s]

Evaluation Original Model Accuracy: 9.615384615384617 %
Evaluation Fine-Tuned Model Accuracy: 45.76923076923077 %


{'original_accuracy': 0.09615384615384616,
 'fine_tuned_accuracy': 0.4576923076923077}