# 深度学习框架 PyTorch Lightning 和医学图像分析框架 MONAI 教程

### PyTorch Lightning

<img alt="Lightning" src="https://pl-public-data.s3.amazonaws.com/assets_lightning/LightningColor.png" width="400px" style="max-width: 50%;">

PyTorch Lightning 是一个轻量级的基于 PyTorch 的高层次模型接口，它提供了一种简单的方式来组织 PyTorch 代码，使得代码更加模块化、可读性更强、可维护性更高。它的设计目标是让研究人员专注于模型的设计，而不是训练过程的实现。

它的核心思想是将训练过程分为 5 个部分：`LightningModule`、`LightningDataModule`、`Trainer`、`Callbacks` 和 `LightningLogger`。

- `LightningModule` 是模型的核心，它包含了模型的定义、前向传播、损失函数、优化器等；
- `LightningDataModule` 是数据的核心，它包含了数据的加载、预处理、划分等；
- `Trainer` 是训练过程的核心，它包含了训练过程的超参数、优化器、学习率调整策略、训练过程的配置等；
- `Callbacks` 是训练过程的钩子函数，它包含了训练过程中的一些回调函数，如模型保存、学习率调整、训练过程可视化等；
- `LightningLogger` 是训练过程的日志记录器，它包含了训练过程中的一些日志记录，如训练过程的可视化、训练过程的日志记录等。

### MONAI

<img alt="MONAI" src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" width="400px" style="max-width: 50%;">

MONAI (Medical Open Network for AI) 是一个开源的医学图像分析框架，它旨在提供一个用于医学图像分析的端到端的深度学习框架，以加速 AI 在医学图像分析领域的研究和应用。

MONAI 提供了一系列的工具，如数据加载、数据预处理、数据划分、模型定义、模型训练、模型评估、模型推理等，它们都是基于 PyTorch 实现的，因此可以很方便地与 PyTorch 结合使用。


## PyTorch Lightning

PyTorch Lightning 从 2.0 版本开始更名为 Lightning，文档也从 [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) 更名为 [Lightning](https://lightning.ai/docs/pytorch/latest/)。为了避免混淆，在下文中，我们将 PyTorch Lightning 简称为 Lightning。

### 安装

Lightning 的安装非常简单。

若想用 pip 安装，可以使用以下命令：

```bash
pip install lightning
```

若想用 conda 安装，可以使用以下命令：

```bash
conda install lightning -c conda-forge
```

安装完成后，可以使用以下命令查看版本信息：

```bash
lightning --version
```

In [1]:
!lightning --version

lightning, version 2.0.3



### 从 PyTorch 迁移到 Lightning

Lightning 的设计目标是让研究人员专注于模型的设计，而不是训练过程的实现。因此，Lightning 的使用非常简单，只需要将 PyTorch 的代码稍作修改即可。

下面我们以训练一个简单的 MNIST 分类模型为例，来介绍如何从 PyTorch 迁移到 Lightning。

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

# Set the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)  # Flatten the input tensor
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Define the training parameters
batch_size = 64
learning_rate = 0.01
num_epochs = 10
validation_split = 0.2

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root=os.getcwd(), train=True, transform=transform, download=True)

