In [1]:
import pytorch_lightning as pl
import torch
import torchvision
from pathlib import Path
import random
import torchvision.transforms.functional as F

class RandomRotation(object):
    def __init__(self, degrees, seed=1):
        self.degrees = (-degrees, degrees)
        random.seed(seed)
    
    @staticmethod
    def get_params(degrees):
        angle = random.uniform(degrees[0], degrees[1])
        return angle

    def __call__(self, img):
        angle = self.get_params(self.degrees)
        return F.rotate(img, angle)

In [2]:
from torchvision import transforms
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]


train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

ds_train = torchvision.datasets.OxfordIIITPet(
    root=".", 
    split="trainval", 
    transform=train_transforms, 
    download=True
)

ds_test = torchvision.datasets.OxfordIIITPet(
    root=".", 
    split="test", 
    transform=test_transforms, 
    download=True
)

ds_train_normal = torchvision.datasets.OxfordIIITPet(
    root=".", 
    split="trainval", 
    transform=test_transforms, 
    download=True
)

BATCH_SIZE = 512

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

Downloading https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz to oxford-iiit-pet/images.tar.gz


  0%|          | 0/791918971 [00:00<?, ?it/s]

Extracting oxford-iiit-pet/images.tar.gz to oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/~vgg/data/pets/annotations.tar.gz to oxford-iiit-pet/annotations.tar.gz


  0%|          | 0/19173078 [00:00<?, ?it/s]

Extracting oxford-iiit-pet/annotations.tar.gz to oxford-iiit-pet


In [3]:
import pytorch_lightning as pl
from typing import Any
import torchmetrics
import torch

class CnnWrapper(pl.LightningModule):

    def __init__(self, 
        loss: callable, 
        lr: float, 
        architecture: torch.nn.Module, 
        classification_head:torch.nn.Module,
        num_classes:int
    ) -> None:
        super().__init__()
        self.architecture = architecture
        self.classification_head = classification_head 
        self.num_classes = num_classes
        self.loss = loss
        self.lr = lr
        
        self.test_auroc = torchmetrics.AUROC(num_classes=self.num_classes)
        self.test_acc = torchmetrics.Accuracy(num_classes=self.num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.architecture(x)
        if self.classification_head is not None:
            x = self.classification_head(x)
        return x

    def _step(self, batch) -> torch.Tensor:
        x, y = batch
        pred = self.forward(x)
        loss = self.loss(pred, y)
        return pred, loss

    def training_step(self, batch) -> torch.Tensor:
        pred, loss = self._step(batch)
        self.log("train/loss", loss)
        pred = torch.nn.functional.softmax(pred, dim=1)
        acc = torchmetrics.functional.accuracy(pred, batch[-1], num_classes=self.num_classes)
        self.log("train/acc", acc)
        return loss
    
    def _eval_step(self, batch, auroc, acc):
        pred, loss = self._step(batch)
        pred = torch.nn.functional.softmax(pred, dim=1)
        auroc.update(pred, batch[-1])
        acc.update(pred, batch[-1])
        return loss
        
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def test_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def validation_epoch_end(self, outputs) -> None:
        print(f"Val AUROC: {self.test_auroc.compute().data}")
        print(f"Val Accuracy: {self.test_acc.compute().data}")

        
    def configure_optimizers(self) -> Any:
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optim

In [4]:
simple_architecture = torch.nn.Sequential(
    torch.nn.Conv2d(3, 16, 3),
    torch.nn.MaxPool2d((4, 4)),
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, 3),
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 40, 3),
    torch.nn.ReLU(),
    torch.nn.Conv2d(40, 48, 3),
    torch.nn.ReLU()
)

