In [1]:
# # 训练unet模型
# 1.搭建unet模型
# 2.自定义loss 函数
# 3.开始训练

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import glob
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

In [3]:
class SegmentDataset(Dataset):
    def __init__(self, where='train', seq=None):  # 根据需要是否做数据增强（seq）
        # 获取npy文件
        # 图片列表
        self.img_list = glob.glob('processed/{}/*/img_*'.format(where))
        # 数据增强的处理
        self.seq = seq

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        # 获取具体某一个数据
        img_file = self.img_list[idx]
        # 获取标注文件名
        mask_file = img_file.replace('img', 'label')

        # 加载
        img = np.load(img_file)
        mask = np.load(mask_file)

        # 数据增强
        if self.seq:
            segmap = SegmentationMapsOnImage(mask, mask.shape)
            img, mask = self.seq(image=img, segmentation_maps=segmap)
            # 获取数组内容
            mask = mask.get_arr()
        # 扩张维度变为张量
        return np.expand_dims(img, 0), np.expand_dims(mask, 0)

In [4]:
# 数据增强处理流程
seq = iaa.Sequential([
    iaa.Affine(scale=(0.8, 1.2),  # 缩放
               rotate=(-45, 45)),  # 旋转
    iaa.ElasticTransformation()  # 变换
])

In [5]:
# 使用dataloader加载
batch_size = 12
num_workers = 0

train_dataset = SegmentDataset('train', seq)
test_dataset = SegmentDataset('test', None)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [6]:
# 导入UNet模型
# 两次卷积操作
class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.step = torch.nn.Sequential(
            # 第一次卷积 (不改变大小，只改变输出通道数)
            torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1),
            #ReLU
            torch.nn.ReLU(),
            # 第二次卷积 (不改变大小，不改变输出通道数)
            torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1),
            #ReLU
            torch.nn.ReLU()
        )

    def forward(self, x):
        return self.step(x)

In [7]:
# 定义网络架构 (下采样：最大池化 上采样：双线性插值 特征融合：)
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 定义左半部分
        self.layer1 = ConvBlock(1, 64)
        self.layer2 = ConvBlock(64, 128)
        self.layer3 = ConvBlock(128, 256)
        self.layer4 = ConvBlock(256, 512)

        # 定义右半部分
        self.layer5 = ConvBlock(256 + 512, 256)
        self.layer6 = ConvBlock(128 + 256, 128)
        self.layer7 = ConvBlock(64 + 128, 64)

        # 最后一个卷积
        self.layer8 = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)

        #池化
        self.Maxpool = torch.nn.MaxPool2d(kernel_size=2)
        # 上采样 -- scale_factor:放大倍数
        self.UpSample = torch.nn.Upsample(scale_factor=2, mode='bilinear')

        #sigmoid
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        # 对输入数据 x进行处理 (下采样)
        # input:(1*256*256) output:(64*256*256)
        x1 = self.layer1(x)
        # 池化
        # input:(64*256*256) output:(64*128*128)
        x1_mp = self.Maxpool(x1)

        # input:(64*128*128) output:(128*128*128)
        x2 = self.layer2(x1_mp)
        # input:(128*128*128) output:(128*64*64)
        x2_mp = self.Maxpool(x2)

        # input:(128*64*64) output:(256*64*64)
        x3 = self.layer3(x2_mp)
        # input:(256*64*64) output:(256*32*32)
        x3_mp = self.Maxpool(x3)

        # input:(256*32*32) output:(512*32*32)
        x4 = self.layer4(x3_mp)

        # 上采样部分
        # input:(512*32*32) output:(512*64*64)
        x5 = self.UpSample(x4)
        # 特征拼接 x3 和 x5
        x5 = torch.cat([x5, x3], dim=1)  # 在通道维度上拼接   output:(768*64*64)
        # 卷积 intput:(768*64*64)  output:(256*64*64)
        x5 = self.layer5(x5)

        # intput:(256*64*64)  output:(256*128*128)
        x6 = self.UpSample(x5)
        # 拼接 在通道维度上拼接 output:(384,128,128)
        x6 = torch.cat([x6, x2], dim=1)
        # 卷积 intput:(384*128*128)  output:(128*128*128)
        x6 = self.layer6(x6)

        # intput:(128*128*128) output:(128*256*256)
        x7 = self.UpSample(x6)
        # 拼接 在通道维度上拼接 output:(64+ 128*256*256)
        x7 = torch.cat([x7, x1], dim=1)
        # 卷积 input:(192*256*256) output:(64*256*256)
        x7 = self.layer7(x7)

        # 最后一次卷积
        # input:(64*256*256) output:(1*256*256)
        x8 = self.layer8(x7)

        #sigmoid
        x9 = self.sigmoid(x8)

        return x9

In [8]:
# 测试模型
from torchsummary import summary

In [9]:
device = torch.device('cuda:0')

In [10]:
model = UNet().to(device)

