## Set Up

In [None]:
!pip install pytorch-lightning
!pip install pytorch-lightning-bolts
!pip install torch torchvision

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
from pytorch_lightning.metrics.functional import accuracy

In [None]:
import sys
sys.path.append("..")
from models import *

## Data

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914,), (0.2023,)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,), (0.2023,)),
])

trainset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

## Model

In [None]:
class LitResnet18Adam(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics

    def _shared_eval_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)
        return loss, acc

    def predict_step(self, batch, batch_idx, dataloader_idx):
        x, y = batch
        y_hat = self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.1)
        

## Train

In [None]:
# model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

net = ResNet18(1)
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
    
# init model
resnet = LitResnet18Adam(net)

# Initialize a trainer
trainer = pl.Trainer(gpus=0, max_epochs=200, progress_bar_refresh_rate=20)

# Train the model
trainer.fit(resnet, trainloader, testloader)

## Test

In [None]:
trainer.test(test_dataloaders=test)

## Visualise