# Pytorch Lightning

为什么要用 Pytorch Lightning ?

在使用Pytorch进行深度学习研究时，随着做的项目开始出现了一些稍微高阶的要求，会发现总是不断地在相似工程代码上花费大量时间，Debug也是这些代码花的时间最多，而且渐渐产生了一个矛盾之处：如果想要更多更好的功能，如TensorBoard支持，Early Stop，LR Scheduler，分布式训练，快速测试等，代码就无可避免地变得越来越长，看起来也越来越乱，同时核心的训练逻辑也渐渐被这些工程代码盖过。那么有没有更好的解决方案，Pytorch Lightning就是希望来解决这些问题。

Pytorch-Lighting 的核心设计思想是”自给自足”。每个网络也同时包含了如何训练、如何测试、优化器定义等内容。

In [1]:
import shutil

shutil.rmtree("./lightning_logs", ignore_errors=True)

# Lightning Module

Lightning的`fit`数据流伪代码：

```python
def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    setup("fit")
    configure_optimizers()
    on_fit_start()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")


def fit_loop():
    model.train()
    torch.set_grad_enabled(True)

    on_train_epoch_start()

    for batch in train_dataloader():
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        out = training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end(out, batch, batch_idx)

        # 每个epoch的最后一个batch后，执行val_loop()
        if should_check_val:
            val_loop()

    on_train_epoch_end()


def val_loop():
    on_validation_model_eval()  # calls `model.eval()`
    torch.set_grad_enabled(False)

    on_validation_start()
    on_validation_epoch_start()

    for batch_idx, batch in enumerate(val_dataloader()):
        on_validation_batch_start(batch, batch_idx)

        batch = on_before_batch_transfer(batch)
        batch = transfer_batch_to_device(batch)
        batch = on_after_batch_transfer(batch)

        out = validation_step(batch, batch_idx)

        on_validation_batch_end(out, batch, batch_idx)

    on_validation_epoch_end()
    on_validation_end()

    # set up for train
    on_validation_model_train()  # calls `model.train()`
    torch.set_grad_enabled(True)
```

## `training_step`

`training_step`对一个min-batch进行前向计算和loss计算，并将计算的loss返回，可以以`Tensor`的形式将loss返回，也可以以一个字典的形式返回，但字典中必须有一个名为'loss'的key。

# Simple Linear NN

In [2]:
from torch import nn, optim
import torch.nn.functional as F
import pytorch_lightning as pl
import torchmetrics