In [11]:
summary(model, (1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
         ConvBlock-5         [-1, 64, 256, 256]               0
         MaxPool2d-6         [-1, 64, 128, 128]               0
            Conv2d-7        [-1, 128, 128, 128]          73,856
              ReLU-8        [-1, 128, 128, 128]               0
            Conv2d-9        [-1, 128, 128, 128]         147,584
             ReLU-10        [-1, 128, 128, 128]               0
        ConvBlock-11        [-1, 128, 128, 128]               0
        MaxPool2d-12          [-1, 128, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         295,168
             ReLU-14          [-1, 256,

In [12]:
# 模拟输入输出
random_input = torch.randn((1, 1, 256, 256)).to(device)  # 第一个1代表当前批次数

In [13]:
output = model(random_input)

In [14]:
output

tensor([[[[0.5272, 0.5278, 0.5274,  ..., 0.5283, 0.5276, 0.5276],
          [0.5258, 0.5260, 0.5273,  ..., 0.5282, 0.5286, 0.5284],
          [0.5259, 0.5280, 0.5285,  ..., 0.5283, 0.5267, 0.5285],
          ...,
          [0.5263, 0.5283, 0.5263,  ..., 0.5279, 0.5268, 0.5287],
          [0.5272, 0.5279, 0.5276,  ..., 0.5262, 0.5279, 0.5293],
          [0.5270, 0.5271, 0.5273,  ..., 0.5277, 0.5280, 0.5281]]]],
       device='cuda:0', grad_fn=<SigmoidBackward0>)

In [15]:
# 训练
# 定义Dice loss系数


In [16]:
class Diceloss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask):
        # 把数值拉成向量
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)

        # 交集面积 （两个向量相乘的结果大致为其相交面积）
        overlap = (pred * mask).sum()
        # 分母
        denum = pred.sum() + mask.sum() + 1e-8

        dice = (2 * overlap) / denum

        return 1 - dice


In [17]:
# loss
loss_fn = Diceloss()

In [18]:
# 优化去
optimizer = torch.optim.Adam(model.parameters(),lr = 0.0001)

In [19]:
# 动态减少学习率
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler =  ReduceLROnPlateau(optimizer,'min')

In [20]:
# 开始训练
from torch.utils.tensorboard import SummaryWriter

In [21]:
writer = SummaryWriter(log_dir='./log')

In [25]:
import time

In [26]:
# 计算测试集loss
def check_test_loss(loader,model):
    # 记录loss
    loss =0
    # 不记录梯度
    with torch.no_grad():
        # 遍历测试数据
        for i, (x,y) in enumerate(loader):
            # 获取图像
            x = x.to(device,dtype = torch.float32)
            # 获取标注
            y = y.to(device,dtype = torch.float32)

            # 获取预测值
            y_pred = model(x)
            loss_batch = loss_fn(y_pred,y)

            loss += loss_batch
    loss = loss / len(loader)
    return loss

In [27]:
EPOCH = 200
# 记录最小测试loss
best_test_loss = 100
for epoch in range(EPOCH):
    # 获取每一批次的图像信息

    # 计算整批数据的loss
    loss = 0
    #记录一个epoch运行的时间
    start_time = time.time()
    for i, (x,y) in enumerate(train_loader):
        #每次update更新梯度
        model.zero_grad()
        # 获取图像
        x = x.to(device,dtype = torch.float32)
        # 获取标注
        y = y.to(device,dtype = torch.float32)

        # 获取预测值
        y_pred = model(x)
        loss_batch = loss_fn(y_pred,y)

        #梯度
        loss_batch.backward()

        # 计算梯度
        optimizer.step()
        optimizer.zero_grad()

        # 获取每个batch的训练loss
        loss_batch = loss_batch.detach().to('cpu')
        print(loss_batch)
        loss += loss_batch

    # 计算loss
    loss = loss/(len(train_loader))  # 这是整个一轮的整批数据的loss

    # 降低lr，如果在连续10个epoch上损失都不再下降，则降低lr
    scheduler.step(loss)

    # 计算测试集loss
    test_loss = check_test_loss(test_loader,model)


    # 记录到tensorboard可视化
    writer.add_scalar("LOSS/train",loss,epoch)
    writer.add_scalar("LOSS/test",test_loss,epoch)

    # 保存最佳模型
    if best_test_loss > test_loss:
        best_test_loss = test_loss
        # 保存模型
        torch.save(model.state_dict(),'./saved_model/unet_best.pt')
        print("第{}个EPOCH达到最低loss".format(epoch))

    print("第{}个Epoch的训练时间{}s, TrainLoss:{}, TestLoss:{}".format(epoch,time.time() - start_time,loss,test_loss))

tensor(0.4848)
tensor(0.5506)
tensor(0.6036)
tensor(0.5632)
tensor(0.8292)
tensor(0.6965)
tensor(0.7852)
tensor(0.5116)
tensor(0.5295)
tensor(0.4348)
tensor(0.4554)
tensor(0.4753)
tensor(0.7221)
tensor(0.7945)
tensor(0.4631)
tensor(0.6561)
tensor(0.5972)
tensor(0.4419)
tensor(0.4936)
tensor(0.6998)
tensor(0.7411)
tensor(0.4426)
tensor(0.6162)
tensor(0.5234)
tensor(0.4936)
tensor(0.4858)
tensor(0.5425)
tensor(0.4645)
tensor(0.8117)
tensor(0.5412)
tensor(0.6749)
tensor(0.8836)
tensor(0.6422)
tensor(0.6120)
tensor(0.7668)
tensor(0.6734)
tensor(0.6571)
tensor(0.6429)
tensor(0.6776)
tensor(0.6695)
tensor(0.5692)
tensor(0.4809)
tensor(0.7584)
tensor(0.8257)
tensor(0.4827)
tensor(0.4992)
tensor(0.5700)
tensor(0.6650)
tensor(0.5146)
tensor(0.6895)
tensor(0.4380)
tensor(0.5415)
tensor(0.6955)
tensor(0.4438)
tensor(0.6065)
tensor(0.4873)
tensor(0.5547)
tensor(0.5084)
tensor(0.5525)
tensor(0.3543)
tensor(0.6670)
tensor(0.6543)
tensor(0.4019)
tensor(0.5134)
tensor(0.4351)
tensor(0.5240)
tensor(0.5

KeyboardInterrupt: 

In [23]:
len(train_loader)

161