## Set Up

In [1]:
!pip install pytorch-lightning
!pip install pytorch-lightning-bolts
!pip install torch torchvision







In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
from pytorch_lightning.metrics.functional import accuracy

In [3]:
import sys
sys.path.append("..")
from models import *

## Data

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Model

In [7]:
class LitResnet18Adam(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics

    def _shared_eval_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)
        return loss, acc

    def predict_step(self, batch, batch_idx, dataloader_idx):
        x, y = batch
        y_hat = self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.1)
        

## Train

In [9]:
# model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

net = ResNet18(1)
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
    
# init model
resnet = LitResnet18Adam(net)

# Initialize a trainer
trainer = pl.Trainer(gpus=0, max_epochs=200, progress_bar_refresh_rate=20)

# Train the model
trainer.fit(resnet, trainloader, testloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torchvision/datasets/mnist.py", line 134, in __getitem__
    img = self.transform(img)
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 60, in __call__
    img = t(img)
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 221, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/Users/venkateshw/Work/miniconda3/envs/adaswarm/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 335, in normalize
    tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]


## Test

In [None]:
trainer.test(test_dataloaders=test)

## Visualise