<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/perceptron/mlp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Multilayer Perceptron for MNIST

## Installing and Importing Libraries

In [None]:
%%capture
!pip install pytorch-lightning wandb

repo_url = "https://raw.githubusercontent.com/wandb/edu/main/"
utils_path = "lightning/perceptron/utils.py"
# Download a util file of helper methods for this notebook
!curl {repo_url + utils_path} --output utils.py

import math

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets
import wandb

import utils

## Defining the `Model`

In [None]:
class FullyConnected(pl.LightningModule):

  def __init__(self, in_features, out_features, activation=lambda xs: xs):
    super().__init__()
    self.linear = torch.nn.Linear(in_features, out_features)
    self.activation = activation  # defaults to passing inputs unchanged

  def forward(self, x):
    return self.activation(self.linear(x))

class LitMLP(utils.LoggedImageClassifierModule):
  """A simple MLP Model, with under-the-hood wandb
  and pytorch-lightning features (logging, metrics, etc.).
  """

  def __init__(self, config, max_images_to_display=32):  # make the model
    super().__init__(max_images_to_display=max_images_to_display)

    self.layers = torch.nn.ModuleList([  # specify our LEGOs. edit this by adding to the list!
        FullyConnected(in_features=28 * 28, activation=config["activation"],
                       out_features=config["fc1.size"]),  # hidden layer
        FullyConnected(in_features=config["fc1.size"],  # "read-out" layer
                       out_features=10),
    ])

    self.loss = config["loss"]
    self.optimizer = config["optimizer"]
    self.optimizer_params = config["optimizer.params"]

  def forward(self, x):  # produce outputs
    x = torch.flatten(x, start_dim=1)
    for layer in self.layers:  # snap together the LEGOs
      x = layer(x)
    return F.log_softmax(x, dim=1)  # compute log of softmax, for numerical reasons

  def configure_optimizers(self):  # ⚡: setup for .fit
    return self.optimizer(self.parameters(), **self.optimizer_params)

## Defining the `DataModule` & `DataLoader`

In [None]:
class MNISTDataModule(pl.LightningDataModule):

  def __init__(self, batch_size=64):
    super().__init__()  # ⚡: we inherit from LightningDataModule
    self.batch_size = batch_size

  def prepare_data(self, validation_size=10_000): # ⚡: how do we set up the data?
    # download the data from the internet
    mnist = torchvision.datasets.MNIST(".", train=True, download=True)

    # set up shapes and types
    self.digits, self.labels = mnist.data.float(), mnist.targets
    self.digits = torch.divide(self.digits, 255.)

    self.training_data = torch.utils.data.TensorDataset(self.digits[:-validation_size],
                                                        self.labels[:-validation_size])
    self.validation_data = torch.utils.data.TensorDataset(self.digits[-validation_size:],
                                                          self.labels[-validation_size:])
    self.validation_size = validation_size

  def train_dataloader(self):  # ⚡: how do we go from dataset to dataloader?
    """The DataLoaders returned by a DataModule produce data for a model.
    
    This DataLoader is used during training."""
    return DataLoader(self.training_data, batch_size=self.batch_size)

  def val_dataloader(self):  # ⚡: what about during validation?
    """The DataLoaders returned by a DataModule produce data for a model.
    
    This DataLoader is used during validation, at the end of each epoch."""
    return DataLoader(self.validation_data, batch_size=self.validation_size)

## Building and Training the `Model`

In [None]:
config = {
    "batch_size": 256,
    "max_epochs": 10,
    "fc1.size": 32,
    "activation": torch.nn.ReLU(),
    "loss": torch.nn.NLLLoss(), 
    "optimizer": torch.optim.SGD,
    "optimizer.params": {"lr": 0.01},
}

dmodule  = MNISTDataModule(batch_size=config["batch_size"])
lmlp = LitMLP(config, max_images_to_display=32)

### Model Information

In [None]:
print(lmlp)
print(f"Parameter Count: {lmlp.count_params()}")

In [None]:
# for debugging purposes (checking shapes, etc.), make these available
dmodule.prepare_data()
dloader = dmodule.train_dataloader()

example_batch = next(iter(dloader))
example_x, example_y = example_batch[0].to("cuda"), example_batch[1].to("cuda")

print(f"Input Shape: {example_x.shape}")
print(f"Target Shape: {example_y.shape}")

lmlp.to("cuda")
outputs = lmlp.forward(example_x)
print(f"Output Shape: {outputs.shape}")
print(f"Loss : {lmlp.loss(outputs, example_y)}")

### Running `.fit`

In [None]:
# 👟 configure Trainer 
trainer = pl.Trainer(gpus=1,  # use the GPU for .forward
                     logger=pl.loggers.WandbLogger(
                       project="lit-mlp", entity="wandb", config=config,
                       save_code=True),  # log to Weights & Biases
                     max_epochs=config["max_epochs"], log_every_n_steps=10)

# 🏃‍♀️ run the Trainer on the model
trainer.fit(lmlp, dmodule)

# 💾 save the model
lmlp.to_onnx("model.onnx", example_x, export_params=True)
wandb.save("model.onnx")

# 🏁 close out the run
wandb.finish()