# Import Packages

In [None]:
from typing import Union, Sequence, Optional, Any

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
!pip install pytorch-lightning
from pytorch_lightning import LightningModule, LightningDataModule, Trainer

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/48/5e/19c817ad2670c1d822642ed7bfc4d9d4c30c2f8eaefebcd575a3188d7319/pytorch_lightning-1.3.8-py3-none-any.whl (813kB)
[K     |████████████████████████████████| 819kB 29.8MB/s 
[?25hCollecting PyYAML<=5.4.1,>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/7a/a5/393c087efdc78091afa2af9f1378762f9821c9c1d7a22c5753fb5ac5f97a/PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636kB)
[K     |████████████████████████████████| 645kB 34.2MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 45.9MB/s 
Collecting torchmetrics>=0.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/4d/8b/de8df9044ca2ac5dfc6b13b9ad3b3ebe6b3a45807311102b569d680e811f/torchmetrics-0.4.1-py3-none-any.whl (234kB)
[K     |███████

# Setup

### Model Defined by Pytorch Module

A basic way to define a neural network in Pytorch framework is implemented by Pytorch Module class.
Module class has two methods to implement.

- __init__
  - Inherit and initialize Module class
  - Build-up your model architecture in this method
- forward
  - Define the forward propagation of your model
  - Note that the backward propagation is implemented by auto gradient if your forward propagation is valid

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(8),
            nn.PReLU(),
            nn.Conv2d(in_channels=8, out_channels=16, padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(16),
            nn.PReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=16, out_channels=8, padding=1, output_padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(8),
            nn.PReLU(),
            nn.ConvTranspose2d(in_channels=8, out_channels=1, padding=1, output_padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(1),
            nn.PReLU()
        )
        self.flatten = nn.Flatten()
        self.mu      = nn.Linear(784, 32)
        self.logvar  = nn.Linear(784, 32)
        self.bridge  = nn.Linear(32, 784)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std).mul(std)
            z   = mu + eps

        else:
            z   = mu

        return z

    def forward(self, x: torch.Tensor) -> dict:
        x      = self.encoder(x)
        x      = self.flatten(x)
        mu     = self.mu(x)
        logvar = self.logvar(x)
        z      = self.reparameterize(mu, logvar)
        x_hat  = self.decoder(self.bridge(z).view(-1, 16, 7, 7))

        return {"mu": mu, "logvar": logvar, "x_hat": x_hat}

# Note: 如果你在 forward 的 return 寫
# return mu, logvar, x_hat
# 沒意外的話你過幾天就忘了哪一個位置對應哪一個變數

In [None]:
# Initialize a model
model = Model()

# Test this model
image  = torch.rand(16, 1, 28, 28) # B x C x H x W
output = model(image)

for key in output:
    print("shape of {}:".format(key), output[key].shape)

shape of mu: torch.Size([16, 32])
shape of logvar: torch.Size([16, 32])
shape of x_hat: torch.Size([16, 1, 28, 28])


### Typical Training Process

- Preparation
  - Initialize optimizer
  - Initialize dataset and data laoder
  - Initialize learning rate scheduler (optional)
- Training (At Each Epoch)
  - Clear gradients (optimizer.zero_grad()) and get batch data from data loader
  - Forward propagation and compute loss
  - Backward propagation (loss.backward()) and update parameters (optimizer.step())
  - Evaluate on validation set (optional but highly recommended)
  - Decay learning rate (optional)

### Pytroch Lightning

In the training loop, there are many details that seems redundant in your code.
For instance, optimizer.zero_grad(), loss.backward(), optimizer.step(), with torch.no_grad():, etc.
Pytorch Lightning do these routines for you.

In a LightningModule, you only implement the following methods and Pytorch Lightning Trainer will do others for you.
- __init__: same in Pytorch Module
- forward: same in Pytorch Module
- training_step: define how you compute your loss
- configure_optimizers: define your optimizer(s)
- validation_step: define how you evaluate your model (optional but highly recommended)
- validation_epoch_end: define how you summarize your evaluations (optional but highly recommended)

In [None]:
class LitModel(LightningModule):
    def __init__(self, lr=1e-4, weight_decay=1e-4):
        # 必備的 method，同 Pytorch Module
        super(LitModel, self).__init__()
        self.save_hyperparameters()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(8),
            nn.PReLU(),
            nn.Conv2d(in_channels=8, out_channels=16, padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(16),
            nn.PReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=16, out_channels=8, padding=1, output_padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(8),
            nn.PReLU(),
            nn.ConvTranspose2d(in_channels=8, out_channels=1, padding=1, output_padding=1, kernel_size=3, stride=2),
            nn.BatchNorm2d(1),
            nn.PReLU()
        )
        self.flatten = nn.Flatten()
        self.mu      = nn.Linear(784, 32)
        self.logvar  = nn.Linear(784, 32)
        self.bridge  = nn.Linear(32, 784)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        # 這個 method 單純是為了方便而定義的
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std).mul(std)
            z   = mu + eps

        else:
            z   = mu

        return z

    def forward(self, x: torch.Tensor) -> dict:
        # 必備的 method，同 Pytorch Module
        x      = self.encoder(x)
        x      = self.flatten(x)
        mu     = self.mu(x)
        logvar = self.logvar(x)
        z      = self.reparameterize(mu, logvar)
        x_hat  = self.decoder(self.bridge(z).view(-1, 16, 7, 7))

        return {"mu": mu, "logvar": logvar, "x_hat": x_hat}

    def configure_optimizers(self):
        # 必備的 method，負責吐 optimizer 出來
        return Adam(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )

    def compute_vae_loss(
        self,
        x: torch.Tensor,
        mu: torch.Tensor,
        logvar: torch.Tensor,
        x_hat: torch.Tensor
    ) -> torch.Tensor:
        # 這是 method 單純是為了方便而定義的
        bce = F.binary_cross_entropy(
            torch.sigmoid(x_hat),
            x,
            reduction="none"
        ).sum(dim=(1, 2, 3)).mean()
        kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean()

        return {"loss": bce + kld,
                "rec_loss": bce,
                "kld_loss": kld}

    def training_step(self, batch: Any, batch_idx: int) -> Any:
        # 必備的 method，input 固定是吃兩個參數
        # 吐你要拿去做 backward propagation 的 loss
        X, y       = batch
        output     = self(X)
        losses     = self.compute_vae_loss(
            x=X,
            mu=output["mu"],
            logvar=output["logvar"],
            x_hat=output["x_hat"]
        )
        total_loss = losses["loss"]

        # 利用 self.log 這個 method 把指標記錄到 tensorboard
        # 若在 self.log 當中使用參數 prog_bar=True 可以幫你把你要的指標寫進 progress bar
        # loss 原本就會被寫進 progress bar (by default)
        self.log("train_loss", total_loss)

        return total_loss

    def validation_step(self, batch: Any, batch_idx: int) -> Any:
        # 可選用的 method，input 固定是吃兩個參數
        # 吐你要算的指標
        X, y   = batch
        output = self(X)
        losses = self.compute_vae_loss(
            x=X,
            mu=output["mu"],
            logvar=output["logvar"],
            x_hat=output["x_hat"]
        )

        return losses

    def validation_epoch_end(self, outputs: Sequence) -> Any:
        # 可選用的 method，input 固定是吃 outputs 這個參數
        # 基本可以不用吐東西
        total_loss = torch.mean(torch.stack([o["loss"] for o in outputs]))
        kld_loss   = torch.mean(torch.stack([o["kld_loss"] for o in outputs]))
        rec_loss   = torch.mean(torch.stack([o["rec_loss"] for o in outputs]))

        # 利用 self.log 這個 method 把指標記錄到 tensorboard
        # 若在 self.log 當中使用參數 prog_bar=True 可以幫你把你要的指標寫進 progress bar
        # validation 的階段沒有任何 default metric，所以要顯示在 progress bar 的都要自己指定
        self.log("val_loss", total_loss, prog_bar=True)
        self.log("val_kld", kld_loss, prog_bar=True)
        self.log("val_rec", rec_loss, prog_bar=True)

### 簡介 Namespace

In [None]:
from argparse import Namespace

params_namespace = Namespace()
params_namespace

Namespace()

In [None]:
params_namespace.lr           = 1e-4
params_namespace.weight_decay = 1e-4
params_namespace.batch_size   = 32
params_namespace

Namespace(batch_size=32, lr=0.0001, weight_decay=0.0001)

In [None]:
params_namespace.lr

0.0001

In [None]:
params_dict = {
    "lr": 1e-4,
    "weight_decay": 1e-4,
    "batch_size": 32
}
params_dict

{'batch_size': 32, 'lr': 0.0001, 'weight_decay': 0.0001}

In [None]:
# Dictionary to namespace
Namespace(**params_dict)

Namespace(batch_size=32, lr=0.0001, weight_decay=0.0001)

In [None]:
# Namespace to dictionary
vars(params_namespace)

{'batch_size': 32, 'lr': 0.0001, 'weight_decay': 0.0001}

In [None]:
# Python 物件拆封的範例
a = [1, 2, 3]
[*a, 4]

[1, 2, 3, 4]

In [None]:
# Initialize a lightning model
lit_model = LitModel()

# Test this model
image  = torch.rand(16, 1, 28, 28) # B x C x H x W
output = lit_model(image)

for key in output:
    print("shape of {}:".format(key), output[key].shape)

shape of mu: torch.Size([16, 32])
shape of logvar: torch.Size([16, 32])
shape of x_hat: torch.Size([16, 1, 28, 28])


### Typical Pytorch Data Loader

A typical Pytorch data loader needs a dataset as its parameter.
Recall that a dataset is an iterable object in Python.
We usually implenment 3 methods for a Pytorch dataset.

- __init__
- __getitem__
- __len__

By indicating the batch size, data loader automatically generate batch data parallelly.

In [None]:
# 以前的寫法
# for batch in dataloader:
#     ...
# for batch_idx, batch in enumerate(dataloader):
#     ...

In [None]:
for i in range(100, 200, 30):
    print(i)

for i, data in enumerate(range(100, 200, 30)):
    print(i, data)

100
130
160
190
0 100
1 130
2 160
3 190


### Pytorch Lightning DataModule
A LightningDataModule can be viewed as an end-to-end wrapping of Data Loader.
Users can simply call a method defined in a LightingDataModule object and obtain a data loader.
More precisely, we usually implement the following methods for a LightningDataModule.

- __init__: inherit LightningDataModule and save hyperparameters
- prepare_data: (optional but highly recommended)
- train_dataloader: return a training data loader (optional but highly recommended)
- val_dataloader: return a validation data laoder (optional but highly recommended)
- test_dataloader: return a test data loader (optional but highly recommended)

In [None]:
class DataModule(LightningDataModule):
    def __init__(self, batch_size: int = 256, data_dir: str = "data"):
        super(DataModule, self).__init__()
        self.batch_size = batch_size
        self.data_dir   = data_dir

    def prepare_data(self):
        self.val_data   = MNIST(self.data_dir, train=False, download=True, transform=ToTensor())
        self.train_data = MNIST(self.data_dir, train=True, download=True, transform=ToTensor())

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size)

# Main

In [None]:
lit_model  = LitModel()
datamodule = DataModule()
trainer    = Trainer()
trainer.fit(lit_model, datamodule=datamodule)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  "GPU available but not used. Set the gpus flag in your trainer"


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 1.3 K 
1 | decoder | Sequential | 1.3 K 
2 | flatten | Flatten    | 0     
3 | mu      | Linear     | 25.1 K
4 | logvar  | Linear     | 25.1 K
5 | bridge  | Linear     | 25.9 K
---------------------------------------
78.7 K    Trainable params
0         Non-trainable params
78.7 K    Total params
0.315     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…