## PytorchLightning: Image Classification using CIFAR10 and ResNet50

### Load modules

In [1]:
import os
from argparse import ArgumentParser
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import loggers as pl_loggers
import lightning.pytorch as pl

### Load Dataset

In [None]:
data_set = CIFAR10(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
test_set = CIFAR10(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())

Split dataset

In [None]:
train_set_size = int(len(data_set) * 0.8)
valid_set_size = len(data_set) - train_set_size

seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator = seed)

print(f"Train size: {train_set_size}")
print(f"Valid size: {valid_set_size}")

### Setup Model

In [None]:
class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, num_target_classes=10):
        super().__init__()
        
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
        
        self.classifier = nn.Linear(num_filters, num_target_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("valid_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
model = ImagenetTransferLearning()

In [None]:
train_loader = DataLoader(train_set, batch_size = 512)
valid_loader = DataLoader(valid_set, batch_size = 512)
test_loader = DataLoader(test_set, batch_size=512)

In [None]:
tb_logger = pl_loggers.TensorBoardLogger('cifar10_logs/')

In [None]:
trainer = pl.Trainer(max_epochs=5, 
                     default_root_dir="resnet50/",
                     enable_checkpointing=True,
                     logger=tb_logger)

trainer.fit(model, train_loader, valid_loader)

In [None]:
trainer.test(model, test_loader)

### Adding argument parser for py file

In [None]:
parser = ArgumentParser()

In [None]:
# Trainer arguments
parser.add_argument("--devices", type=int, default=2)

# Hyperparameters for the model
parser.add_argument("--layer_1_dim", type=int, default=128)

In [None]:
# Parse the user inputs and defaults (returns a argparse.Namespace)
args = parser.parse_args()

In [None]:
# Use the parsed arguments in your program
trainer = Trainer(devices=args.devices)
model = ImagenetTransferLearning(ImagenetTransferLearning=args.layer_1_dim)