### Environment

Pytorch and pytorch-lightning must be installed. Use:

```
$ conda install pytorch torchvision torchaudio -c pytorch
$ conda install -c conda-forge pytorch-lightning
```


In [1]:

# Code adapted from https://learnopencv.com/getting-started-with-pytorch-lightning/

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import os


### Defining the model object

Pytorch-lightning will handle everything we need to do, we just need to define the functions.


In [2]:
class lightningModel(pl.LightningModule):
    
    # Model definition - same as standard pytorch
    def __init__(self):
        super(lightningModel, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1,28,kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2))
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(28,10,kernel_size=2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2))
        self.dropout1=torch.nn.Dropout(0.25)
        self.fc1=torch.nn.Linear(250,18)
        self.dropout2=torch.nn.Dropout(0.08)
        self.fc2=torch.nn.Linear(18,10)

 
    # One-time operations like downloading data, etc.
    def prepare_data(self):
        MNIST(os.getcwd(), train=True, download =True)
        MNIST(os.getcwd(), train=False, download =True)
    
    # REQUIRED FUNCTION
    # Load and split dataset into train/val/test sets
    def train_dataloader(self):
        mnist_train=MNIST(os.getcwd(), train=True, download =False,transform=transforms.ToTensor())
        self.train_set, self.val_set= random_split(mnist_train,[55000,5000])
        return DataLoader(self.train_set,batch_size=128)
    
    # OPTIONAL FUNCTION
    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=128)

    # OPTIONAL FUNCTION
    def test_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), 
                          train=False,
                          download=False,
                          transform=transforms.ToTensor()),
                          batch_size=128)

    # REQUIRED FUNCTION
    # Forward pass - same as standard pytorch
    def forward(self,x):
          x=self.layer1(x)
          x=self.layer2(x)
          x=self.dropout1(x)
          x=torch.relu(self.fc1(x.view(x.size(0), -1)))
          x=F.leaky_relu(self.dropout2(x))
          
          return F.softmax(self.fc2(x), dim=1)

    # REQUIRED FUNCTION
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

    # REQUIRED FUNCTION
    # This is called for every batch in the training dataset
    def training_step(self,batch,batch_idx):
          
          # Load data, inference step, loss calculation
          x,labels=batch
          pred=self.forward(x)
          loss = F.nll_loss(pred, labels)
          
          # Must log and return the loss, but many other options are available 
          logs={"train_loss": loss}
          
          output={
              "loss": loss,
              "log": logs
          }
          
          return output


### Using the Lightning Trainer

Once the `model` class and all required functions have been implemented, the Lightning trainer takes over execution. 

The trainer handles just about everything we could want, including:

- Ensuring input data type matches model data type
- Managing batch ingestion and processing
- Calculating loss and calling the optimizer
- Iterating over all epochs
- Saving, logging and checkpointing the models


In [None]:
# Instantiate the trainer, which is highly customizable: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html
Trainer=pl.Trainer(max_epochs=2)

# Instantiate the model
model=lightningModel()

# Let pytorch-lighting manage execution
Trainer.fit(model)