In [1]:
import pandas as pd
import numpy as np
from skimage import io, transform

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils, models
import torchvision as tv
from torchvision.transforms import v2

import lightning as L
import torchmetrics as tm
import torch.nn.functional as F

### Creating the dataset and dataloader from `imagenet-mini`

In [2]:
labelLookup = pd.read_csv("imagenet-words.txt", delimiter='\t', names=['label'], header=None, index_col=0)['label']

In [3]:
labelLookup.head()

n00001740                          entity
n00001930                 physical entity
n00002137    abstraction, abstract entity
n00002452                           thing
n00002684         object, physical object
Name: label, dtype: object

In [4]:
# convert PIL image into torch Tensor then does specified transforms from docs: 
# https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
ds_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
train_whole_dataset = tv.datasets.ImageFolder("imagenet-mini/train", transform=ds_transforms)
test_dataset = tv.datasets.ImageFolder("imagenet-mini/val", transform=ds_transforms)

In [6]:
train_whole_dataset

Dataset ImageFolder
    Number of datapoints: 34745
    Root location: imagenet-mini/train
    StandardTransform
Transform: Compose(
                 ToImage()
                 Resize(size=[256], interpolation=InterpolationMode.BILINEAR, antialias=True)
                 CenterCrop(size=(224, 224))
                 ToDtype(scale=True)
                 Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
           )

In [7]:
test_dataset

Dataset ImageFolder
    Number of datapoints: 3923
    Root location: imagenet-mini/val
    StandardTransform
Transform: Compose(
                 ToImage()
                 Resize(size=[256], interpolation=InterpolationMode.BILINEAR, antialias=True)
                 CenterCrop(size=(224, 224))
                 ToDtype(scale=True)
                 Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
           )

In [8]:
train_dataset, val_dataset = random_split(train_whole_dataset, [.9, .1])

### Creating Resnet models & Training

In [10]:
L.seed_everything(42)

[rank: 0] Seed set to 42


42

In [11]:
resnets = {
        18: models.resnet18,
        34: models.resnet34,
        50: models.resnet50,
        101: models.resnet101,
        152: models.resnet152,
    }

