# Python U-Net 网络图像分割实验 - 以眼底 OCT 图像为例

> Copyright (c) 2023 Vincent SHI | 史文朔

本实验是基于 Python 语言的 U-Net 网络图像分割实验，以眼底 OCT 图像为例，实现了 U-Net 网络的训练、测试等功能，并以此对 Python 语言的基本语法、图像处理等知识进行了综合性的实践。

本实验依赖眼底 OCT 图像数据集，建议同学们通过参加相关竞赛自行获取。


## 实验目的

通过本实验，你必须要掌握：

- Python 语言的基本语法
- Python 语言的基本图像处理
  - 图像读取
  - 图像显示
  - 图像标注
  - 图像保存
- 诸如 U-Net 等深度学习网络的基本定义
- Python 机器学习生态库的基本调用

同时，建议你掌握：

- Python 语言的基本科学计算库
- Python 语言的基本深度学习库
- U-Net 等深度学习网络的基本训练方法
- U-Net 等深度学习网络的基本原理

你可以先不必掌握：

- Python 图像处理库的内部实现
- U-Net 等深度学习网络的数学理论基础


## 实验步骤

- 引用相关库

- 图片的读取和预处理

- 划分训练和测试集

- 定义 U-Net 网络

- 进行训练和测试

- 对验证集进行预测

- 输出预测结果


## 实验代码样例


In [None]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torchvision.transforms import ToTensor
from sklearn.model_selection import train_test_split
from pathlib import Path
from typing import *

PATH = Path.cwd()

# 设置参数
IMAGE_PATH = PATH / "work" / "Train" / "Image"  # 原图像路径
MASK_PATH = PATH / "work" / "Train" / "Layer_Masks"  # 图像分割标签路径
VAL_PATH = PATH / "work" / "Validation" / "Image"  # 验证集原图像路径
OPT_PATH = PATH / "work" / "Validation" / "Output_Marks"  # 输出标签路径
FILE_TYPE = "*.png"  # 文件类型
IMAGE_SIZE = 256  # 输入图像统一尺寸
TEST_RATIO = 0.2  # 训练/验证图像划分比例
BATCH_SIZE = 8  # 批大小
ITERS = 3000  # 训练迭代次数
OPTIMIZER_TYPE = "adam"  # 优化器, 可自行使用其他优化器，如SGD, RMSprop,...
NUM_WORKERS = 4  # 数据加载处理器个数
INIT_LR = 1e-3  # 初始学习率


# 定义数据集读取函数
FileList = List[Tuple[Path, Path]]


def init_filelist(image: Path, mask: Path, ftype: str) -> FileList:
    # 读取数据文件夹下所有文件名
    file_paths = image.glob(ftype)
    return [(p, mask / p.name) for p in file_paths]


# 定义数据集类
class Dataset(data.Dataset):
    def __init__(self, img_list: FileList, mode: str, crop=None, transform=None):
        # 断言，确保数据集模式为'train', 'test', 'val'之一
        assert mode in ["train", "test", "val"], "数据集模式必须为train, test, val之一"
        self.img_list = img_list
        if len(self.img_list) == 0:
            raise RuntimeError("找不到任何图像文件，请检查数据集路径")
        self.mode = mode
        self.transform = transform
        self.crop = crop

    def __getitem__(self, index: int):
        img_path, mask_path = self.img_list[index]
        file_name = img_path.name
        img = cv2.imread(str(img_path), 0)
        mask = cv2.imread(str(mask_path), 0)
        if self.transform is not None:
            img = self.transform(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)))
            mask = self.transform(cv2.resize(mask, (IMAGE_SIZE, IMAGE_SIZE)))
        return (img, mask), file_name

    def __len__(self):
        return len(self.img_list)


# 划分训练集和验证集
train_filelists, test_filelists = train_test_split(
    init_filelist(IMAGE_PATH, MASK_PATH, FILE_TYPE), test_size=TEST_RATIO
)

# 数据集加载
center_crop = None
input_transform = ToTensor()
train_set = Dataset(
    train_filelists, "train", crop=center_crop, transform=input_transform
)
train_loader = data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)
test_set = Dataset(test_filelists, "test", crop=center_crop, transform=input_transform)
test_loader = data.DataLoader(
    test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)


# 定义 U-Net 网络的 10 个动作
# Cony <3x3. stride-2> with Batch Normalization and ReLU
class Conv32BnRelu(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))


# ReLU
class Relu(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(x)


# Conv <3x3> with Batch Normalization and ReLU
class Conv3BnRelu(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))


