# Part a) 
- load and prepare Oxford Pets Data

In [1]:
import pytorch_lightning as pl
import torch
import torchvision
from pathlib import Path
import random
import torchvision.transforms.functional as F

class RandomRotation(object):
    def __init__(self, degrees, seed=1):
        self.degrees = (-degrees, degrees)
        random.seed(seed)
    
    @staticmethod
    def get_params(degrees):
        angle = random.uniform(degrees[0], degrees[1])
        return angle

    def __call__(self, img):
        angle = self.get_params(self.degrees)
        return F.rotate(img, angle)

In [2]:
from torchvision import transforms
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]


train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

ds_train = torchvision.datasets.OxfordIIITPet(
    root=".", 
    split="trainval", 
    transform=train_transforms, 
    download=True
)

ds_test = torchvision.datasets.OxfordIIITPet(
    root=".", 
    split="test", 
    transform=test_transforms, 
    download=True
)

ds_train_normal = torchvision.datasets.OxfordIIITPet(
    root=".", 
    split="trainval", 
    transform=test_transforms, 
    download=True
)

BATCH_SIZE = 512

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

Downloading https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz to oxford-iiit-pet/images.tar.gz


  0%|          | 0/791918971 [00:00<?, ?it/s]

Extracting oxford-iiit-pet/images.tar.gz to oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/~vgg/data/pets/annotations.tar.gz to oxford-iiit-pet/annotations.tar.gz


  0%|          | 0/19173078 [00:00<?, ?it/s]

Extracting oxford-iiit-pet/annotations.tar.gz to oxford-iiit-pet


# Part b and c)
- define and train a simple CNN
- use skip connections and batch normalization


# DISCLAIMER: The Test Set accuracy is evaluated all at once at the bottom of this document to make it easier to compare them.

In [3]:
import pytorch_lightning as pl
from typing import Any
import torchmetrics
import torch

class CnnWrapper(pl.LightningModule):

    def __init__(self, 
        loss: callable, 
        lr: float, 
        architecture: torch.nn.Module, 
        classification_head:torch.nn.Module,
        num_classes:int
    ) -> None:
        super().__init__()
        self.architecture = architecture
        self.classification_head = classification_head 
        self.num_classes = num_classes
        self.loss = loss
        self.lr = lr
        
        self.test_auroc = torchmetrics.AUROC(num_classes=self.num_classes)
        self.test_acc = torchmetrics.Accuracy(num_classes=self.num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.architecture(x)
        if self.classification_head is not None:
            x = self.classification_head(x)
        return x

    def _step(self, batch) -> torch.Tensor:
        x, y = batch
        pred = self.forward(x)
        loss = self.loss(pred, y)
        return pred, loss

    def training_step(self, batch) -> torch.Tensor:
        pred, loss = self._step(batch)
        self.log("train/loss", loss)
        pred = torch.nn.functional.softmax(pred, dim=1)
        acc = torchmetrics.functional.accuracy(pred, batch[-1], num_classes=self.num_classes)
        self.log("train/acc", acc)
        return loss
    
    def _eval_step(self, batch, auroc, acc):
        pred, loss = self._step(batch)
        pred = torch.nn.functional.softmax(pred, dim=1)
        auroc.update(pred, batch[-1])
        acc.update(pred, batch[-1])
        return loss
        
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def test_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def validation_epoch_end(self, outputs) -> None:
        print(f"Val AUROC: {self.test_auroc.compute().data}")
        print(f"Val Accuracy: {self.test_acc.compute().data}")

        
    def configure_optimizers(self) -> Any:
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optim

In [4]:
simple_architecture = torch.nn.Sequential(
    torch.nn.Conv2d(3, 16, 3),
    torch.nn.MaxPool2d((4, 4)),
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, 3),
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 40, 3),
    torch.nn.ReLU(),
    torch.nn.Conv2d(40, 48, 3),
    torch.nn.ReLU()
)

class SkipBlock(torch.nn.Module):

    def __init__(self, in_channels) -> None:
        super().__init__()
        self.in_chanels = in_channels
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(self.in_chanels, self.in_chanels, kernel_size=(3, 3), padding="same"),
            torch.nn.BatchNorm2d(self.in_chanels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(self.in_chanels, self.in_chanels, kernel_size=(3, 3), padding="same"),
            torch.nn.BatchNorm2d(self.in_chanels),
        )

    def forward(self, x):
        id = x
        out = self.block(x)
        out += id
        return torch.nn.functional.relu(out)