class SkipBlock(torch.nn.Module):

    def __init__(self, in_channels) -> None:
        super().__init__()
        self.in_chanels = in_channels
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(self.in_chanels, self.in_chanels, kernel_size=(3, 3), padding="same"),
            torch.nn.BatchNorm2d(self.in_chanels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(self.in_chanels, self.in_chanels, kernel_size=(3, 3), padding="same"),
            torch.nn.BatchNorm2d(self.in_chanels),
        )

    def forward(self, x):
        id = x
        out = self.block(x)
        out += id
        return torch.nn.functional.relu(out)

better_architecture = torch.nn.Sequential(
    torch.nn.Conv2d(3, 8, 3), 
    torch.nn.BatchNorm2d(8),
    torch.nn.ReLU(),
    SkipBlock(8),
    torch.nn.Conv2d(8, 16, (3, 3)),
    torch.nn.MaxPool2d((4, 4)),
    SkipBlock(16),
    SkipBlock(16),
    torch.nn.Conv2d(16, 32, (3, 3)),
    torch.nn.MaxPool2d((2, 2)),
    SkipBlock(32),
    SkipBlock(32),
    torch.nn.Conv2d(32, 40, (3, 3)),
    torch.nn.Dropout(),
    SkipBlock(40),
    SkipBlock(40),    
    torch.nn.Conv2d(40, 48, (3, 3)),
    torch.nn.Dropout(),
    SkipBlock(48),
    SkipBlock(48),
)

In [5]:
num_classes = len(ds_train.classes)
loss = torch.nn.CrossEntropyLoss()
lr = 1e-3
epochs = 100

In [6]:
classification_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(48, num_classes),
)

simple_model = CnnWrapper(loss, lr, simple_architecture, classification_head, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=5, check_val_every_n_epoch=5)


trainer.fit(simple_model, dl_train, val_dataloaders=[dl_test])



Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.15117962658405304
Val Accuracy: 0.09765625




Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.5981079936027527
Val Accuracy: 0.06222032755613327


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6233459711074829
Val Accuracy: 0.06350155174732208


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6412146687507629
Val Accuracy: 0.06940404325723648


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6538829207420349
Val Accuracy: 0.07178343832492828


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6615874767303467
Val Accuracy: 0.07615261524915695


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6682275533676147
Val Accuracy: 0.07843562960624695


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6741829514503479
Val Accuracy: 0.08095256239175797


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6790103316307068
Val Accuracy: 0.08282855153083801


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6822530031204224
Val Accuracy: 0.0840945839881897


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6853702664375305
Val Accuracy: 0.08604232221841812


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6887919902801514
Val Accuracy: 0.08834545314311981


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.691373348236084
Val Accuracy: 0.09002929925918579


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.694680392742157
Val Accuracy: 0.0920342355966568


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6969338655471802
Val Accuracy: 0.09316663444042206


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6990909576416016
Val Accuracy: 0.09459676593542099


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7016346454620361
Val Accuracy: 0.09640369564294815


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7036406397819519
Val Accuracy: 0.09785951673984528


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7056942582130432
Val Accuracy: 0.10019980370998383


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7073661088943481
Val Accuracy: 0.1015763059258461


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7090961337089539
Val Accuracy: 0.10313961654901505


In [7]:
classification_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(48, num_classes),
)

better_model = CnnWrapper(loss, lr, better_architecture, classification_head, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=5, check_val_every_n_epoch=5)