# Split the dataset into training and validation sets
train_size = int((1 - validation_split) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

test_dataset = datasets.MNIST(root=os.getcwd(), train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model and move it to the device
model = Net().to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

            # Update the progress bar
            pbar.update(1)
            pbar.set_postfix({"Train Loss": train_loss / ((pbar.n - 1) * train_loader.batch_size + images.size(0))})

    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for val_images, val_labels in val_loader:
            val_images, val_labels = val_images.to(device), val_labels.to(device)

            val_outputs = model(val_images)
            val_loss += criterion(val_outputs, val_labels).item()

            _, val_predicted = torch.max(val_outputs.data, 1)
            val_total += val_labels.size(0)
            val_correct += (val_predicted == val_labels).sum().item()

    val_accuracy = 100 * val_correct / val_total
    val_loss /= len(val_loader)

    # Print validation metrics
    tqdm.write(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

# Test the model
model.eval()
total_correct = 0
total_samples = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

# Print the accuracy on the test set
accuracy = 100 * total_correct / total_samples
print(f"Test Accuracy: {accuracy:.2f}%")

# Save the model checkpoint
torch.save(model.state_dict(), "mnist.ckpt")

# Load the model checkpoint
model.load_state_dict(torch.load("mnist.ckpt"))


Epoch 1/10: 100%|██████████| 750/750 [00:05<00:00, 133.99batch/s, Train Loss=1.21]


Validation Loss: 0.5129, Validation Accuracy: 85.13%


Epoch 2/10: 100%|██████████| 750/750 [00:05<00:00, 136.30batch/s, Train Loss=0.425]


Validation Loss: 0.3663, Validation Accuracy: 88.98%


Epoch 3/10: 100%|██████████| 750/750 [00:05<00:00, 135.68batch/s, Train Loss=0.351]


Validation Loss: 0.3373, Validation Accuracy: 89.54%


Epoch 4/10: 100%|██████████| 750/750 [00:05<00:00, 135.83batch/s, Train Loss=0.318]


Validation Loss: 0.2997, Validation Accuracy: 90.92%


Epoch 5/10: 100%|██████████| 750/750 [00:05<00:00, 135.64batch/s, Train Loss=0.296]


Validation Loss: 0.2806, Validation Accuracy: 91.36%


Epoch 6/10: 100%|██████████| 750/750 [00:05<00:00, 133.13batch/s, Train Loss=0.277]


Validation Loss: 0.2771, Validation Accuracy: 91.31%


Epoch 7/10: 100%|██████████| 750/750 [00:05<00:00, 133.94batch/s, Train Loss=0.259]


Validation Loss: 0.2523, Validation Accuracy: 92.33%


Epoch 8/10: 100%|██████████| 750/750 [00:05<00:00, 135.96batch/s, Train Loss=0.243]


Validation Loss: 0.2302, Validation Accuracy: 92.88%


Epoch 9/10: 100%|██████████| 750/750 [00:05<00:00, 134.25batch/s, Train Loss=0.226]


Validation Loss: 0.2173, Validation Accuracy: 93.35%


Epoch 10/10: 100%|██████████| 750/750 [00:05<00:00, 135.07batch/s, Train Loss=0.211]


Validation Loss: 0.2018, Validation Accuracy: 93.82%
Test Accuracy: 94.11%


<All keys matched successfully>

这份 PyTorch 代码实现了一个用于手写数字识别的神经网络模型。代码主要包括以下几个部分：

1. **导入所需的库**：包括 os、torch 以及 torch 的相关模块、数据集模块以及进度条模块。
2. **设置设备**：根据可用的硬件情况，将模型的计算设备设置为 CPU 或 GPU。
3. **定义神经网络模型的架构**：定义了一个包含三个全连接层的模型，输入大小为 784，输出大小为 10。
4. **定义训练参数**：包括批处理大小、学习率、训练周期数和验证集的比例。
5. **加载 MNIST 数据集**：使用 torchvision 中的 datasets 模块加载 MNIST 数据集，并进行数据预处理。
6. **将数据集划分为训练集和验证集**：使用 random_split 函数将数据集划分为训练集和验证集。
7. **创建数据加载器**：使用 DataLoader 模块创建训练集、验证集和测试集的数据加载器，用于批量加载数据。
8. **初始化模型并将其移动到设备上**：创建模型实例，并将模型移动到之前设置的设备上。
9. **定义损失函数和优化器**：使用交叉熵损失函数和随机梯度下降（SGD）优化器。
10. **训练循环**：对于每个训练周期，迭代训练集中的批次数据进行训练。计算输出、损失、梯度，然后更新模型参数。
11. **验证过程**：在每个训练周期后，使用验证集评估模型的性能。计算验证集上的损失和准确率。
12. **测试模型**：使用测试集评估模型的性能。计算测试集上的准确率。
13. **保存和加载模型**：保存训练好的模型权重参数到文件中，并可以加载之前保存的模型。

总体来说，这份代码实现了一个基于 PyTorch 的手写数字识别模型的训练、验证和测试过程，并提供了保存和加载模型的功能。

从上面的 PyTorch 代码可以看出，即使是一个简单的 MNIST 分类模型，也需要写很多代码，而且代码的可读性和可维护性都不高。下面我们来看看如何用 Lightning 重构这个模型。

### LightningDataModule

首先，我们需要将数据的加载、预处理、划分等操作封装到 `LightningDataModule` 中，这样可以使得数据的加载、预处理、划分等操作更加模块化、可读性更强、可维护性更高。

In [3]:
import os
import torch
import torchvision
from torchvision import datasets, transforms
import lightning.pytorch as pl
from torch.utils.data import DataLoader, random_split


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = os.getcwd(), batch_size: int = 64, num_workers: int = 4):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

这样一来，我们就可以将数据的加载、预处理、划分等操作从模型中分离出来，使得模型的定义更加简洁。

### LightningModule

接下来，我们需要将模型的定义、前向传播、损失函数、优化器等操作封装到 `LightningModule` 中，这样可以使得模型的定义、前向传播、损失函数、优化器等操作更加模块化、可读性更强、可维护性更高。

In [4]:
import lightning.pytorch as pl
import torch.nn as nn
import torch


# Define the neural network architecture
class MNISTModule(pl.LightningModule):
    def __init__(self, learning_rate=0.01):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.loss = nn.CrossEntropyLoss()
        self.lr = learning_rate

    def forward(self, x):
        x = x.view(-1, 784)  # Flatten the input tensor
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        _, predicted = torch.max(logits.data, 1)
        correct = (predicted == y).sum().item()
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_acc', correct/len(y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {'test_loss': loss, 'correct': correct, 'total': len(y)}


### Trainer

在定义好 `LightningDataModule` 和 `LightningModule` 后，我们就可以使用 `Trainer` 来训练模型了。

`Trainer` 是 Lightning 提供的一个用于训练模型的类，它封装了训练过程中的大量细节，使得训练过程更加简单。

In [5]:
# Initialize data module and model
data_module = MNISTDataModule()
model = MNISTModule()

callbacks = [
    pl.callbacks.EarlyStopping(monitor='val_loss'),
    pl.callbacks.ModelCheckpoint(monitor='val_loss')
]

# Initialize trainer
trainer = pl.Trainer(
    max_epochs=20, 
    precision=16,
    callbacks=callbacks,
    )

# from lightning.pytorch.tuner import Tuner
# tuner = Tuner(trainer)

# # tuner.scale_batch_size(model, data_module)
# tuner.lr_find(model, data_module)

# Fit the model
trainer.fit(model, data_module)

  rank_zero_warn(
  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type             | Params
------------------------------------------
0 | fc1  | Linear           | 100 K 
1 | fc2  | Linear           | 8.3 K 
2 | fc3  | Linear           | 650   
3 | loss | CrossEntropyLoss | 0     
------------------------------------------
109 K     Trainable params
0         Non-trainable params
109 K     Total params
0.438     Total estimated model params size (MB)
2023-06-19 16:02:54.480452: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-19 16:02:54.506293: I tensorflow/core/platform/cpu_feature_guard.c

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [6]:

# Test the model
trainer.test(model, data_module)

Testing: 0it [00:00, ?it/s]

[{'test_loss_epoch': 0.0872071385383606, 'test_acc_epoch': 0.9739000201225281}]