In [1]:
%%capture
!pip install -qqq wandb pytorch-lightning

In [2]:
# numpy for non-GPU array math
import numpy as np

import torch
from torch.nn import functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split

from torchvision.datasets import MNIST
from torchvision import transforms

In [3]:
# Import pytorch lightning
import pytorch_lightning as pl
pl.seed_everything(hash("setting random seeds") % 2**32 - 1)

# weights and biases
import wandb

# Use the wandb logger
from pytorch_lightning.loggers import WandbLogger

# login to wandb
wandb.login()

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/azureuser/.netrc


True

In [5]:
class LitMLP(pl.LightningModule):

    def __init__(self, in_dims, n_classes=10,
                 n_layer_1=128, n_layer_2=256, lr=1e-4):
        super().__init__()

        # we flatten the input Tensors and pass them through an MLP
        self.layer_1 = nn.Linear(np.prod(in_dims), n_layer_1)
        self.layer_2 = nn.Linear(n_layer_1, n_layer_2)
        self.layer_3 = nn.Linear(n_layer_2, n_classes)

        # log hyperparameters (saves to self.hparams, which is logged to wandb as the config)
        self.save_hyperparameters()

        # compute the accuracy -- no need  to roll your own!
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def forward(self, x):
        """
        Defines a forward pass using the Stem-Learner-Task
        design pattern from Deep Learning Design Patterns:
        https://www.manning.com/books/deep-learning-design-patterns
        """
        batch_size, *dims = x.size()

        # stem: flatten
        x = x.view(batch_size, -1)

        # learner: two fully-connected layers
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        
        # task: compute class logits
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)

        return x

    # convenient method to get the loss on a batch
    def loss(self, xs, ys):
        logits = self(xs)  # this calls self.forward
        loss = F.nll_loss(logits, ys)
        return logits, loss
    
    # takes a batch and computes the loss; backprop goes through it
    def training_step(self, batch, batch_idx):
        xs, ys = batch
        logits, loss = self.loss(xs, ys)

        # logging metrics we calculated by hand
        # Here's the docs for reference https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log
        # this takes a name and value, and under the hood it uses wandb.log
        self.log('train/loss', loss, on_epoch=True) # if you do on_step=False (by default this is true) then it'll only do epoch wise averaging outputs, see test_step below
        # logging a pl.Metric
        self.train_acc(logits, ys)
        self.log('train/acc', self.train_acc, on_epoch=True) 
    
        return loss
    
    # returns the torch.optim.Optimizer to apply after the training_step
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])
    
    # validation_step and test_step will trigger on each batch
    # {training, validation, test}_epoch_end will trigger at end of epoch, or a full pass over a given dataset
    
    def test_step(self, batch, batch_idx):
        xs, ys = batch
        logits, loss = self.loss(xs, ys)
        self.test_acc(logits, ys)
        self.log("test/loss_epoch", loss, on_step=False, on_epoch=True)
        self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True)
        
    # save the model after we are done with testing, we will use ONNX format (https://onnx.ai/) cause it lets us use nice things like the neutron model viewer in W&B (https://github.com/lutzroeder/netron)
    def test_epoch_end(self, test_step_outputs):  # args are defined as part of pl API
        dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
        model_filename = "model_final.onnx"
        torch.onnx.export(self, dummy_input, model_filename)
        wandb.save(model_filename)
        
    # return the logits so they can be used by validation_epoch_end
    def validation_step(self, batch, batch_idx):
        xs, ys = batch
        logits, loss = self.loss(xs, ys)
        preds = torch.argmax(logits, 1)
        self.valid_acc(logits, ys)

        self.log("valid/loss_epoch", loss)  # default on val/test is on_epoch only
        self.log('valid/acc_epoch', self.valid_acc)

        return logits
    
    # example of how to log the logits as a histogram
    def validation_epoch_end(self, validation_step_outputs):
        dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
        model_filename = f"model_{str(self.global_step).zfill(5)}.onnx" # save the model on every epoch end, so we can roll it back if needed
        torch.onnx.export(self, dummy_input, model_filename)
        wandb.save(model_filename)

        flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
        self.logger.experiment.log(
            {"valid/logits": wandb.Histogram(flattened_logits.to("cpu")),
             "global_step": self.global_step})