#trainer = pl.Trainer(max_epochs=epochs, log_every_n_steps=10)
trainer.fit(better_model, dl_train, val_dataloaders=[dl_test])



Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.1388256549835205
Val Accuracy: 0.0




Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6531271934509277
Val Accuracy: 0.04815682768821716


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6726138591766357
Val Accuracy: 0.07223152369260788


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.69271320104599
Val Accuracy: 0.08311861008405685


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7014686465263367
Val Accuracy: 0.0903184711933136


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7115429043769836
Val Accuracy: 0.09783674776554108


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7151222229003906
Val Accuracy: 0.10217900574207306


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7237372994422913
Val Accuracy: 0.11098214238882065


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7333308458328247
Val Accuracy: 0.11927179247140884


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7447397708892822
Val Accuracy: 0.1297987997531891


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7527160048484802
Val Accuracy: 0.13578511774539948


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7609254717826843
Val Accuracy: 0.14201483130455017


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7667263150215149
Val Accuracy: 0.1496049016714096


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7727746963500977
Val Accuracy: 0.15436875820159912


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.778714120388031
Val Accuracy: 0.16050772368907928


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7827345728874207
Val Accuracy: 0.16518311202526093


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7875475287437439
Val Accuracy: 0.17085789144039154


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7916057705879211
Val Accuracy: 0.175197571516037


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7953654527664185
Val Accuracy: 0.17930099368095398


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8014101982116699
Val Accuracy: 0.18659786880016327


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8068161606788635
Val Accuracy: 0.1936051845550537


In [8]:
from typing import Tuple
import torch
import numpy as np

class MixupPets(torch.utils.data.Dataset):

    def __init__(self, pets_ds: torch.utils.data.Dataset) -> None:
        super().__init__()
        self.ds_pets = pets_ds
        self.num_classes = len(pets_ds.classes)
        
    def __len__(self) -> int:
        return len(self.ds_pets)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        x, y = self.ds_pets[idx]

        l = torch.zeros(self.num_classes)
        l[y] = 1.

        rv = random.randint(0, self.__len__() - 1)
        x_m, y_m = self.ds_pets[rv]
        l_m = torch.zeros(self.num_classes)
        l_m[y_m] = 1

        alpha = 0.2
        lam = np.random.beta(alpha, alpha)
        image = lam * x + (1 - lam) * x_m
        label = lam * l + (1 - lam) * l_m

        return image, label


In [9]:
ds_mixup = MixupPets(ds_train_normal)
dl_mixup = torch.utils.data.DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

classification_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(48, num_classes),
)

mixup_model = CnnWrapper(loss, lr, better_architecture, classification_head, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=5, check_val_every_n_epoch=5)

trainer.fit(mixup_model, dl_mixup, val_dataloaders=[dl_test])

Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.18091297149658203
Val Accuracy: 0.09765625


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.6916982531547546
Val Accuracy: 0.08906882256269455


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.741266131401062
Val Accuracy: 0.12126285582780838


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7736793756484985
Val Accuracy: 0.15509933233261108


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.7982895374298096
Val Accuracy: 0.18484076857566833


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8095681667327881
Val Accuracy: 0.19665445387363434


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8214246034622192
Val Accuracy: 0.21169371902942657


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8328657746315002
Val Accuracy: 0.22634515166282654


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8301730155944824
Val Accuracy: 0.228041872382164


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8376994132995605
Val Accuracy: 0.24074019491672516


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8421407341957092
Val Accuracy: 0.24972158670425415


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8487892746925354
Val Accuracy: 0.26034843921661377


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8531479239463806
Val Accuracy: 0.26809021830558777


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8585013747215271
Val Accuracy: 0.27741631865501404


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8627896904945374
Val Accuracy: 0.2849017083644867


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8658816814422607
Val Accuracy: 0.28985533118247986


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8694906830787659
Val Accuracy: 0.2974986732006073


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8696590065956116
Val Accuracy: 0.29879963397979736


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8724551200866699
Val Accuracy: 0.304535835981369


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8756592273712158
Val Accuracy: 0.3112037777900696


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8772422671318054
Val Accuracy: 0.3149024248123169


