In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [2]:
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning import Trainer

In [11]:
input_size = 784 # 28x28
hidden_size = 500 
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001


class LitNeuralNet(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LitNeuralNet, self).__init__()
        self.input_size = input_size
        self.l1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)  
    
    
    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        
        
        return out

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        images, labels = batch
        images         = images.reshape(-1, 28*28)
        
        outputs        = self(images)
        loss           = F.cross_entropy( outputs, labels)
        
        tensorboard_logs = {'train_loss' : loss}
        
        return {'loss':loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=learning_rate)
    
    def train_dataloader(self):
        train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
        return train_loader
    
    def validation_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        images, labels = batch
        images         = images.reshape(-1, 28*28)
        
        outputs        = self(images)
        loss           = F.cross_entropy( outputs, labels)
        # Logging to TensorBoard by default
        return {'val_loss':loss}
    
    def val_dataloader(self):
        val_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

        # Data loader
        val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                          batch_size=batch_size, 
                                          num_workers = 4,
                                          shuffle=False)
        return val_loader
    
    def validation_epoch_end(self, outputs):
        avg_loss         = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'avg_val_loss' : avg_loss}
        return {'val_loss':avg_loss, 'log':tensorboard_logs}


trainer = Trainer(auto_lr_find=True, max_epochs = num_epochs, fast_dev_run=True)
model   = LitNeuralNet(input_size, hidden_size, num_classes)
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 392 K 
1 | relu | ReLU   | 0     
2 | l2   | Linear | 5.0 K 
--------------------------------
397 K     Trainable params
0         Non-trainable params
397 K     Total params
1.590     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]