<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/perceptron/perceptron_fives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Perceptron for Detecting Fives

## Installing and Importing Libraries

In [None]:
%%capture
!pip install pytorch-lightning torchviz wandb

repo_url = "https://raw.githubusercontent.com/wandb/edu/main/"
utils_path = "lightning/utils.py"
# Download a util file of helper methods for this notebook
!curl {repo_url + utils_path} --output utils.py

import math

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets
import wandb

import utils

## Defining the `Model`

In [None]:
class LitPerceptronModel(utils.LoggedImageClassifierModule):
  """A simple Perceptron Model, with under-the-hood wandb
  and pytorch-lightning features (logging, metrics, etc.).
  """

  def __init__(self, max_images_to_display=32):  # make the model
    super().__init__(max_images_to_display=max_images_to_display)
    self.perceptron = torch.nn.Linear(in_features=28 * 28, out_features=1)
    self.loss = torch.nn.MSELoss()

  def forward(self, x):  # produce outputs
    x = torch.flatten(x, start_dim=1)
    return self.perceptron(x)

  def configure_optimizers(self):  # ⚡: setup for .fit
    return torch.optim.Adam(self.parameters(), lr=0.001)

## Defining a `DataLoader`

In [None]:
class PerceptronDataModule(pl.LightningDataModule):

  def __init__(self, batch_size=64):
    super().__init__()  # ⚡: we inherit from LightningDataModule
    self.batch_size = batch_size

  def prepare_data(self): # ⚡: how do we set up the data?
    # download the data from the internet
    mnist = torchvision.datasets.MNIST(".", train=True, download=True)

    # set up shapes and types
    self.digits, self.is_5 = mnist.data.float(), (mnist.targets == 5)[:, None].float()
    self.dataset = torch.utils.data.TensorDataset(self.digits, self.is_5)

  def train_dataloader(self):  # ⚡: how do we go from dataset to dataloader?
    """The DataLoaders returned by a DataModule produce data for a model.
    
    This DataLoader is used during training."""
    return DataLoader(self.dataset, batch_size=self.batch_size)

## Building and Training the `Model`

In [None]:
dmodule  = PerceptronDataModule(batch_size=256)
lp = LitPerceptronModel(max_images_to_display=32)

dmodule.prepare_data()

### Debugging Code

In [None]:
# for debugging purposes (checking shapes, etc.), make these available
dloader = dmodule.train_dataloader()

example_batch = next(iter(dloader))
example_x, example_y = example_batch[0].to("cuda"), example_batch[1].to("cuda")

print(f"Input Shape: {example_x.shape}")
print(f"Target Shape: {example_y.shape}")

lp.to("cuda")
outputs = lp.forward(example_x)
print(f"Output Shape: {outputs.shape}")
print(f"Loss : {lp.loss(outputs, example_y)}")

### Running `.fit`

In [None]:
# 👟 configure Trainer 
trainer = pl.Trainer(gpus=1,  # use the GPU for .forward
                     logger=pl.loggers.WandbLogger(
                       project="lit-perceptron", entity="wandb",
                       save_code=True),  # log to Weights & Biases
                     max_epochs=10, log_every_n_steps=1)

# 🏃‍♀️ run the Trainer on the model
trainer.fit(lp, dmodule)

# 🏁 close out the run
wandb.finish()