In [6]:
# custom callback that logs input images and output predictions
class ImagePredictionLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.val_imgs, self.val_labels = val_samples
        self.val_imgs = self.val_imgs[:num_samples]
        self.val_labels = self.val_labels[:num_samples]
          
    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)

        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)

        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                            for x, pred, y in zip(val_imgs, preds, self.val_labels)],
            "global_step": trainer.global_step
            })

In [7]:
# Data loader
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir='./', batch_size=128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

    # optional, only called once and one 1 GPU, typically for something like data download
    def prepare_data(self):
        # download data, train then test
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    # called on each GPU seperately and accepts stage to define if we are at fit or test step
    def setup(self, stage=None):

        # we set up only relevant datasets when stage is specified
        if stage == 'fit' or stage is None:
            mnist = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    # we define a separate DataLoader for each of train/val/test
    def train_dataloader(self):
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        mnist_val = DataLoader(self.mnist_val, batch_size=10 * self.batch_size)
        return mnist_val

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, batch_size=10 * self.batch_size)
        return mnist_test

In [8]:
# setup data
mnist = MNISTDataModule()
mnist.setup()

# grab samples to log predictions on
samples = next(iter(mnist.val_dataloader()))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)






In [9]:
# connect to logging weights and biases, see documentation for more info https://docs.wandb.com/integrations/lightning
wandb_logger = WandbLogger(project="lit-wandb")

In [10]:
# create the trainer
trainer = pl.Trainer(
    logger=wandb_logger,    # W&B integration
    log_every_n_steps=50,   # set the logging frequency
    gpus=-1,                # use all GPUs
    max_epochs=1,           # number of epochs
    deterministic=True,     # keep it deterministic
    callbacks=[ImagePredictionLogger(samples)] # see Callbacks section
    )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [11]:

# setup model
model = LitMLP(in_dims=(1, 28, 28))

# fit the model
trainer.fit(model, mnist)

# evaluate the model on a test set
trainer.test(datamodule=mnist,
             ckpt_path=None)  # uses last-saved model

# Note: When visiting your run page, it is recommended to use global_step as x-axis to correctly superimpose metrics logged in different stages.
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mphylliida[0m (use `wandb login --relogin` to force relogin)



  | Name      | Type     | Params
---------------------------------------
0 | layer_1   | Linear   | 100 K 
1 | layer_2   | Linear   | 33.0 K
2 | layer_3   | Linear   | 2.6 K 
3 | train_acc | Accuracy | 0     
4 | valid_acc | Accuracy | 0     
5 | test_acc  | Accuracy | 0     
---------------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc_epoch': tensor(0.9096, device='cuda:0'),
 'test/loss_epoch': tensor(0.3162, device='cuda:0'),
 'valid/acc_epoch': tensor(0.9050, device='cuda:0'),
 'valid/loss_epoch': tensor(0.3346, device='cuda:0')}
--------------------------------------------------------------------------------


VBox(children=(Label(value=' 1.07MB of 1.59MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.67260250199…

0,1
global_step,429.0
_step,859.0
_runtime,25.0
_timestamp,1607960427.0
train/loss_step,0.39951
train/acc_step,0.86719
valid/loss_epoch,0.3346
valid/acc_epoch,0.905
epoch,0.0
train/loss_epoch,0.7545


0,1
global_step,▁▁██
_step,▁▁▁▂▂▃▃▃▄▄▄▄█
_runtime,▁▁▂▃▄▄▅▅▆▇▇██
_timestamp,▁▁▂▃▄▄▅▅▆▇▇██
train/loss_step,█▄▂▂▂▁▁▁
train/acc_step,▁▃▇▇▆██▇
valid/loss_epoch,▁
valid/acc_epoch,▁
epoch,▁▁
train/loss_epoch,▁


> _Note_: In notebooks, we need to call `wandb.finish()` when to indicate that we've finished our run. This isn't necessary in scripts.