## PytorchLightning: Image Classification using CIFAR10 and ResNet50

### Load modules

In [1]:
import os
import pdb
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import loggers as pl_loggers
import lightning.pytorch as pl

### Load Dataset

In [2]:
data_set = CIFAR10(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
test_set = CIFAR10(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())

Files already downloaded and verified
Files already downloaded and verified


Split dataset

In [3]:
train_set_size = int(len(data_set) * 0.8)
valid_set_size = len(data_set) - train_set_size

seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator = seed)

print(f"Train size: {train_set_size}")
print(f"Valid size: {valid_set_size}")

Train size: 40000
Valid size: 10000


### Setup Model

In [4]:
class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, num_target_classes=10):
        super().__init__()
        
        # save hyperparameters
        self.save_hyperparameters()
        
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
        
        self.classifier = nn.Linear(num_filters, num_target_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("valid_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        loss = self.criterion(y_pred, y)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        return loss
    
    def forward(self, x):
        
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        
        y_pred = self.classifier(representations)
        
        return y_pred
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
model = ImagenetTransferLearning()

In [None]:
train_loader = DataLoader(train_set, batch_size = 512)
valid_loader = DataLoader(valid_set, batch_size = 512)
test_loader = DataLoader(test_set, batch_size=512)

In [None]:
tb_logger = pl_loggers.TensorBoardLogger('resnet50_level6/')

In [None]:
trainer = pl.Trainer(max_epochs=5, 
                     default_root_dir="resnet50_level6/",
                     enable_checkpointing=True,
                     logger=tb_logger)

trainer.fit(model, train_loader, valid_loader)

In [None]:
trainer.test(model, test_loader)

### Load Checkpoint

In [5]:
model_resnet = ImagenetTransferLearning()

In [6]:
model_resnet = model_resnet.load_from_checkpoint("resnet50_level6/lightning_logs/version_0/checkpoints/epoch=4-step=395.ckpt")

In [13]:
x = torch.rand(1, 3, 224, 224, dtype=torch.float).to("cuda")

In [14]:
# disable randomness, dropout, etc...
model_resnet.eval()

# predict with the model
y_hat = model_resnet(x)