## PyTorch Lightning CIFAR10 ~94% Baseline Tutorial
Train a Resnet to 94% accuracy on Cifar10!

### Setup


In [None]:
#! pip install pytorch-lightning lightning-bolts -qU

In [1]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn

from torchmetrics.functional import accuracy

In [2]:
seed_everything(7)

PATH_DATASETS = os.environ.get('PATH_DATASETS', '.')
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64
NUM_WORKERS = int(os.cpu_count() / 2)

Global seed set to 7


### CIFAR10 Data Module
Import the existing data module from bolts and modify the train and test transforms.

In [3]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    cifar10_normalization()
])

In [4]:
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    cifar10_normalization()
])

In [5]:
cifar10_dm = CIFAR10DataModule(
    data_dir = PATH_DATASETS,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS,
    train_transforms = train_transforms,
    test_transforms = test_transforms,
    val_transforms = test_transforms
)

### ResNet
Modify the pre-existing `ResNet` architecture from `torchvision`. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32).

In [6]:
def create_model():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    
    model.conv1 = nn.Conv2d(3, 64, 
                            kernel_size=(3, 3), 
                            stride=(1, 1), 
                            padding=(1, 1), 
                            bias=False)
    model.maxpool = nn.Identity()
    
    return model

### Lightning Module
Check out the `configure_optimizers` method to use custom Learning Rate schedulers. 

The `OneCycleLR` with `SGD` will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. 

Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

In [22]:
class LitResNet(LightningModule):
    
    def __init__(self, lr=0.05):
        super().__init__()
        
        self.save_hyperparameters()
        self.model = create_model()
        
    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        return loss
    
    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        
        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)
            
    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, 'val')
        
    def test_step(self, batch, batch_idx):
        self.evaluate(batch, 'test')
        
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), 
                                    lr=self.hparams.lr, 
                                    momentum=0.9, 
                                    weight_decay=5e-4)
        
        steps_per_epoch = 45000 // BATCH_SIZE
        scheduler_dict = {
            'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),
            'interval': 'step'
        }
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

In [23]:
model = LitResNet(lr=0.05)
model

LitResNet(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): Identity()
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (con

In [24]:
model.datamodule = cifar10_dm

In [25]:
trainer = Trainer(
    max_epochs=30,
    gpus=AVAIL_GPUS,
    progress_bar_refresh_rate=10,
    logger=TensorBoardLogger('lightning_logs/', name='resnet'),
    callbacks=[LearningRateMonitor(logging_interval='step')],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)


  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 7


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

### Bonus: Use Stochastic Weight Averaging (SWA) to get a boost on performance
Use SWA from `torch.optim` to get a quick performance boost. Also shows a couple of cool features from Lightning:

- Use `training_epoch_end` to run code after the end of every epoch
- Use a pretrained model directly with this wrapper for SWA

In [None]:
class SWAResNet(LitResNet):
    
    def __init__(self, trained_model, lr=0.01):
        super().__init__()
        
        self.save_hyperparameters('lr')
        self.model = trained_model
        self.swa_model = AverageModel(self.model)
        
    def forward(self, x):
        out = self.swa_model(x)
        return F.log_softmax(out, dim=1)
    
    def training_epoch_end(self, training_step_outputs):
        self.swa_model.update_parameters(self.model)
        
    def validation_step(self, batch, batch_idx, stage=None):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        
    def configure_optimizer(self):
        optimizer = torch.optim.SGD(self.model.parameters(), 
                                    lr=self.hparams.lr, 
                                    momentum=0.9, 
                                    weight_decay=5e-4)
        return optimizer
    
    def on_train_end(self):
        update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)

In [None]:
swa_model = SWAResNet(model.model, lr=0.01)
swa_model.datamodule = cifar10_dm

swa_trainer = Trainer(
    max_epochs=20,
    gpus=AVAIL_GPUS,
    progress_bar_refresh_rate=20,
    logger=TensorBoardLogger('lightning_logs/', name='swa_resnet'),
)

swa_trainer.fit(swa_model, cifar10_dm)
swa_trainer.test(swa_model, datamodule=cifar10_dm)

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/