better_architecture = torch.nn.Sequential(
    torch.nn.Conv2d(3, 8, 3), 
    torch.nn.BatchNorm2d(8),
    torch.nn.ReLU(),
    SkipBlock(8),
    torch.nn.Conv2d(8, 16, (3, 3)),
    torch.nn.MaxPool2d((4, 4)),
    SkipBlock(16),
    SkipBlock(16),
    torch.nn.Conv2d(16, 32, (3, 3)),
    torch.nn.MaxPool2d((2, 2)),
    SkipBlock(32),
    SkipBlock(32),
    torch.nn.Conv2d(32, 40, (3, 3)),
    torch.nn.Dropout(),
    SkipBlock(40),
    SkipBlock(40),    
    torch.nn.Conv2d(40, 48, (3, 3)),
    torch.nn.Dropout(),
    SkipBlock(48),
    SkipBlock(48),
)

In [5]:
num_classes = len(ds_train.classes)
loss = torch.nn.CrossEntropyLoss()
lr = 1e-3
epochs = 150

In [6]:
classification_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(48, num_classes),
)

simple_model = CnnWrapper(loss, lr, simple_architecture, classification_head, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=10, check_val_every_n_epoch=5)

# i use the test set as the validation set, to get a feeling of how the network is performing, but no actions are taken on it so it should be fine
# this comment also applies to every following network training loop
trainer.fit(simple_model, dl_train, val_dataloaders=[dl_test])



Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.1438794732093811
Val Accuracy: 0.0380859375




Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.5918474197387695
Val Accuracy: 0.05391008034348488


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6257268786430359
Val Accuracy: 0.06290361285209656


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6407859325408936
Val Accuracy: 0.06616241484880447


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6545436382293701
Val Accuracy: 0.07133758068084717


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6638306379318237
Val Accuracy: 0.07501678168773651


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6694058775901794
Val Accuracy: 0.07726365327835083


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.674612820148468
Val Accuracy: 0.08035346865653992


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6807834506034851
Val Accuracy: 0.08398077636957169


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6853423118591309
Val Accuracy: 0.08667939156293869


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6897886395454407
Val Accuracy: 0.08917112648487091


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.693607747554779
Val Accuracy: 0.09175264835357666


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6971468925476074
Val Accuracy: 0.09418005496263504


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7002459764480591
Val Accuracy: 0.09591346979141235


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7035743594169617
Val Accuracy: 0.09822484850883484


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7064107656478882
Val Accuracy: 0.10037638992071152


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7092479467391968
Val Accuracy: 0.10221336781978607


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7123202681541443
Val Accuracy: 0.1046106293797493


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7143179774284363
Val Accuracy: 0.10610443353652954


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7166846990585327
Val Accuracy: 0.10766947269439697


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7188039422035217
Val Accuracy: 0.10917423665523529


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7210561633110046
Val Accuracy: 0.11074250936508179


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7233818173408508
Val Accuracy: 0.11209660023450851


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7256131768226624
Val Accuracy: 0.11334605515003204


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7275205254554749
Val Accuracy: 0.11463852971792221


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.729231595993042
Val Accuracy: 0.11591499298810959


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7307936549186707
Val Accuracy: 0.11713580787181854


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7316803932189941
Val Accuracy: 0.11753774434328079


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7334133982658386
Val Accuracy: 0.11859554797410965


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7351413369178772
Val Accuracy: 0.11981382220983505


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7367457151412964
Val Accuracy: 0.12096063047647476


In [7]:
classification_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(48, num_classes),
)

better_model = CnnWrapper(loss, lr, better_architecture, classification_head, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=10, check_val_every_n_epoch=5)

