## PytorchLightning: Image Classification using CIFAR10 and ResNet50

### Load modules

In [1]:
import os
import pdb
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import loggers as pl_loggers
import lightning.pytorch as pl

### Load Dataset

In [2]:
data_set = CIFAR10(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
test_set = CIFAR10(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())

Files already downloaded and verified
Files already downloaded and verified


Split dataset

In [3]:
train_set_size = int(len(data_set) * 0.8)
valid_set_size = len(data_set) - train_set_size

seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator = seed)

print(f"Train size: {train_set_size}")
print(f"Valid size: {valid_set_size}")

Train size: 40000
Valid size: 10000


### Setup Model

In [None]:
class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, num_target_classes=10):
        super().__init__()
        
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
        
        self.classifier = nn.Linear(num_filters, num_target_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("valid_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
model = ImagenetTransferLearning()

In [None]:
train_loader = DataLoader(train_set, batch_size = 512)
valid_loader = DataLoader(valid_set, batch_size = 512)
test_loader = DataLoader(test_set, batch_size=512)

In [None]:
tb_logger = pl_loggers.TensorBoardLogger('cifar10_logs/')

In [None]:
trainer = pl.Trainer(max_epochs=5, 
                     default_root_dir="resnet50/",
                     enable_checkpointing=True,
                     logger=tb_logger)

trainer.fit(model, train_loader, valid_loader)

In [None]:
trainer.test(model, test_loader)

### Adding argument parser for py file

In [None]:
from argparse import ArgumentParser

In [None]:
parser = ArgumentParser()

In [None]:
# Trainer arguments
parser.add_argument("--devices", type=int, default=2)

# Hyperparameters for the model
parser.add_argument("--layer_1_dim", type=int, default=128)

In [None]:
# Parse the user inputs and defaults (returns a argparse.Namespace)
args = parser.parse_args()

In [None]:
# Use the parsed arguments in your program
trainer = Trainer(devices=args.devices)
model = ImagenetTransferLearning(ImagenetTransferLearning=args.layer_1_dim)

### Debugging

In [None]:
def function_to_debug():
    x = 2
    print(x)
    pdb.set_trace()
    y = x**2

In [None]:
function_to_debug()

### Run all your model code once quickly

The fast_dev_run argument in the trainer runs 5 batch of training, validation, test and prediction data through your trainer to see if there are any bugs:

In [None]:
# default 5 batch
trainer = pl.Trainer(fast_dev_run=True, max_epochs=1)

In [None]:
trainer.fit(model, train_loader, valid_loader)

In [None]:
# change default batch

In [None]:
pl.Trainer(fast_dev_run=7)

In [None]:
trainer.fit(model, train_loader, valid_loader)

### Shorten the epoch length

In [None]:
# use only 20% of training data and 10% of val data
trainer = pl.Trainer(limit_train_batches=0.2, limit_val_batches=0.1, max_epochs=2)

In [None]:
trainer.fit(model, train_loader, valid_loader)

In [None]:
# use 10 batches of train and 5 batches of val
trainer = pl.Trainer(limit_train_batches=10, limit_val_batches=5, max_epochs=5)

In [None]:
trainer.fit(model, train_loader, valid_loader)

### Run a Sanity Check

In [None]:
trainer = pl.Trainer(num_sanity_val_steps=2, max_epochs=5)

In [None]:
trainer.fit(model, train_loader, valid_loader)

### Print LightningModule weights summary

In [None]:
from lightning.pytorch.callbacks import ModelSummary

In [None]:
trainer = pl.Trainer(callbacks=[ModelSummary(max_depth=-1)], 
                     limit_train_batches=10, 
                     limit_val_batches=5, 
                     max_epochs=5)

In [None]:
trainer.fit(model, train_loader, valid_loader)

To print the model summary if .fit() is not called:

In [None]:
from lightning.pytorch.utilities.model_summary import ModelSummary

summary = ModelSummary(model, max_depth=-1)
print(summary)

### Print input output layer dimensions

In [None]:
class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, num_target_classes=10, *args, **kwargs):
        super().__init__()
        
        self.example_input_array = torch.Tensor(32, 3, 228, 228)
        
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
        
        self.classifier = nn.Linear(num_filters, num_target_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("valid_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def forward(self, input_):
        x = input_
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        return y_pred
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
from lightning.pytorch.utilities.model_summary import ModelSummary

model = ImagenetTransferLearning()
summary = ModelSummary(model, max_depth=-1)
print(summary)