In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

---
# PyTorch Lightning

1. Model
2. Optimizer
3. Data
4. training loop "the magic"
5. validation loop "the validation magic"

In [13]:
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

class ResNet(pl.LightningModule):
  """
  This is exactly the same as an nn module
  
  just with some extra optional ingredients
  
  NOTE: no need for .cuda()- lightning does that for us
  """
  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(28*28, 64)
    self.l2 = nn.Linear(64, 64)
    self.l3 = nn.Linear(64, 10)
    self.do = nn.Dropout(0.1)
    self.loss = nn.CrossEntropyLoss()

  def forward(self, x):
    h1 = nn.functional.relu(self.l1(x))
    h2 = nn.functional.relu(self.l2(h1))
    do = self.do(h2+h1)
    logits = self.l3(do)
    return logits
  
  def configure_optimizers(self):
    """
    pl function- can configure as many optimizers as we want
    pl gives us a train loop for each optimizer
    """
    optimizer = optim.SGD(self.parameters(), lr=1e-2)
    return optimizer
  
  ### training loop
  def training_step(self, batch, batch_idx):
    """
    pl function - implements training loop. 
    this is the magic
    """
    x, y = batch
    
    # x: b x 1 x 28 x 28
    b = x.size(0)
    x = x.view(b, -1)
    
    # 1 forward
    logits = self(x) #model(x) # l: logits
    
    # 2 compute objective function
    J = self.loss(logits, y)
    
    # lightning detaches automatically, need to return with graph attached. 
    # return J
  
    # calculate accuracy
    # metrics can be automatically calculated across all gpus for multi-gpu training
    acc = accuracy(logits, y)
    pbar = {'train_acc': acc}

    # equivalently
    # 3 reserved words: 'log', 'loss', 'progress_bar'
    return {'loss': J, 'progress_bar': pbar}

#   def backward(self, trainer, loss, optimizer, optimizer_idx):
#     """
#     This method is implemented for us, but if we want we can override it for custom functionality
#     """
#     loss.backward()

  def train_dataloader(self):
    """
    use this if we need to figure out the number of classes
    """
    train_data = datasets.MNIST('data', train=True, download=False, transform=transforms.ToTensor())
    self.train, self.val = random_split(train_data, [55000, 5000])
    train_loader = DataLoader(self.train, batch_size=32)
    #val_loader = DataLoader(val, batch_size=32)
    return train_loader
  
  def val_dataloader(self):

    val_loader = DataLoader(self.val, batch_size=32)
    return val_loader

  ### 2 methods for validation loop: validation_step, 
  def validation_step(self, batch, batch_idx):
    """
    We generally don't want metrics for every batch. plot for whole validation set.
    For every single batch in the validation loop, get the accuracy & loss. Lightning will cache it all for us
    """
    results = self.training_step(batch, batch_idx)
    results['progress_bar']['val_acc'] = results['progress_bar']['train_acc']
    del results['progress_bar']['train_acc']
    return results

  def validation_epoch_end(self, val_step_outputs):
    # [results, results, results, results, ...]
    # calcualte avg val loss for all val outputs
    avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
    avg_acc = torch.tensor([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()
    # note: early stopping is implemented automatically
    pbar = {'avg_val_acc': avg_val_acc}
    return {'val_loss': avg_val_loss, 'progress_bar': pbar} # val loss is all we care about for early stopping / checkpoint

model = ResNet()

In [14]:
trainer = pl.Trainer(progress_bar_refresh_rate=20,
                     max_epochs=5,
                     gpus=1)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type             | Params
------------------------------------------
0 | l1   | Linear           | 50.2 K
1 | l2   | Linear           | 4.2 K 
2 | l3   | Linear           | 650   
3 | do   | Dropout          | 0     
4 | loss | CrossEntropyLoss | 0     
------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

In [6]:
! ls lightning_logs/

version_0  version_1  version_2  version_3  version_4  version_5


lightning saved the best checkpoint for us plus logs

In [None]:
class ImageClassifier(nn.Module):
  def __init__(self):
    self.resnet = ResNet()

In [2]:
# train, val split
train_data = datasets.MNIST('data', train=True, download=False, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
# define model
model = nn.Sequential(
  nn.Linear(28*28, 64),
  nn.ReLU(),
  nn.Linear(64, 64), 
  nn.ReLU(),
  nn.Dropout(0.1),
  nn.Linear(64, 10)
)

In [5]:
# Define optimizer
params = model.parameters()
optimiser = optim.SGD(params, lr=1e-2)

In [6]:
# Define loss
loss = nn.CrossEntropyLoss()

In [7]:
# training and validation loops
nb_epochs = 5
for epoch in range(nb_epochs):
  losses = list()
  accuracies = list()
  model.train()
  for batch in train_loader:
    x, y = batch
    
    # x: b x 1 x 28 x 28
    b = x.size(0)
    x = x.view(b, -1)
    
    # 1 forward
    l = model(x) # l: logits
    
    # 2 compute objective function
    J = loss(l, y)
    
    # 3 clearning the gradients
    model.zero_grad()
    
    # accumulate the partial derivatives of J wrt params
    J.backward()
    
    # 5 step in the opposite direction of the gradient
    optimiser.step()
    
    losses.append(J.item())
    accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())
    
  print(f'Epoch {epoch+1}', end=', ')
  print(f'training loss: {torch.tensor(losses).mean():.2f}', end=', ')
  print(f'training accuracy: {torch.tensor(accuracies).mean():.2f}')
  
  losses = list()
  accuracies = list()
  model.eval()
  for batch in val_loader:
    x, y = batch
    
    # x: b x 1 x 28 x 28
    b = x.size(0)
    x = x.view(b, -1)
    
    # 1 forward
    with torch.no_grad():
      l = model(x) # l: logits
    
    # 2 compute objective function
    J = loss(l, y)
    
    losses.append(J.item())
    accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())

  print(f'Epoch {epoch+1}', end=', ')
  print(f'validation loss: {torch.tensor(losses).mean():.2f}', end=', ')
  print(f'validation accuracy: {torch.tensor(accuracies).mean():.2f}')

Epoch 1, training loss: 0.84, training accuracy: 0.78
Epoch 1, validation loss: 0.40, validation accuracy: 0.89
Epoch 2, training loss: 0.38, training accuracy: 0.89
Epoch 2, validation loss: 0.33, validation accuracy: 0.91
Epoch 3, training loss: 0.31, training accuracy: 0.91
Epoch 3, validation loss: 0.28, validation accuracy: 0.92
Epoch 4, training loss: 0.27, training accuracy: 0.92
Epoch 4, validation loss: 0.25, validation accuracy: 0.93
Epoch 5, training loss: 0.24, training accuracy: 0.93
Epoch 5, validation loss: 0.23, validation accuracy: 0.93