trainer.fit(better_model, dl_train, val_dataloaders=[dl_test])



Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.13891097903251648
Val Accuracy: 0.0029296875




Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6641319990158081
Val Accuracy: 0.05902407690882683


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6758483052253723
Val Accuracy: 0.07737383246421814


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6993874907493591
Val Accuracy: 0.08827196061611176


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7088572978973389
Val Accuracy: 0.09273885190486908


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7275397777557373
Val Accuracy: 0.10702668875455856


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7274879217147827
Val Accuracy: 0.10782185941934586


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7280219793319702
Val Accuracy: 0.11049537360668182


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7400941848754883
Val Accuracy: 0.12111535668373108


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7495619058609009
Val Accuracy: 0.1309737116098404


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7486884593963623
Val Accuracy: 0.13430026173591614


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7493529915809631
Val Accuracy: 0.13631200790405273


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7510527968406677
Val Accuracy: 0.139860600233078


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7601968050003052
Val Accuracy: 0.1477186381816864


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.768867015838623
Val Accuracy: 0.1555449515581131


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7755992412567139
Val Accuracy: 0.1622576266527176


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7832453846931458
Val Accuracy: 0.17181222140789032


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.788700520992279
Val Accuracy: 0.17735855281352997


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7953680157661438
Val Accuracy: 0.18598097562789917


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8023134469985962
Val Accuracy: 0.1935534030199051


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8065605759620667
Val Accuracy: 0.19922316074371338


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8122743964195251
Val Accuracy: 0.20666556060314178


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8159444332122803
Val Accuracy: 0.21094419062137604


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8203426003456116
Val Accuracy: 0.21602603793144226


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8246272206306458
Val Accuracy: 0.22177818417549133


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8234084844589233
Val Accuracy: 0.22204013168811798


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8279910087585449
Val Accuracy: 0.2288784235715866


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8308993577957153
Val Accuracy: 0.23270754516124725


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.833551287651062
Val Accuracy: 0.23701761662960052


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8365272879600525
Val Accuracy: 0.2414056360721588


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8384901881217957
Val Accuracy: 0.24495472013950348


# Part d mixup

- Mixup just modifes the data, so we can just write a custom Dataset that does the mixup
- We do not need to change network interners since CrossEntropyCost works in multiclass classification

In [8]:
from typing import Tuple
import torch
import numpy as np

class MixupPets(torch.utils.data.Dataset):

    def __init__(self, pets_ds: torch.utils.data.Dataset) -> None:
        super().__init__()
        self.ds_pets = pets_ds
        self.num_classes = len(pets_ds.classes)
        
    def __len__(self) -> int:
        return len(self.ds_pets)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        x, y = self.ds_pets[idx]

        l = torch.zeros(self.num_classes)
        l[y] = 1.

        rv = random.randint(0, self.__len__() - 1)
        x_m, y_m = self.ds_pets[rv]
        l_m = torch.zeros(self.num_classes)
        l_m[y_m] = 1

        alpha = 0.2
        lam = np.random.beta(alpha, alpha)
        image = lam * x + (1 - lam) * x_m
        label = lam * l + (1 - lam) * l_m

        return image, label


In [9]:
ds_mixup = MixupPets(ds_train_normal)
dl_mixup = torch.utils.data.DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

classification_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(48, num_classes),
)

mixup_model = CnnWrapper(loss, lr, better_architecture, classification_head, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=10, check_val_every_n_epoch=5)

trainer.fit(mixup_model, dl_mixup, val_dataloaders=[dl_test])

Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.14747370779514313
Val Accuracy: 0.0


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7042688131332397
Val Accuracy: 0.09716599434614182


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7593755125999451
Val Accuracy: 0.1397990882396698


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.791878342628479
Val Accuracy: 0.17579586803913116


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8035458922386169
Val Accuracy: 0.1931210160255432


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8261843323707581
Val Accuracy: 0.2270638644695282


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8408961296081543
Val Accuracy: 0.2523656487464905


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8500735759735107
Val Accuracy: 0.26869359612464905


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8533228039741516
Val Accuracy: 0.2728798985481262


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8574577569961548
Val Accuracy: 0.27948305010795593


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8651895523071289
Val Accuracy: 0.29418784379959106


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.871636152267456
Val Accuracy: 0.30710679292678833


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8758500814437866
Val Accuracy: 0.31459203362464905


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8790343403816223
Val Accuracy: 0.3209909498691559


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.883802056312561
Val Accuracy: 0.33138003945350647


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8850723505020142
Val Accuracy: 0.33388036489486694


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8872936367988586
Val Accuracy: 0.3375636339187622


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8904999494552612
Val Accuracy: 0.3453160226345062


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.892476499080658
Val Accuracy: 0.3492977023124695


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8933688998222351
Val Accuracy: 0.35187673568725586


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8947128057479858
Val Accuracy: 0.35371753573417664


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8973397016525269
Val Accuracy: 0.3598170876502991


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8982574939727783
Val Accuracy: 0.3619069755077362


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8982824683189392
Val Accuracy: 0.3626464903354645


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8998205065727234
Val Accuracy: 0.3665020167827606


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.90118408203125
Val Accuracy: 0.37052690982818604


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9018569588661194
Val Accuracy: 0.37257567048072815


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9035568237304688
Val Accuracy: 0.3775315582752228


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9047836065292358
Val Accuracy: 0.38078761100769043


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9058112502098083
Val Accuracy: 0.38312309980392456


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9059702157974243
Val Accuracy: 0.38417917490005493