In [10]:
!pip install timm
!pip install torchinfo==1.6.3

Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.6.12
[0mCollecting torchinfo==1.6.3
  Downloading torchinfo-1.6.3-py3-none-any.whl (20 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.6.3
[0m

In [11]:
from torchinfo import summary
import timm

#timm.list_models()
timm_model = timm.create_model('resnest50d', pretrained=True, num_classes=num_classes)
summary(timm_model, input_size = (BATCH_SIZE, 3, 224, 224))


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth" to /root/.cache/torch/hub/checkpoints/resnest50-528c19ca.pth


Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   --                        --
├─Sequential: 1-1                        [512, 64, 112, 112]       --
│    └─Conv2d: 2-1                       [512, 32, 112, 112]       864
│    └─BatchNorm2d: 2-2                  [512, 32, 112, 112]       64
│    └─ReLU: 2-3                         [512, 32, 112, 112]       --
│    └─Conv2d: 2-4                       [512, 32, 112, 112]       9,216
│    └─BatchNorm2d: 2-5                  [512, 32, 112, 112]       64
│    └─ReLU: 2-6                         [512, 32, 112, 112]       --
│    └─Conv2d: 2-7                       [512, 64, 112, 112]       18,432
├─BatchNorm2d: 1-2                       [512, 64, 112, 112]       128
├─ReLU: 1-3                              [512, 64, 112, 112]       --
├─MaxPool2d: 1-4                         [512, 64, 56, 56]         --
├─Sequential: 1-5                        [512, 256, 56, 56]        --
│    └

In [12]:
def freeze_timm_model(model):
    '''Freeze all layers except the last layer(fc or classifier)'''
    for param in model.parameters():
            param.requires_grad = False
    # reset the last two layers to require grad
    model.fc.weight.requires_grad = True
    model.fc.bias.requires_grad = True
    
freeze_timm_model(timm_model)

In [13]:
pretrained_model = CnnWrapper(loss, lr, timm_model, None, num_classes)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=epochs, log_every_n_steps=5, check_val_every_n_epoch=5)

trainer.fit(better_model, dl_train, val_dataloaders=[dl_test])

Sanity Checking: 0it [00:00, ?it/s]

Val AUROC: 0.8011611700057983
Val Accuracy: 0.1910828799009323


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8066568374633789
Val Accuracy: 0.1984904557466507


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8114463686943054
Val Accuracy: 0.20493923127651215


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8168789744377136
Val Accuracy: 0.21337421238422394


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8222010135650635
Val Accuracy: 0.22124433517456055


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8273944854736328
Val Accuracy: 0.22945837676525116


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8320415616035461
Val Accuracy: 0.23733092844486237


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8341246247291565
Val Accuracy: 0.24144752323627472


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8383299708366394
Val Accuracy: 0.2490551620721817


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.84098219871521
Val Accuracy: 0.2534186542034149


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8435371518135071
Val Accuracy: 0.25827252864837646


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.847313642501831
Val Accuracy: 0.2649347484111786


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8508251309394836
Val Accuracy: 0.2706938087940216


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8540908098220825
Val Accuracy: 0.27653199434280396


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8568659424781799
Val Accuracy: 0.28162217140197754


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8594464659690857
Val Accuracy: 0.28636470437049866


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.861634373664856
Val Accuracy: 0.29053470492362976


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8634628057479858
Val Accuracy: 0.29427942633628845


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8646062016487122
Val Accuracy: 0.29752597212791443


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.866421103477478
Val Accuracy: 0.30124226212501526


Validation: 0it [00:00, ?it/s]

Val AUROC: 0.8681001663208008
Val Accuracy: 0.30503734946250916


In [14]:
print(f"Testing Simple CNN")
trainer.test(simple_model, dl_test)
print(f"Testing Better CNN")
trainer.test(better_model, dl_test)
print(f"Testing Better CNN with Mixup")
trainer.test(mixup_model, dl_test)
print(f"Testing Finetuned Resnet50d")
trainer.test(pretrained_model, dl_test)

Testing Simple CNN


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.710661768913269
Test Accuracy: 0.10455599427223206


Testing Better CNN


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.8697017431259155
Test Accuracy: 0.30864983797073364


Testing Better CNN with Mixup


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.8728142976760864
Test Accuracy: 0.308416485786438


Testing Finetuned Resnet50d


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.4768258333206177
Test Accuracy: 0.019896429032087326


[{'test': 3.6787164211273193}]