# Cony <3x3> with Batch Normalization
class Conv3Bn(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return self.bn(self.conv(x))


# Max Pooling
class MaxPool(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        return self.pool(x)


# Element-wise add
class Add(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x1, x2):
        return x1 + x2


# Conv2DTranspose with Batch Normalization and ReLU
class ConvTranspose2dBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv_transpose(x)))


# Conv2DTranspose with Batch Normalization
class ConvTranspose2dBn(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return self.bn(self.conv_transpose(x))


# Upsample
class Upsample(nn.Module):
    def __init__(self):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

    def forward(self, x):
        return self.upsample(x)


# Conv <1x1>
class Conv1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x):
        return self.conv(x)


# 定义 U-Net 网络
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # Contracting path
        self.conv32_1 = Conv32BnRelu(in_channels, 32)
        self.conv32_2 = Conv32BnRelu(32, 64)
        self.conv32_3 = Conv32BnRelu(64, 128)
        self.conv32_4 = Conv32BnRelu(128, 256)
        self.conv32_5 = Conv32BnRelu(256, 512)
        self.maxpool = MaxPool()
        # Expansive path
        self.upsample = Upsample()
        self.conv32_6 = Conv32BnRelu(512, 1024)
        self.conv32_7 = Conv32BnRelu(1024, 512)
        self.conv32_8 = Conv32BnRelu(512, 256)
        self.conv32_9 = Conv32BnRelu(256, 128)
        self.conv32_10 = Conv32BnRelu(128, 64)
        self.conv1 = Conv1(64, out_channels)
        
    def forward(self, x):
        # Contracting path
        x1 = self.conv32_1(x)
        x2 = self.conv32_2(self.maxpool(x1))
        x3 = self.conv32_3(self.maxpool(x2))
        x4 = self.conv32_4(self.maxpool(x3))
        x5 = self.conv32_5(self.maxpool(x4))
        # Expansive path
        x6 = self.conv32_6(self.upsample(x5))
        x7 = self.conv32_7(torch.cat([x4, x6], dim=1))
        x8 = self.conv32_8(self.upsample(x7))
        x9 = self.conv32_9(torch.cat([x3, x8], dim=1))
        x10 = self.conv32_10(self.upsample(x9))
        x11 = self.conv1(x10)
        return x11
    
    
# 定义损失函数
class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        smooth = 1.0
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        intersection = (pred_flat * target_flat).sum()
        return 1 - ((2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))
    

# 定义训练函数
def train(
    model: nn.Module,
    train_loader: data.DataLoader,
    test_loader: data.DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    iters: int,
    device: torch.device,
):
    model.train()
    for i in range(iters):
        for (img, mask), _ in train_loader:
            img = img.to(device)
            mask = mask.to(device)
            optimizer.zero_grad()
            pred = model(img)
            loss = criterion(pred, mask)
            loss.backward()
            optimizer.step()
        if i % 10 == 0:
            print(f"第{i}次迭代的损失为: {loss.item()}")
            torch.save(model.state_dict(), "model.pth")
            test(model, test_loader, criterion, device)
    torch.save(model.state_dict(), "model.pth")
    print("训练完成")


# 定义测试函数
def test(
    model: nn.Module,
    test_loader: data.DataLoader,
    criterion: nn.Module,
    device: torch.device,
):
    model.eval()
    with torch.no_grad():
        for (img, mask), file_name in test_loader:
            img = img.to(device)
            mask = mask.to(device)
            pred = model(img)
            loss = criterion(pred, mask)
            pred = pred.cpu().numpy()
            pred = np.where(pred > 0.5, 255, 0)
            for i in range(pred.shape[0]):
                cv2.imwrite(str(OPT_PATH / file_name[i]), pred[i])
            print(f"测试集损失为: {loss.item()}")
    model.train()


# 定义主函数
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(1, 1).to(device)
    criterion = DiceLoss()
    if OPTIMIZER_TYPE == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=INIT_LR)
    elif OPTIMIZER_TYPE == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=INIT_LR)
    else:
        optimizer = torch.optim.RMSprop(model.parameters(), lr=INIT_LR)
    train(
        model,
        train_loader,
        test_loader,
        criterion,
        optimizer,
        ITERS,
        device,
    )


if __name__ == "__main__":
    main()
    

(tensor([[[0.0941, 0.0510, 0.0667,  ..., 0.0784, 0.0863, 0.0941],
          [0.0902, 0.0863, 0.0980,  ..., 0.0902, 0.0902, 0.0902],
          [0.0980, 0.0902, 0.1020,  ..., 0.0941, 0.0941, 0.0941],
          ...,
          [0.2588, 0.2431, 0.2471,  ..., 0.2471, 0.2549, 0.2510],
          [0.2196, 0.2235, 0.2039,  ..., 0.2353, 0.2314, 0.2196],
          [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039]]]),
 tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]))