In [1]:
# MNIST ANN Pytorch Lightning practice (2020 by patrick loeber, with error)

In [7]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib as plt

import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning import Trainer


In [8]:
# define hyperparameters
input_size = 784 # image pixel 28 x 28
hidden_size = 100  # the higher the more accurate
output_size = 10  # number of classes
epochs = 2
batch_size = 100
learning_rate = 0.01

In [9]:
class LitAnn(pl.LightningModule):
    def __init__(self, input_size, hidden_size, output_size):
        super(LitAnn, self).__init__()
        self.input_size = input_size
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return out

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        images, labels = batch
        images = images.reshape(-1, 28*28)
        
        #forward pass
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}  # can add tensorboard_logs
     
    
    #load training data
    def train_dataloader(self):
        train_dataset = torchvision.datasets.MNIST(root="./data", 
                                                   train=True, 
                                                   transform=transforms.ToTensor(), 
                                                   download=True)
        
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                   batch_size=batch_size,
                                                   num_workers=4, # increase training speed
                                                   shuffle=True)
        return train_loader

    # test validation       
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        images = images.reshape(-1, 28*28)
        
        #forward pass
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        return {'val_loss': loss}

        #validation dataset, used to fine tune the hyper parameters of the model before test.
    def val_dataloader(self):
        val_dataset = torchvision.datasets.MNIST(root="./data", 
                                                   train=False, 
                                                   transform=transforms.ToTensor())
        
        val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                                   batch_size=batch_size,
                                                   num_workers=4,
                                                   shuffle=False)
        return val_loader
    
    def on_validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#       tensorboard_logs = {'val_loss': avg_loss}
        return{'val_loss': avg_loss} #error?
#[docs]    def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Called when the val epoch ends."""
        
    def configure_optimizers(self):
        return torch.optim.Adam(model.parameters(), lr=learning_rate)
    
if __name__ == '__main__':
    
    # inside Trainer, can add fast_dev_run=True: send single batch to test if the model works
    # to find the best learning rate : auto_lr_find=True in Trainer
    trainer = Trainer(max_epochs=epochs) # use gpu by gpu= 1 or 2 etc. or ddp for distributed backed
    model = LitAnn(input_size, hidden_size, output_size)
    trainer.fit(model)  
    trainer.test()


    

        

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 78.5 K
1 | relu | ReLU   | 0     
2 | l2   | Linear | 1.0 K 
--------------------------------
79.5 K    Trainable params
0         Non-trainable params
79.5 K    Total params
0.318     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

TypeError: on_validation_epoch_end() missing 1 required positional argument: 'outputs'