In [1]:
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

In [2]:
class ConvNet(pl.LightningModule):

    def __init__(self):
        super(ConvNet, self).__init__()
        self.cn1 = nn.Conv2d(1, 16, 3, 1)
        self.cn2 = nn.Conv2d(16, 32, 3, 1)
        self.dp1 = nn.Dropout2d(0.10)
        self.dp2 = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32
        self.fc2 = nn.Linear(64, 10)
 
    def forward(self, x):
        x = self.cn1(x)
        x = F.relu(x)
        x = self.cn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dp1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dp2(x)
        x = self.fc2(x)
        op = F.log_softmax(x, dim=1)
        return op

    def training_step(self, batch, batch_num):
        train_x, train_y = batch
        y_pred = self(train_x)
        training_loss = F.cross_entropy(y_pred, train_y)
        # optional
        self.log('train_loss', training_loss, on_step=True, on_epoch=True, prog_bar=True)
        return training_loss

#     def validation_step(self, batch, batch_num):
#         # optional
#         val_x, val_y = batch
#         y_pred = self(val_x)
#         val_loss = F.cross_entropy(y_pred, val_y)
#         # optional
#         self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True)
#         return val_loss

#     def validation_epoch_end(self, outputs):
#         # optional
#         avg_loss = torch.stack(outputs).mean()
#         self.log('val_avg_loss', avg_loss, on_epoch=True, prog_bar=True)
#         return avg_loss

    def test_step(self, batch, batch_num):
        # optional
        test_x, test_y = batch
        y_pred = self(test_x)
        test_loss = F.cross_entropy(y_pred, test_y)
        # optional
        self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True)
        return test_loss

    def test_epoch_end(self, outputs):
        # optional
        avg_loss = torch.stack(outputs).mean()
        self.log('test_avg_loss', avg_loss, on_epoch=True, prog_bar=True)
        return avg_loss

    def configure_optimizers(self):
        return torch.optim.Adadelta(self.parameters(), lr=0.5)

    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, 
                                transform=transforms.Compose([transforms.ToTensor(),
                                                              transforms.Normalize((0.1302,), (0.3069,))])), 
                                batch_size=32, num_workers=4)

    def val_dataloader(self):
        # optional
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, 
                                transform=transforms.Compose([transforms.ToTensor(),
                                                              transforms.Normalize((0.1302,), (0.3069,))])), 
                                batch_size=32, num_workers=4)

    def test_dataloader(self):
        # optional
        return DataLoader(MNIST(os.getcwd(), train=False, download=True, 
                                transform=transforms.Compose([transforms.ToTensor(),
                                                              transforms.Normalize((0.1302,), (0.3069,))])), 
                                batch_size=32, num_workers=4)

In [3]:
model = ConvNet()

# most basic trainer, uses good defaults (1 gpu)
trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=10)    
trainer.fit(model)   

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type      | Params
-----------------------------------
0 | cn1  | Conv2d    | 160   
1 | cn2  | Conv2d    | 4 K   
2 | dp1  | Dropout2d | 0     
3 | dp2  | Dropout2d | 0     
4 | fc1  | Linear    | 294 K 
5 | fc2  | Linear    | 650   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [4]:
trainer.test()

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_avg_loss': tensor(0.0430),
 'test_loss': tensor(8.1433e-06),
 'test_loss_epoch': tensor(0.0430),
 'train_loss': tensor(0.0542),
 'train_loss_epoch': tensor(0.0327),
 'train_loss_step': tensor(0.0542),
 'val_loss': tensor(5.6977e-05),
 'val_loss_epoch': tensor(0.0167)}
--------------------------------------------------------------------------------





[{'train_loss_step': 0.05422795191407204,
  'train_loss': 0.05422795191407204,
  'val_loss_epoch': 0.016706781461834908,
  'val_loss': 5.697713640984148e-05,
  'train_loss_epoch': 0.03269123658537865,
  'test_avg_loss': 0.042974330484867096,
  'test_loss_epoch': 0.04304307699203491,
  'test_loss': 8.14332815934904e-06}]

In [5]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

  from IPython.utils import traitlets as _traitlets


Reusing TensorBoard on port 6006 (pid 54047), started 1 day, 21:21:12 ago. (Use '!kill 54047' to kill it.)

<IPython.core.display.Javascript object>