### Environment

Pytorch and pytorch-lightning must be installed. Use:

```
$ conda install pytorch torchvision torchaudio -c pytorch
$ conda install -c conda-forge pytorch-lightning
```


In [1]:
# Full documentation: https://pytorch-lightning.readthedocs.io/en/latest/
# Code adapted from https://learnopencv.com/getting-started-with-pytorch-lightning/
# and https://learnopencv.com/tensorboard-with-pytorch-lightning/

# Import requirements
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import os

### Defining the model object

Pytorch-lightning will handle everything we need to do, we just need to define the functions.

Below, we define the following *required* functions just as we would using regular pytorch:

- prepare_data(): For one-time operations (downloading mnist)
- train_dataloader(): Loading train/validation data into dataloader objects
- forward(): Standard forward pass definition
- configure_optimizers(): Instantiate optimizer
- training_step(): Called on every batch, feeds data through the model

And the following optional functions for logging and visualization:

- custom_histogram_adder(): Visualize the changes in the weight distributions v epoch
- training_epoch_end(): Called after each epoch completes, primarily for logging

We get started by inheriting from `pl.LightningModule`:



In [2]:
class lightningModel(pl.LightningModule):
    
    # Sample model definition - same as standard pytorch
    def __init__(self):
        super(lightningModel, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1,28,kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2))
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(28,10,kernel_size=2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2))
        self.dropout1=torch.nn.Dropout(0.25)
        self.fc1=torch.nn.Linear(250,18)
        self.dropout2=torch.nn.Dropout(0.08)
        self.fc2=torch.nn.Linear(18,10)

 
    # One-time operations like downloading data, etc.
    def prepare_data(self):
        # The following will download raw MNIST data into ./MNIST
        MNIST(os.getcwd(), train = True, download = True)
        MNIST(os.getcwd(), train = False, download = True)
    
    # REQUIRED FUNCTION
    # Load and split dataset into train/val/test sets
    def train_dataloader(self):
        mnist_train = MNIST(os.getcwd(), train = True, download = False,transform = transforms.ToTensor())
        self.train_set, self.val_set = random_split(mnist_train,[55000,5000])
        return DataLoader(self.train_set,batch_size = 128)
    
    # OPTIONAL FUNCTION
    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size = 128)

    # OPTIONAL FUNCTION
    def test_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), 
                          train = False,
                          download = False,
                          transform = transforms.ToTensor()),
                          batch_size = 128)

    # REQUIRED FUNCTION
    # Forward pass - same as standard pytorch
    def forward(self,x):
          x = self.layer1(x)
          x = self.layer2(x)
          x = self.dropout1(x)
          x = torch.relu(self.fc1(x.view(x.size(0), -1)))
          x = F.leaky_relu(self.dropout2(x))
          
          return F.softmax(self.fc2(x), dim=1)

    # REQUIRED FUNCTION
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

    # REQUIRED FUNCTION
    # This is called for every batch in the training dataset
    def training_step(self,batch,batch_idx):

        # Standard forward pass
        x, labels = batch
        pred = self.forward(x)
        train_loss = F.cross_entropy(pred, labels)
        
        # identifying number of correct predections in a given batch
        correct = pred.argmax(dim = 1).eq(labels).sum().item()

        # identifying total number of labels in a given batch
        total = len(labels)


        
        # Log dictionary 
        logs = {"train_loss": train_loss}

        batch_dictionary={
            #REQUIRED: must at minimum return the loss
            "loss": train_loss,
            
            # Optional for batch logging purposes
            "log": logs,

            # To be used for logging at the end of each epoch
            "correct": correct,
            "total": total
        }

        return batch_dictionary

    # OPTIONAL FUNCTION - Add custom histogram of the weights
    def custom_histogram_adder(self):
       
        # Iterate through all parameters, log histogram of weights
        for name,params in self.named_parameters():
            self.logger.experiment.add_histogram(name,params,self.current_epoch)

            
    # OPTIONAL FUNCTION - Called after every epoch is completed
    def training_epoch_end(self,outputs):

        # Calculating average loss and accuracy
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct = sum([x["correct"] for  x in outputs])
        total = sum([x["total"] for  x in outputs])
        
        # Log custom scalar values
        self.logger.experiment.add_scalar("Loss/Train",
                                            avg_loss,
                                            self.current_epoch)

        self.logger.experiment.add_scalar("Accuracy/Train",
                                            correct/total,
                                            self.current_epoch) 
        
        # Only log the graph once
        if(self.current_epoch == 1):
            sampleImg = torch.rand((1,1,28,28))
            self.logger.experiment.add_graph(lightningModel(), sampleImg)  # Returns a SummaryWriter object

        # Add custom histogram using the custom_histogram_adder() function defined above
        self.custom_histogram_adder()
        
        # creating log dictionary
        tensorboard_logs = {'loss': avg_loss,"Accuracy": correct/total}

        epoch_dictionary = {
            # Required
            'loss': avg_loss,
            'log': tensorboard_logs}

        return epoch_dictionary

### Using the Lightning Trainer

Once the `model` class and all required functions have been implemented, the Lightning trainer takes over execution. 

The trainer handles just about everything we could want, including:

- Automatically enabling/disabling grads
- Running the training, validation and test dataloaders
- Calling the Callbacks at the appropriate times
- Putting batches and computations on the correct devices
- Saving, logging and checkpointing the models
- Gracefully shutting down after an abort
- Even handling distributed execution over [multiple GPUs on multiple machines](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#trainer-flags)

Implementing all this can be done in four lines:

In [5]:
# To use pytorch-lightning logs, we just need to pass a logger object to the Trainer
logger = TensorBoardLogger('tb_logs', name='model_run')

# Instantiate the trainer, which is highly customizable: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html
Trainer = pl.Trainer(max_epochs = 3, logger = logger)

# Instantiate the model
model = lightningModel()

# Let pytorch-lighting manage execution
Trainer.fit(model)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name     | Type       | Params
----------------------------------------
0 | layer1   | Sequential | 728   
1 | layer2   | Sequential | 1.1 K 
2 | dropout1 | Dropout    | 0     
3 | fc1      | Linear     | 4.5 K 
4 | dropout2 | Dropout    | 0     
5 | fc2      | Linear     | 190   
----------------------------------------
6.6 K     Trainable params
0         Non-trainable params
6.6 K     Total params


Epoch 2: 100%|██████████| 430/430 [00:25<00:00, 17.08it/s, loss=1.55, v_num=26]


1

### Visualization with Tensorboard

Finally, we can visualize all the logging via tensorboard just as we would with standard pytorch:


In [6]:
# Load tensorboard - this may look different on your machine
%load_ext tensorboard
%tensorboard --logdir tb_logs/model_run

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 15186), started 1:00:07 ago. (Use '!kill 15186' to kill it.)