# Part e)

- Install timm and torchinfo to get a feeling of how the loaded model looks
- Freeze model excluding the classification head
- Finetune

In [10]:
# !pip install timm
# !pip install torchinfo==1.6.3

Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.6.12
[0mCollecting torchinfo==1.6.3
  Downloading torchinfo-1.6.3-py3-none-any.whl (20 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.6.3
[0m

In [11]:
from torchinfo import summary
import timm

#timm.list_models()
timm_model = timm.create_model('resnest50d', pretrained=True, num_classes=num_classes)
# inspect the loaded model to get a feeling for how it looks and what to change / freeze
summary(timm_model, input_size = (BATCH_SIZE, 3, 224, 224))


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth" to /root/.cache/torch/hub/checkpoints/resnest50-528c19ca.pth


Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   --                        --
├─Sequential: 1-1                        [512, 64, 112, 112]       --
│    └─Conv2d: 2-1                       [512, 32, 112, 112]       864
│    └─BatchNorm2d: 2-2                  [512, 32, 112, 112]       64
│    └─ReLU: 2-3                         [512, 32, 112, 112]       --
│    └─Conv2d: 2-4                       [512, 32, 112, 112]       9,216
│    └─BatchNorm2d: 2-5                  [512, 32, 112, 112]       64
│    └─ReLU: 2-6                         [512, 32, 112, 112]       --
│    └─Conv2d: 2-7                       [512, 64, 112, 112]       18,432
├─BatchNorm2d: 1-2                       [512, 64, 112, 112]       128
├─ReLU: 1-3                              [512, 64, 112, 112]       --
├─MaxPool2d: 1-4                         [512, 64, 56, 56]         --
├─Sequential: 1-5                        [512, 256, 56, 56]        --
│    └

In [12]:
# freeze model layers that don't need finetuning
def freeze_timm_model(model):
    '''Freeze all layers except the last layer(fc or classifier)'''
    for param in model.parameters():
            param.requires_grad = False
    # reset the last two layers to require grad
    model.fc.weight.requires_grad = True
    model.fc.bias.requires_grad = True
    
freeze_timm_model(timm_model)

In [13]:
# Finetune the network
epochs = 30 # we don't need as many epochs 
pretrained_model = CnnWrapper(loss, lr, timm_model, None, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=5, check_val_every_n_epoch=5)

trainer.fit(pretrained_model, dl_train, val_dataloaders=[dl_test])

Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.18190491199493408
Val Accuracy: 0.0927734375


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9670823812484741
Val Accuracy: 0.679096519947052


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9810598492622375
Val Accuracy: 0.7633341550827026


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9867335557937622
Val Accuracy: 0.798271119594574


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9896060228347778
Val Accuracy: 0.818980872631073


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9913065433502197
Val Accuracy: 0.8313283920288086


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.9925006628036499
Val Accuracy: 0.8409150242805481


# Run the Test Step for every Network to see how the performance differs

In [14]:
print(f"Testing Simple CNN")
trainer.test(simple_model, dl_test)
print(f"Testing Better CNN")
trainer.test(better_model, dl_test)
print(f"Testing Better CNN with Mixup")
trainer.test(mixup_model, dl_test)
print(f"Testing Finetuned Resnet50d")
trainer.test(pretrained_model, dl_test)

Testing Simple CNN


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.738244891166687
Test Accuracy: 0.12203410267829895


Testing Better CNN


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.8288923501968384
Test Accuracy: 0.2403910607099533


Testing Better CNN with Mixup


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.90616375207901
Test Accuracy: 0.3851676881313324


Testing Finetuned Resnet50d


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.9933205842971802
Test Accuracy: 0.8478676080703735


[{'test': 0.3644622564315796}]

# Findings:
- The performance of the networks increases with each subtask -> New Architectures + model training methods improve performance
- Models probably overfited a little bit, the performance of the first 3 models was better in a test run with only ~80 epochs. 
- Finetunig a Pretrained Model is way less expansive and often performs better than handcrafted models 
    - Due to training set size, computing power, etc