In [None]:
! pip install pytorch-lightning --quiet

[K     |████████████████████████████████| 563kB 2.8MB/s 
[K     |████████████████████████████████| 276kB 13.2MB/s 
[K     |████████████████████████████████| 92kB 7.7MB/s 
[K     |████████████████████████████████| 829kB 14.1MB/s 
[?25h  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... [?25l[?25hdone


In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy 
from pytorch_lightning.metrics.functional.classification import to_categorical 
from torch.utils.data import random_split
from sklearn.metrics import classification_report , confusion_matrix

In [None]:
class MnistClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28*28,128)
        self.layer_2 = nn.Linear(128, 10)
       
        
    def forward(self,x):
        x = x.view(x.size(0),-1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self,batch,batch_idx):
        x , y = batch
        logits = self(x)
        loss = F.cross_entropy(logits,y)
        # result = pl.TrainResult(loss)
        accuracy = Accuracy()
        acc = accuracy(torch.tensor(logits).cpu(),torch.tensor(y).cpu())
        pbar = {'training_acc': acc}
        # result = pl.TrainResult(loss)
        return {'loss' : loss , 'progress_bar':pbar}
    
    def validation_step(self , batch , batch_idx):
        result = self.training_step(batch,batch_idx)
        result['progress_bar']['val_acc'] = result['progress_bar']['training_acc']
        return result
    
    def validation_epoch_end(self , val_step_outputs):
        avg_loss = torch.stack([x['loss'] for x in val_step_outputs]).mean()
        avg_acc =  torch.stack([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()
        tqdm_dict = {'val_loss': avg_loss , 'val_acc': avg_acc}
        return {
                'progress_bar': tqdm_dict,
                'log': {'val_loss': avg_loss , 'val_acc': avg_acc},
                }
    
    def test_step(self , batch , batch_idx):
        x , y = batch
        logits = self(x)
        loss = F.cross_entropy(logits,y)
        logits = torch.tensor(logits)
        accuracy = Accuracy()
        acc = accuracy(logits.cpu(), torch.tensor(y).cpu())
        logits = to_categorical(logits)

        pbar = {'test_acc': acc }
        print(classification_report(y.cpu(),logits.cpu()))
        print(confusion_matrix(y.cpu(),logits.cpu()))
    
        # result = pl.TrainResult(loss)
        return {'loss' : loss , 'progress_bar':pbar}
    

In [None]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self , batch_size=64):
        super().__init__()
        self.batch_size = batch_size
    
    def prepare_data(self):
        self.mnist_train = MNIST(os.getcwd(),train=True , download=True)
        self.mnist_test = MNIST(os.getcwd(),train=False, download=True)
    
    def setup(self,stage):
        transform = transforms.Compose([transforms.ToTensor() , 
                                        transforms.Normalize((0.1307,),(0.3081,))])
        mnist_train = MNIST(os.getcwd(),train=True , transform=transform)
        
        self.mnist_train , self.mnist_val = random_split(mnist_train,[55000,5000])
        self.mnist_test = MNIST(os.getcwd(),train=False , transform=transform)
    
    def train_dataloader(self):
        return DataLoader(self.mnist_train , batch_size= self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val , batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test , batch_size= self.mnist_test.__len__())
    

In [None]:
pl.seed_everything(1234)

dm = MNISTDataModule(batch_size=32)
model = MnistClassifier()

# training
trainer = pl.Trainer(gpus=1, max_epochs=20 , progress_bar_refresh_rate=50)
trainer.fit(model, dm)
trainer.test(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 1 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



Please use self.log(...) inside the lightningModule instead.

# log on a step or aggregate epoch metric to the logger and/or progress bar
# (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
Please use self.log(...) inside the lightningModule instead.

# log on a step or aggregate epoch metric to the logger and/or progress bar
# (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

              precision    recall  f1-score   support

           0       0.99      0.98      0.98       980
           1       0.98      0.99      0.99      1135
           2       0.98      0.97      0.98      1032
           3       0.99      0.94      0.96      1010
           4       0.96      0.98      0.97       982
           5       0.93      0.99      0.96       892
           6       0.97      0.97      0.97       958
           7       0.97      0.98      0.98      1028
           8       0.98      0.96      0.97       974
           9       0.97      0.96      0.96      1009

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000

[[ 963    1    0    0    1    6    4    1    1    3]
 [   0 1129    2    0    0    0    2    1    1    0]
 [   3    1 1005    3    5    0    2    6    6    1]
 [   2    2    4  951    1   27    0    7    4   12]
 [   1    1    2    1  962   



[{'loss': 0.19195282459259033, 'test_acc': 0.9732000231742859}]