class LightningNN(pl.LightningModule):
    def __init__(self, input_size, num_classes, learning_rate):
        super().__init__()
        self.save_hyperparameters()
        # model
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)
        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        self.learning_rate = learning_rate
        self.accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )
        self.f1_score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

    def training_step(self, batch, batch_idx):
        loss, accuracy, f1_score = self._common_step(batch, batch_idx)
        self.log_dict(
            {
                "train_loss": loss,
                "train_accuracy": accuracy,
                "train_f1_score": f1_score,
            },
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy, f1_score = self._common_step(batch, batch_idx)
        self.log_dict(
            {
                "val_loss": loss,
                "val_accuracy": accuracy,
                "val_f1_score": f1_score,
            },
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy, f1_score = self._common_step(batch, batch_idx)
        self.log_dict(
            {
                "test_loss": loss,
                "test_accuracy": accuracy,
                "test_f1_score": f1_score,
            }
        )
        return loss

    def _common_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape((x.shape[0], -1))
        logits = self.forward(x)
        loss = self.loss_fn(logits, y)

        accuracy = self.accuracy(logits, y)
        f1_score = self.f1_score(logits, y)

        return loss, accuracy, f1_score

    def predict_step(self, batch, batch_idx):
        x, y = batch
        x = x.reshape((x.shape[0], -1))
        logits = self.forward(x)
        preds = logits.argmax(dim=1)
        return preds

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

# DataModule

In [3]:
from torchvision import datasets, transforms
from torch.utils import data
from torch.utils.data import random_split


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    # Download data on single GPU
    def prepare_data(self):
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage):
        entire_dataset = datasets.MNIST(
            self.data_dir,
            train=True,
            transform=transforms.ToTensor(),
            download=False,
        )
        self.train_ds, self.val_ds = random_split(entire_dataset, [50000, 10000])
        self.test_ds = datasets.MNIST(
            self.data_dir,
            train=False,
            transform=transforms.ToTensor(),
            download=False,
        )

    def train_dataloader(self):
        return data.DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return data.DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        return data.DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

# Config

In [4]:
import torch

torch.set_float32_matmul_precision("medium")

data_dir = "../data/"
accelerator = "gpu"
devices = [0]
input_size = 784
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 3
num_threads = 4
precision = "16-mixed"

# CallBack

Callbacks 使得我们可以对整个执行流中的一些桩点插入一些自定义代码。桩点可以参考，前面整个数据流的伪代码部分。

https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html


**Best Practices**

The following are best practices when using/designing callbacks.

* Callbacks should be isolated in their functionality.
* Your callback should not rely on the behavior of other callbacks in order to work properly.
* Do not manually call methods from the callback.
* Directly calling methods (eg. on_validation_end) is strongly discouraged.
* Whenever possible, your callbacks should not depend on the order in which they are executed.

In [5]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback

checkpoint_cb = ModelCheckpoint(
    monitor="val_loss",
    filename="simple-mnist-{epoch:02d}-{val_loss:.2f}",
    save_top_k=3,
    mode="min",
    save_last=True,
)

early_stoppping_cb = EarlyStopping(monitor="val_loss")


class MyPrintingCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is starting")

    def on_train_end(self, trainer, pl_module):
        print("Training is ending")


custom_printing_cb = MyPrintingCallback()

# Trainer

https://lightning.ai/docs/pytorch/stable/common/trainer.html

The Lightning `Trainer` does much more than just “training”. Under the hood, it handles all loop details for you, some examples include:

* Automatically enabling/disabling grads
* Running the training, validation and test dataloaders
* Calling the Callbacks at the appropriate times
* Putting batches and computations on the correct devices

In [6]:
dm = MNISTDataModule(data_dir, batch_size, num_threads)
model = LightningNN(input_size, num_classes, learning_rate)
trainer = pl.Trainer(
    accelerator=accelerator,
    devices=devices,
    min_epochs=1,
    max_epochs=num_epochs,
    precision=precision,
    callbacks=[checkpoint_cb, early_stoppping_cb, custom_printing_cb],
)
trainer.fit(model, dm)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/yangyansheng/workspace/pyml/pytorch-tutorials/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8]

  | Name     | Type               | Params
------------------------------------------------
0 | fc1      | Linear             | 39.2 K
1 | fc2      | Linear             | 510   
2 | loss_fn  | CrossEntropyLoss   | 0     
3 | accuracy | MulticlassAccuracy | 0     
4 | f1_score | MulticlassF1Score  | 0     
------------------------------------------------
39.8 K    Trainable params
0         Non-trainable params
39.8 K    Total params
0.159     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training is starting


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


Training is ending


# Validation and Test

In [7]:
model = LightningNN.load_from_checkpoint(
    "./lightning_logs/version_0/checkpoints/last.ckpt"
)
trainer = pl.Trainer(
    accelerator=accelerator,
    num_nodes=1,
    devices=1,
    precision=precision,
)
trainer.validate(model, dm)
trainer.test(model, dm)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8]


Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8]


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.16122230887413025,
  'test_accuracy': 0.9538999795913696,
  'test_f1_score': 0.9538999795913696}]

# Auto Resuming Trainer

```python
checkpoint_cb = ModelCheckpoint(
    monitor="val_loss",
    dirpath=config.checkpoint_path,
    save_top_k=1,
    mode="min",
    save_last=True,
)
resume_ckpt = os.path.join(config.checkpoint_path, "last.ckpt")
if not os.path.exists(resume_ckpt):
    resume_ckpt = None

trainer.fit(model, data_module, ckpt_path=resume_ckpt)
```