# 2. 训练U-net模型

## 摘要

* 加载CT扫描图像，去除面积小于25（即直径约小于2.5mm）的样本
* 训练一个部分基于ResNet18预训练模型构建的U-net模型
* 测试模型

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

datafolder='processeddata'                              # 经过预处理的数据的保存位置
weightsfolder='modelpths'                               # 模型权重保存位置
noduleimages=np.load(datafolder+"/noduleimages.npy")    # 肺部图像数据
nodulemasks=np.load(datafolder+"/nodulemasks.npy")      # 肺结节掩膜数据


In [None]:
# 使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = True


In [None]:
# 绘制肺结节的面积分布
nodulesize=[np.sum(mask) for mask in nodulemasks]
plt.hist([nod for nod in nodulesize if nod<300],bins=50)
plt.xlabel("Area")
plt.ylabel("frequency")
plt.show()


In [None]:
# remove samples with nodulesize<25, which is ~radius=2.8
# 删除掩膜面积小于25的样本
filteredindicies=[i for i in range(len(nodulesize)) if nodulesize[i]>25]
noduleimages=noduleimages[filteredindicies]
nodulemasks=nodulemasks[filteredindicies]

In [None]:
# 绘制肺结节的HU图

plt.figure()
plt.imshow(noduleimages[42])                 #显示第42张HU值图
plt.annotate('', xy=(317, 367), xycoords='data',
             xytext=(0.5, 0.5), textcoords='figure fraction',
             arrowprops=dict(arrowstyle="->"))
#plt.savefig("images/test.png",dpi=300)
plt.show()

In [None]:
# 数据形状与类型转换
noduleimages=noduleimages.reshape(noduleimages.shape[0],1,512,512)
nodulemasks=nodulemasks.reshape(nodulemasks.shape[0],1,512,512)
noduleimages[noduleimages==-0]=0
nodulemasks[nodulemasks<=0]=0
nodulemasks[nodulemasks>0]=1
imagestrain, imagestest, maskstrain, maskstest = train_test_split(noduleimages.astype(float),nodulemasks.astype(float),test_size=0.3)

imagestrain=torch.from_numpy(imagestrain).float()
maskstrain=torch.from_numpy(maskstrain).float()
imagestest=torch.from_numpy(imagestest).float()
maskstest=torch.from_numpy(maskstest).float()

del noduleimages, nodulemasks


In [None]:
# U-Net Model
import torchvision
from torchvision.models.resnet import ResNet18_Weights

class Decoder(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(Decoder, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_relu = nn.Sequential(
            nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
            )
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x1 = torch.cat((x1, x2), dim=1)
        x1 = self.conv_relu(x1)
        return x1

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.n_class = 1

        self.base_model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT) # or .IMAGENET1K_V1
        self.base_layers = list(self.base_model.children())
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            self.base_layers[1],
            self.base_layers[2])
        self.layer2 = nn.Sequential(*self.base_layers[3:5])
        self.layer3 = self.base_layers[5]
        self.layer4 = self.base_layers[6]
        self.layer5 = self.base_layers[7]
        self.decode4 = Decoder(512, 256+256, 256)
        self.decode3 = Decoder(256, 256+128, 256)
        self.decode2 = Decoder(256, 128+64, 128)
        self.decode1 = Decoder(128, 64+64, 64)
        self.decode0 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
            )
        self.conv_last = nn.Conv2d(64, self.n_class, 1)

    def forward(self, input):
        
        e1 = self.layer1(input) # 64,256,256
        e2 = self.layer2(e1) # 64,128,128
        e3 = self.layer3(e2) # 128,64,64
        e4 = self.layer4(e3) # 256,32,32
        f = self.layer5(e4) # 512,16,16
        d4 = self.decode4(f, e4) # 256,32,32
        d3 = self.decode3(d4, e3) # 256,64,64
        d2 = self.decode2(d3, e2) # 128,128,128
        d1 = self.decode1(d2, e1) # 64,256,256
        d0 = self.decode0(d1) # 64,512,512
        out = self.conv_last(d0) # 1,512,512
        
        act = nn.Sigmoid()
        return act(out)


In [None]:
from torch.utils.data import Dataset

class LIDCDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label
    

In [None]:
train_dataset = LIDCDataset(imagestrain, maskstrain)
val_dataset = LIDCDataset(imagestest, maskstest)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
# 评价指标

# def dice_coef(y_pred, y_true, smooth=1e-5):
#     multi = y_true * y_pred
#     intersection = torch.sum(multi)
#     multi = 0
#     sumyt = torch.sum(y_true)
#     sumyp = torch.sum(y_pred)
#     union = sumyt + sumyp + smooth
#     intersection = 2.0 * intersection + smooth
#     dice = intersection.float() / union.float()
#     return dice

# def dice_loss(y_pred, y_true, smooth=1e-5):
#     dice_loss = 1 - dice_coef(y_pred, y_true, smooth)
#     return dice_loss


In [None]:
# 评价指标

def dice_coef(pred, gt, smooth=1):

    # N = gt.size(0)

    pred_flat = pred.view(-1)
    gt_flat = gt.view(-1)

    intersection = pred_flat * gt_flat
    
    dice = (2 * intersection.sum() + smooth) / (pred_flat.sum() + gt_flat.sum() + smooth)
    # dice = dice.sum() / N

    return dice

def dice_loss(pred, gt, smooth=1):
    return 1 - dice_coef(pred, gt, smooth)


