<img src="https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/docs/source/_static/images/logo.png" alt="PyTorch Lightning" width="500">

# Rapid prototyping notebook
Use this to prototype quick ideas, then move to a script to scale up!

[Remember! we're always available for support on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-12iz3cds1-uyyyBYJLiaL2bqVmMN7n~A)

---
## Setup

In [None]:
%%capture
! pip install -U pytorch-lightning

In [None]:
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics.functional import accuracy

---
## Data

In [None]:
class DummyDataset(Dataset):
    def __init__(self, *shapes, num_samples=10000):
        super().__init__()
        self.shapes = shapes
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        samples = []
        for shape in self.shapes:
            sample = torch.rand(*shape)
            samples.append(sample)

        return samples

In [None]:
train = DummyDataset((1, 28, 28), (1,))
train = DataLoader(train, batch_size=32)

In [None]:
val = DummyDataset((1, 28, 28), (1,))
val = DataLoader(val, batch_size=32)

In [None]:
test = DummyDataset((1, 28, 28), (1,))
test = DataLoader(test, batch_size=32)

---

## Model

In [None]:
class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        return self.encoder(x)

    def training_step(self, batch, batch_idx):
        # ---------------------------
        # REPLACE WITH YOUR OWN LOGIC

        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss
        # --------------------------

    def validation_step(self, batch, batch_idx):
        # ---------------------------
        # REPLACE WITH YOUR OWN LOGIC

        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)
        # --------------------------

    def test_step(self, batch, batch_idx):
        # ---------------------------
        # REPLACE WITH YOUR OWN LOGIC

        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss)
        # --------------------------

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

---
## Train
NOTE: in colab, set progress_bar_refresh_rate high or the screen will freeze because of the rapid tqdm update speed.

In [None]:
# init model
ae = LitAutoEncoder()

# Initialize a trainer
trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=5)

# Train the model ⚡
trainer.fit(ae, train_dataloaders=train, val_dataloaders=val)

---
## Test

In [None]:
trainer.test(ae, dataloaders=test)

---
## Visualize

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

---
## Observations
Do your analysis and notes here!