In [1]:
import os

import pytorch_lightning as pl
# your favorite machine learning tracking tool
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.progress import TQDMProgressBar

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import StanfordCars
from torchvision.datasets.utils import download_url
import torchvision.models as models


import wandb

In [3]:
class StanfordCarsDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        # Augmentation policy for training set
        self.augmentation = transforms.Compose([
              transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
              transforms.RandomGrayscale(p=0.5),
              transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
              transforms.RandomPosterize(bits=2, p=0.5),
            
              transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
              transforms.RandomRotation(degrees=15),
              transforms.RandomHorizontalFlip(),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
        # Preprocessing steps applied to validation and test set.
        self.transform = transforms.Compose([
              transforms.Resize(size=256),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
        
        self.num_classes = 196

    def prepare_data(self):
        StanfordCars(root=self.data_dir, download=True, split="train")
        StanfordCars(root=self.data_dir, download=True, split="test")

    def setup(self, stage=None):
        # build dataset
        self.train = StanfordCars(root=self.data_dir, split="train", transform=self.augmentation)
        # split dataset
        # self.train, self.val = random_split(dataset, [0.6, 0.4])

        self.val, self.test = random_split(StanfordCars(root=self.data_dir, split="test", transform=self.transform), [0.5, 0.5])
        
        # self.test = random_split(self.test, [len(self.test)])[0]

        # self.train.dataset.transform = self.augmentation
        # self.val.dataset.transform = self.transform
        # self.test.dataset.transform = self.transform
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size, num_workers=2)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, num_workers=2)

In [4]:
from torch.nn.modules.linear import Linear
    
class IC(nn.Module):
    def __init__(self, input_shape, p=0.01):
        super().__init__()
        self.norm = nn.BatchNorm2d(input_shape, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # self.scale = Scaling()
        self.drop = nn.Dropout(p)

    def forward(self, x):      
        x = self.norm(x)
        # x = self.scale(x)
        x = self.drop(x)
        return x

class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=0.0001, transfer=False):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes
        # transfer learning if pretrained=True
        
        efficient = models.efficientnet_b1(pretrained=transfer)
        # self.model = models.efficientnet_b1(pretrained=transfer)
        
        x_shape = [3, 32, 16, 24, 40, 80, 112, 192, 320, 1280]
        
        p = 0.01
        
        if transfer:
            # layers are frozen by using eval()
            # efficient.eval()
            # freeze params
            for param in efficient.parameters():
                param.requires_grad = True
        
        efficient.features[0] = nn.Sequential(IC(3), efficient.features[0][0], efficient.features[0][2])
        
        efficient.features[1][0].block[0] = nn.Sequential(IC(32), efficient.features[1][0].block[0][0], efficient.features[1][0].block[0][2])
        efficient.features[1][0].block[2] = nn.Sequential(IC(32), efficient.features[1][0].block[2][0])
        efficient.features[1][1].block[0] = nn.Sequential(IC(16), efficient.features[1][1].block[0][0], efficient.features[1][1].block[0][2])
        efficient.features[1][1].block[2] = nn.Sequential(IC(16), efficient.features[1][1].block[2][0])
        
        efficient.features[2][0].block[0] = nn.Sequential(IC(16), efficient.features[2][0].block[0][0], efficient.features[2][0].block[0][2])
        efficient.features[2][0].block[1] = nn.Sequential(IC(96), efficient.features[2][0].block[1][0], efficient.features[2][0].block[1][2])
        efficient.features[2][0].block[3] = nn.Sequential(IC(96), efficient.features[2][0].block[3][0])
        efficient.features[2][1].block[0] = nn.Sequential(IC(24), efficient.features[2][1].block[0][0], efficient.features[2][1].block[0][2])
        efficient.features[2][1].block[1] = nn.Sequential(IC(144), efficient.features[2][1].block[1][0], efficient.features[2][1].block[1][2])
        efficient.features[2][1].block[3] = nn.Sequential(IC(144), efficient.features[2][1].block[3][0])
        efficient.features[2][2].block[0] = nn.Sequential(IC(24), efficient.features[2][2].block[0][0], efficient.features[2][2].block[0][2])
        efficient.features[2][2].block[1] = nn.Sequential(IC(144), efficient.features[2][2].block[1][0], efficient.features[2][2].block[1][2])
        efficient.features[2][2].block[3] = nn.Sequential(IC(144), efficient.features[2][2].block[3][0])
        
        efficient.features[3][0].block[0] = nn.Sequential(IC(24), efficient.features[3][0].block[0][0], efficient.features[3][0].block[0][2])
        efficient.features[3][0].block[1] = nn.Sequential(IC(144), efficient.features[3][0].block[1][0], efficient.features[3][0].block[1][2])
        efficient.features[3][0].block[3] = nn.Sequential(IC(144), efficient.features[3][0].block[3][0])
        efficient.features[3][1].block[0] = nn.Sequential(IC(40), efficient.features[3][1].block[0][0], efficient.features[3][1].block[0][2])
        efficient.features[3][1].block[1] = nn.Sequential(IC(240), efficient.features[3][1].block[1][0], efficient.features[3][1].block[1][2])
        efficient.features[3][1].block[3] = nn.Sequential(IC(240), efficient.features[3][1].block[3][0])
        efficient.features[3][2].block[0] = nn.Sequential(IC(40), efficient.features[3][2].block[0][0], efficient.features[3][2].block[0][2])
        efficient.features[3][2].block[1] = nn.Sequential(IC(240), efficient.features[3][2].block[1][0], efficient.features[3][2].block[1][2])
        efficient.features[3][2].block[3] = nn.Sequential(IC(240), efficient.features[3][2].block[3][0])
        
        efficient.features[4][0].block[0] = nn.Sequential(IC(40), efficient.features[4][0].block[0][0], efficient.features[4][0].block[0][2])
        efficient.features[4][0].block[1] = nn.Sequential(IC(240), efficient.features[4][0].block[1][0], efficient.features[4][0].block[1][2])
        efficient.features[4][0].block[3] = nn.Sequential(IC(240), efficient.features[4][0].block[3][0])
        efficient.features[4][1].block[0] = nn.Sequential(IC(80), efficient.features[4][1].block[0][0], efficient.features[4][1].block[0][2])
        efficient.features[4][1].block[1] = nn.Sequential(IC(480), efficient.features[4][1].block[1][0], efficient.features[4][1].block[1][2])
        efficient.features[4][1].block[3] = nn.Sequential(IC(480), efficient.features[4][1].block[3][0])
        efficient.features[4][2].block[0] = nn.Sequential(IC(80), efficient.features[4][2].block[0][0], efficient.features[4][2].block[0][2])
        efficient.features[4][2].block[1] = nn.Sequential(IC(480), efficient.features[4][2].block[1][0], efficient.features[4][2].block[1][2])
        efficient.features[4][2].block[3] = nn.Sequential(IC(480), efficient.features[4][2].block[3][0])
        efficient.features[4][3].block[0] = nn.Sequential(IC(80), efficient.features[4][3].block[0][0], efficient.features[4][3].block[0][2])
        efficient.features[4][3].block[1] = nn.Sequential(IC(480), efficient.features[4][3].block[1][0], efficient.features[4][3].block[1][2])
        efficient.features[4][3].block[3] = nn.Sequential(IC(480), efficient.features[4][3].block[3][0])
        
        efficient.features[5][0].block[0] = nn.Sequential(IC(80), efficient.features[5][0].block[0][0], efficient.features[5][0].block[0][2])
        efficient.features[5][0].block[1] = nn.Sequential(IC(480), efficient.features[5][0].block[1][0], efficient.features[5][0].block[1][2])
        efficient.features[5][0].block[3] = nn.Sequential(IC(480), efficient.features[5][0].block[3][0])
        efficient.features[5][1].block[0] = nn.Sequential(IC(112), efficient.features[5][1].block[0][0], efficient.features[5][1].block[0][2])
        efficient.features[5][1].block[1] = nn.Sequential(IC(672), efficient.features[5][1].block[1][0], efficient.features[5][1].block[1][2])
        efficient.features[5][1].block[3] = nn.Sequential(IC(672), efficient.features[5][1].block[3][0])
        efficient.features[5][2].block[0] = nn.Sequential(IC(112), efficient.features[5][2].block[0][0], efficient.features[5][2].block[0][2])
        efficient.features[5][2].block[1] = nn.Sequential(IC(672), efficient.features[5][2].block[1][0], efficient.features[5][2].block[1][2])
        efficient.features[5][2].block[3] = nn.Sequential(IC(672), efficient.features[5][2].block[3][0])
        efficient.features[5][3].block[0] = nn.Sequential(IC(112), efficient.features[5][3].block[0][0], efficient.features[5][3].block[0][2])
        efficient.features[5][3].block[1] = nn.Sequential(IC(672), efficient.features[5][3].block[1][0], efficient.features[5][3].block[1][2])
        efficient.features[5][3].block[3] = nn.Sequential(IC(672), efficient.features[5][3].block[3][0])
        
        efficient.features[6][0].block[0] = nn.Sequential(IC(112), efficient.features[6][0].block[0][0], efficient.features[6][0].block[0][2])
        efficient.features[6][0].block[1] = nn.Sequential(IC(672), efficient.features[6][0].block[1][0], efficient.features[6][0].block[1][2])
        efficient.features[6][0].block[3] = nn.Sequential(IC(672), efficient.features[6][0].block[3][0])
        efficient.features[6][1].block[0] = nn.Sequential(IC(192), efficient.features[6][1].block[0][0], efficient.features[6][1].block[0][2])
        efficient.features[6][1].block[1] = nn.Sequential(IC(1152), efficient.features[6][1].block[1][0], efficient.features[6][1].block[1][2])
        efficient.features[6][1].block[3] = nn.Sequential(IC(1152), efficient.features[6][1].block[3][0])
        efficient.features[6][2].block[0] = nn.Sequential(IC(192), efficient.features[6][2].block[0][0], efficient.features[6][2].block[0][2])
        efficient.features[6][2].block[1] = nn.Sequential(IC(1152), efficient.features[6][2].block[1][0], efficient.features[6][2].block[1][2])
        efficient.features[6][2].block[3] = nn.Sequential(IC(1152), efficient.features[6][2].block[3][0])
        efficient.features[6][3].block[0] = nn.Sequential(IC(192), efficient.features[6][3].block[0][0], efficient.features[6][3].block[0][2])
        efficient.features[6][3].block[1] = nn.Sequential(IC(1152), efficient.features[6][3].block[1][0], efficient.features[6][3].block[1][2])
        efficient.features[6][3].block[3] = nn.Sequential(IC(1152), efficient.features[6][3].block[3][0])
        efficient.features[6][4].block[0] = nn.Sequential(IC(192), efficient.features[6][4].block[0][0], efficient.features[6][4].block[0][2])
        efficient.features[6][4].block[1] = nn.Sequential(IC(1152), efficient.features[6][4].block[1][0], efficient.features[6][4].block[1][2])
        efficient.features[6][4].block[3] = nn.Sequential(IC(1152), efficient.features[6][4].block[3][0])
        
        efficient.features[7][0].block[0] = nn.Sequential(IC(192), efficient.features[7][0].block[0][0], efficient.features[7][0].block[0][2])
        efficient.features[7][0].block[1] = nn.Sequential(IC(1152), efficient.features[7][0].block[1][0], efficient.features[7][0].block[1][2])
        efficient.features[7][0].block[3] = nn.Sequential(IC(1152), efficient.features[7][0].block[3][0])
        efficient.features[7][1].block[0] = nn.Sequential(IC(320), efficient.features[7][1].block[0][0], efficient.features[7][1].block[0][2])
        efficient.features[7][1].block[1] = nn.Sequential(IC(1920), efficient.features[7][1].block[1][0], efficient.features[7][1].block[1][2])
        efficient.features[7][1].block[3] = nn.Sequential(IC(1920), efficient.features[7][1].block[3][0])
        
        efficient.features[8] = nn.Sequential(IC(320), efficient.features[8][0], efficient.features[8][2])
        
        efficient.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes, bias=True)

        #n_sizes = self._get_conv_output(input_shape)

        #self.avg_pool = nn.AdaptiveAvgPool2d(1) # add an adaptive average pooling layer to reduce the output to a fixed size
        
        self.model = efficient
        
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=196)
  
    # returns the size of the output tensor going into the Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(tmp_input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size

    def _forward_features(self, x):
        x = self.model(x)
        return x
    
    def ic(self,x, p=0.01):
        x = nn.BatchNorm2d(num_features=x.shape[1], eps=1.1e-5).to(x.device)(x)
        # scale = Scaling()
        # x =  scale(x)
        x = nn.Dropout(p)(x)
        return x

    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch):
        batch, gt = batch[0], batch[1]
        out = self.forward(batch)
        loss = self.criterion(out, gt)

        acc = self.accuracy(out, gt)

        self.log("train/loss", loss, prog_bar=True, on_epoch=True)
        self.log("train/acc", acc, prog_bar=True, on_epoch=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        batch, gt = batch[0], batch[1]
        out = self.forward(batch)
        loss = self.criterion(out, gt)

        self.log("val/loss", loss, prog_bar=True, on_epoch=True)

        acc = self.accuracy(out, gt)
        self.log("val/acc", acc, prog_bar=True, on_epoch=True)

        return loss
    
    def test_step(self, batch, batch_idx):
        batch, gt = batch[0], batch[1]
        out = self.forward(batch)
        acc = self.accuracy(out, gt)
        loss = self.criterion(out, gt)
        self.log('test_acc', acc, on_step=False, on_epoch=True)
        self.log("test/loss", loss, prog_bar=True, on_epoch=True)
        return {"loss": loss, "outputs": out, "gt": gt}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [None]:
dm = StanfordCarsDataModule(batch_size=32)
model = LitModel((3, 224, 224), 196, learning_rate=0.001, transfer=True)
trainer = pl.Trainer(max_epochs=50, accelerator="gpu", logger=WandbLogger(project="IC"))

In [None]:
trainer.fit(model, dm)

In [None]:
trainer.test(model, dm)