In [None]:
# Instantiate the model
   
max_lr = 0.0003
min_lr = 0.0001
momentum = 0.9
weight_decay = 0.#0001
save_weight = weightsfolder + '/'
weight_name = 'unet'
num_epochs = 300
scheduler_step = num_epochs // 10

model = UNet()
model = model.cuda()  # Move model to GPU if available

# Define the optimizer and the scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, weight_decay=weight_decay)
# Setup optimizer
#optimizer = torch.optim.SGD(model.parameters(), lr=max_lr, momentum=momentum, weight_decay=weight_decay)
# Setup scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, min_lr)

history = {
    'train_loss': [],
    'val_loss': [],
    'train_accuracy': [],
    'val_accuracy': []
}


In [None]:
# Training loop

from tqdm.notebook import tqdm

torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection

num_snapshot = 0
best_param = []
best_val_loss, best_loss = float('inf'), float('inf')  # Initialize best loss as infinity

for epoch in tqdm(range(num_epochs)):
    print(f'Epoch {epoch+1}/{num_epochs}')

    # Training
    model.train()
    print('Training...')
    for batch in tqdm(train_dataloader):
        inputs, labels = batch
        inputs, labels = inputs.cuda(), labels.cuda()  # Move inputs and labels to GPU if available

        # Forward pass
        outputs = model.forward(inputs)
        loss = dice_loss(outputs, labels)
        accuracy = dice_coef(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.requires_grad_(True) #启用梯度计算
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    print('Validating...')
    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            inputs, labels = batch
            inputs, labels = inputs.cuda(), labels.cuda()  # Move inputs and labels to GPU if available

            outputs = model.forward(inputs)
            val_loss = dice_loss(outputs, labels)
            val_accuracy = dice_coef(outputs, labels)

    print(f'Done.\nTraining Loss: {loss.item():.4f}, Validation Loss: {val_loss.item():.4f}\nTraining Accuracy: {accuracy.item():.4f}, Validation Accuracy: {val_accuracy.item():.4f}')
    
    lr_scheduler.step()  # Update learning rate
    print(f'Learning rate is updated to {lr_scheduler.get_last_lr()}')
    
    # Save the model if validation loss is improved
    if val_loss < best_val_loss and loss < best_loss:
        print(f'Validation loss decreased from {best_val_loss} to {val_loss}.')
        print(f'Training loss decreased from {best_loss} to {loss}.')
        print('Update best parameters.\n')
        best_val_loss = val_loss
        best_loss = loss
        best_param = model.state_dict()
        # torch.save(best_param, save_weight + weight_name + '_' + str(epoch) + '.pth')
    
    if (epoch + 1) % scheduler_step == 0:
        print('Num_snapshot reached, Save the model and reset the optimizer.\n')
        torch.save(best_param, save_weight + weight_name + str(num_snapshot) + '.pth')
        optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, weight_decay=weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, min_lr)
        num_snapshot = num_snapshot + 1
        best_val_loss = float('inf')
        best_loss = float('inf')
        best_param = []
        
    history['train_loss'].append(loss.item())
    history['val_loss'].append(val_loss.item())
    history['train_accuracy'].append(accuracy.item())
    history['val_accuracy'].append(val_accuracy.item())


In [None]:
# 绘制准确率曲线

# plt.plot(history['dice_coef'], color='b')
plt.plot(history['train_accuracy'], color='b')
# plt.plot(history['val_dice_coef'], color='g')
plt.plot(history['val_accuracy'], color='g')
plt.xlabel("Epoch")
# plt.ylabel("accuracy")
plt.ylabel("dice_coef")
plt.legend(["Train", "Test"])
plt.show()

In [None]:
# 保存历史数据

import pickle

with open('history.pkl', 'wb') as f:
    pickle.dump(history, f)

history['train_accuracy'][-1], history['val_accuracy'][-1]


In [None]:
# 加载最佳模型

best_param = torch.load(save_weight + weight_name + str(num_snapshot-1) + '.pth')
model.load_state_dict(best_param)


In [None]:
# 测试模型

model.eval()
with torch.no_grad():
    inputs = imagestest
    labels = maskstest
    outputs = model.forward(inputs)
    loss = 1 - dice_coef(outputs, labels)
    accuracy = dice_coef(outputs, labels)

print(f"Loss: {loss.item()}, Accuracy: {accuracy.item()}")


In [None]:
# 计算模型的预测结果与测试数据的重叠率

model.eval()
num_test = imagestest.shape[0]
imgs_mask_test = np.empty([num_test, 1, 512, 512], dtype=np.float32)
for i in range(num_test):
    inputs = imagestest[i:i+1].cuda()
    outputs = model.forward(inputs)
    imgs_mask_test[i] = outputs.cpu().detach().numpy()[0]

sumoverlap = []
for i in range(num_test):
    sumoverlap.append(torch.sum(maskstest[i, 0] * imgs_mask_test[i, 0]).item())

overlap_ratio = len([ov for ov in sumoverlap if ov > 1]) / len(sumoverlap)
overlap_ratio


In [None]:
# 绘制预测结果掩模及原始图像

index=5
print("Predicted")
plt.imshow(imgs_mask_test[index,0], cmap="gray")
plt.show()
print("Ground Truth")
plt.imshow(maskstest[index,0],cmap="gray")
plt.show()
print("Image")
plt.imshow(imagestest[index,0], cmap="gray")
plt.show()