In [12]:
class ResnetClassifier(L.LightningModule):
    def __init__(self, variant, lr=1e-2):
        super().__init__()
        self.save_hyperparameters()
        if variant not in resnets:
            raise ValueError("`variant` argument is invalid (should be [18, 34, 50, 101, 152])")
        self.resnet_model = resnets[variant](weights=None)
        self.accuracy = tm.classification.Accuracy(task="multiclass", num_classes=1000)
        self.lr = lr

    def forward(self, x):
        return self.resnet_model(x)

    def _batch_step(self, batch, batch_kind):
        if batch_kind == 'train':
            self.resnet_model.train()
        else:
            self.resnet_model.eval()
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        # logging onto tensorboard
        self.log(f"{batch_kind}_loss", loss, prog_bar=True)
        self.log(f"{batch_kind}_acc_f1", acc, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._batch_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self._batch_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        return self._batch_step(batch, 'test')

    def predict_step(self, batch, batch_idx):
        self.eval()
        x, _ = batch
        return self(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [12]:
resnet50_model = ResnetClassifier(50)

In [None]:
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, BATCH_SIZE, num_workers=4)

In [None]:
trainer = L.Trainer(callbacks=[L.pytorch.callbacks.EarlyStopping(monitor="val_loss", mode="min")], max_epochs=80)

In [None]:
trainer.fit(resnet50_model, train_loader, val_loader)

In [None]:
CKPT_PATH = 'resnet50-imagenetmini-raw.ckpt'

In [None]:
trainer.save_checkpoint("resnet50-imagenetmini-raw.ckpt")

---
### Evaluating models

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:
CKPT_PATH = 'resnet50-imagenetmini-raw-SD-augmented-2.ckpt'

In [None]:
test_loader = DataLoader(test_dataset, BATCH_SIZE, num_workers=4)

In [None]:
y = np.array([test_dataset[i][1] for i in range(len(test_dataset))])

In [None]:
#loaded_model = ResnetClassifier.load_from_checkpoint(checkpoint_path=CKPT_PATH)
loaded_model = ResnetClassifier(50)
checkpoint = torch.load(CKPT_PATH)
loaded_model.load_state_dict(checkpoint["state_dict"])

In [None]:
trainer = L.Trainer(callbacks=[L.pytorch.callbacks.EarlyStopping(monitor="val_loss", mode="min")], max_epochs=80)

In [None]:
trainer.test(loaded_model, train_loader)

In [None]:
trainer.test(loaded_model, val_loader)

In [None]:
trainer.test(loaded_model, test_loader)

---

### Automating Evaluation

In [None]:
!mamba install scikit-learn -y

In [13]:
torch.set_float32_matmul_precision('high')

In [14]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

In [15]:
BATCH_SIZE = 64
test_loader = DataLoader(test_dataset, BATCH_SIZE, num_workers=4)
# train_whole_loader = DataLoader(train_whole_dataset, BATCH_SIZE)
val_loader = DataLoader(val_dataset, BATCH_SIZE, num_workers=4)

In [21]:
y = test_dataset.targets
# y_train = train_whole_dataset.targets
y_val = np.array([val_dataset[i][1] for i in range(len(val_dataset))])

In [17]:
def get_test_preds(loaded_model, test_loader):
    trainer = L.Trainer()
    loaded_model.freeze()

    predictions_list = trainer.predict(loaded_model, test_loader) # 30-len list of 32 x 20 tensors
    predictions = torch.vstack(predictions_list).numpy() # 952 x 20
    top_preds = predictions.argmax(axis=1).flatten()

    return top_preds, predictions

def top_preds(all_predictions):
    return (
        np.argsort(all_predictions, axis=1)[:, -5:],
        np.argsort(all_predictions, axis=1)[:, -3:]
    )

def get_topk_accuracy(top_preds, ground_truths):
    ground_truths = np.array(ground_truths)
    #check if ground truth class lies somewhere in the top k
    #check if any of the top 5 predicted classes match the ground truth class
    # print(top_preds.shape)
    ground_truths = ground_truths.reshape(-1, 1)
    matches = np.any(top_preds == ground_truths, axis=1)

    # Count the number of matches
    num_matches = np.sum(matches)
    # print(num_matches)

    # Calculate the percentage of images where at least one of the top 5 predictions matches the ground truth
    percentage_matches = (num_matches / top_preds.shape[0]) * 100
    return percentage_matches

def performance_metrics(predictions, ground_truth, metric_type="Test"):
    accuracy = accuracy_score(ground_truth, predictions)
    recall = recall_score(ground_truth, predictions, average='weighted')
    precision = precision_score(ground_truth, predictions, average='weighted')
    f1 = f1_score(ground_truth, predictions, average='weighted')

    print(f"{metric_type} Accuracy: {accuracy}")
    print(f"{metric_type} Recall: {recall}")
    print(f"{metric_type} Precision: {precision}")
    print(f"{metric_type} F1 Score: {f1}")

In [18]:
CKPTS = [
    "resnet50-imagenetmini-raw.ckpt",
    "resnet50-imagenetmini-raw-SD-only.ckpt",
    "resnet50-imagenetmini-raw-SD-augmented.ckpt",
    "resnet50-imagenetmini-raw-SD-augmented-2.ckpt",
]

In [23]:
for ckpt in CKPTS:
    print(f"---{ckpt}---")
    loaded_model = ResnetClassifier(50)
    checkpoint = torch.load(ckpt)
    loaded_model.load_state_dict(checkpoint["state_dict"])
    resnet_pred, resnet_all_pred = get_test_preds(loaded_model, test_loader)
    
    resnet_top5, resnet_top3 = top_preds(resnet_all_pred)
    print(f"acc@5 (top 5): {get_topk_accuracy(resnet_top5, y)}")
    print(f"acc@3 (top 3): {get_topk_accuracy(resnet_top3, y)}")
    
    performance_metrics(resnet_pred, y)
    print()

---resnet50-imagenetmini-raw.ckpt---


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 11.419831761407085
acc@3 (top 3): 8.233494774407342
Test Accuracy: 0.036961509049197046
Test Recall: 0.036961509049197046
Test Precision: 0.029064363131703703
Test F1 Score: 0.025455297923017692

---resnet50-imagenetmini-raw-SD-only.ckpt---


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 4.511853173591639
acc@3 (top 3): 3.2373183787917412
Test Accuracy: 0.013510068824878919
Test Recall: 0.013510068824878919
Test Precision: 0.012859353017196767
Test F1 Score: 0.009902197065871457

---resnet50-imagenetmini-raw-SD-augmented.ckpt---


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 20.469028804486364
acc@3 (top 3): 15.166964058118786
Test Accuracy: 0.07723680856487382
Test Recall: 0.07723680856487382
Test Precision: 0.07879437079020538
Test F1 Score: 0.06602193894976481

---resnet50-imagenetmini-raw-SD-augmented-2.ckpt---


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 18.37879174101453
acc@3 (top 3): 13.535559520774918
Test Accuracy: 0.06066785623247515
Test Recall: 0.06066785623247515
Test Precision: 0.0596396036438425
Test F1 Score: 0.05019554290424637



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [23]:
for ckpt in CKPTS:
    print(f"---{ckpt}---")
    loaded_model = ResnetClassifier(50)
    checkpoint = torch.load(ckpt)
    loaded_model.load_state_dict(checkpoint["state_dict"])
    resnet_pred, resnet_all_pred = get_test_preds(loaded_model, val_loader)
    
    resnet_top5, resnet_top3 = top_preds(resnet_all_pred)
    print(f"acc@5 (top 5): {get_topk_accuracy(resnet_top5, y_val)}")
    print(f"acc@3 (top 3): {get_topk_accuracy(resnet_top3, y_val)}")
    
    performance_metrics(resnet_pred, y_val, metric_type="Val")
    print()

---resnet50-imagenetmini-raw.ckpt---


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 23.51755900978699
acc@3 (top 3): 18.07714450201497
Val Accuracy: 0.08347725964306275
Val Recall: 0.08347725964306275
Val Precision: 0.0855904093840094
Val F1 Score: 0.065577967596862

---resnet50-imagenetmini-raw-SD-only.ckpt---


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 4.864709268854347
acc@3 (top 3): 3.3390903857225105
Val Accuracy: 0.013241220495106506
Val Recall: 0.013241220495106506
Val Precision: 0.015230656951810747
Val F1 Score: 0.011305410142877035

---resnet50-imagenetmini-raw-SD-augmented.ckpt---


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 66.2636729994243
acc@3 (top 3): 58.693149107656886
Val Accuracy: 0.41508347725964306
Val Recall: 0.41508347725964306
Val Precision: 0.4849336686187748
Val F1 Score: 0.4037153090725819

---resnet50-imagenetmini-raw-SD-augmented-2.ckpt---


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Predicting: |          | 0/? [00:00<?, ?it/s]

acc@5 (top 5): 50.77720207253886
acc@3 (top 3): 42.74611398963731
Val Accuracy: 0.2645365572826713
Val Recall: 0.2645365572826713
Val Precision: 0.31281753297661347
Val F1 Score: 0.2500595773058515



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
