In [1]:
import os
import cv2
import torch
import numpy as np
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from diceLoss import dice_loss
device = "cpu" if not torch.cuda.is_available() else "cuda"

数据预处理

In [3]:
# #查看label详情

# # 读取灰度图像
# image = cv2.imread('pancreas\\mask\\1c7493217d62.png', cv2.IMREAD_GRAYSCALE)

# # 将图像转换为 NumPy 数组
# image_array = np.array(image)

# # 打印图像矩阵的形状
# print(f'Image shape: {image_array.shape}')

# # 打印图像矩阵
# print(image_array)


In [2]:
#load images，resize images，and split as train and test set

def load_image_and_resize(folder):
    images = []
    for filename in os.listdir(folder):
        filepath = os.path.join(folder, filename)
        if os.path.isfile(filepath) and filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            try:
                if "mask" in folder:
                    img = Image.open(filepath).convert('L').resize((320, 320))
                    images.append(img)
                else:
                    img = Image.open(filepath).resize((320,320))
                    images.append(img)
            except IOError:
                print(f"Unable to open image file: {filename}")
    return images

def load_image_and_resize_origin_mask(folder):
    images = []
    for filename in os.listdir(folder):
        filepath = os.path.join(folder, filename)
        if os.path.isfile(filepath) and filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            try:
                img = Image.open(filepath).resize((320,320))
                images.append(img)
            except IOError:
                print(f"Unable to open image file: {filename}")
    return images
Image_for_pancreas = load_image_and_resize("pancreas\\image")
Mask_for_pancreas = load_image_and_resize("pancreas\\mask")
Image_for_stomach = load_image_and_resize("stomach\\image")
Mask_for_stomach = load_image_and_resize("stomach\\mask")
image_train_pancreas, image_test_pancreas, mask_train_pancreas, mask_test_pancreas = train_test_split(Image_for_pancreas, Mask_for_pancreas, test_size=0.3333, random_state=42)
image_train_stomach, image_test_stomach, mask_train_stomach, mask_test_stomach = train_test_split(Image_for_stomach, Mask_for_stomach, test_size=0.3333, random_state=42)
image_train = image_train_pancreas + image_train_stomach
image_test = image_test_pancreas + image_test_stomach
mask_train = mask_train_pancreas + mask_train_stomach
mask_test = mask_test_pancreas + mask_test_stomach


verify_Image_colorectum = load_image_and_resize("colorectum\\image")
verify_mask_colorectum = load_image_and_resize("colorectum\\mask")



In [3]:
#Dateset and Dataloader
class CustomDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        # 如果调用时没有指定任何转换，我们将使用默认的转换
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.ToTensor(),  # 将PIL图像或NumPy ndarray转换为tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化处理
        ])
        self.mask_transform = transforms.Compose([
            transforms.ToTensor()  # 通常情况下，遮罩只需要转换为tensor，不需要归一化
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        mask = self.masks[index]

        if self.transform:
            image = self.transform(image).to(device)
        if self.mask_transform:
            mask = (self.mask_transform(mask).to(device) > 0).float()

        return image, mask

    
# 创建训练集和测试集的 Dataset 对象
train_dataset = CustomDataset(image_train, mask_train)
test_dataset = CustomDataset(image_test, mask_test)
verify_dataset = CustomDataset(verify_Image_colorectum, verify_mask_colorectum)

# 可以使用 DataLoader 加载 Dataset 对象
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
verification_loader = DataLoader(verify_dataset, batch_size=1,shuffle=False)


In [4]:
# 定义模型
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import unetConv2
from init_weights import init_weights


class UNet3Plus(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4,
                 is_deconv=True, is_batchnorm=True):
        super(UNet3Plus, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.feature_scale = feature_scale
        self.is_deconv = is_deconv
        self.is_batchnorm = is_batchnorm
        filters = [64, 128, 256, 512, 1024]

        ## -------------Encoder--------------
        self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)

        ## -------------Decoder--------------
        self.CatChannels = filters[0]
        self.CatBlocks = 5
        self.UpChannels = self.CatChannels * self.CatBlocks

        '''stage 4d'''
        # h1->320*320, hd4->40*40, Pooling 8 times
        self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
        self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_PT_hd4_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd4->40*40, Pooling 4 times
        self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
        self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
        self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h2_PT_hd4_relu = nn.ReLU(inplace=True)

        # h3->80*80, hd4->40*40, Pooling 2 times
        self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
        self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h3_PT_hd4_relu = nn.ReLU(inplace=True)

        # h4->40*40, hd4->40*40, Concatenation
        self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
        self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd4->40*40, Upsample 2 times
        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
        self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu4d_1 = nn.ReLU(inplace=True)

        '''stage 3d'''
        # h1->320*320, hd3->80*80, Pooling 4 times
        self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
        self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_PT_hd3_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd3->80*80, Pooling 2 times
        self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
        self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.h2_PT_hd3_relu = nn.ReLU(inplace=True)

        # h3->80*80, hd3->80*80, Concatenation
        self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
        self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd4->80*80, Upsample 2 times
        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd4->80*80, Upsample 4 times
        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
        self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu3d_1 = nn.ReLU(inplace=True)

        '''stage 2d '''
        # h1->320*320, hd2->160*160, Pooling 2 times
        self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_PT_hd2_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd2->160*160, Concatenation
        self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
        self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)

        # hd3->80*80, hd2->160*160, Upsample 2 times
        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd2->160*160, Upsample 4 times
        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd2->160*160, Upsample 8 times
        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
        self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu2d_1 = nn.ReLU(inplace=True)

        '''stage 1d'''
        # h1->320*320, hd1->320*320, Concatenation
        self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
        self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)

        # hd2->160*160, hd1->320*320, Upsample 2 times
        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd3->80*80, hd1->320*320, Upsample 4 times
        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd1->320*320, Upsample 8 times
        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
        self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd1->320*320, Upsample 16 times
        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')  # 14*14
        self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
        self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)

        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
        self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
        self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu1d_1 = nn.ReLU(inplace=True)

        # output
        self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')


    def forward(self, inputs):
        ## -------------Encoder-------------
        h1 = self.conv1(inputs)  # h1->320*320*64

        h2 = self.maxpool1(h1)
        h2 = self.conv2(h2)  # h2->160*160*128

        h3 = self.maxpool2(h2)
        h3 = self.conv3(h3)  # h3->80*80*256

        h4 = self.maxpool3(h3)
        h4 = self.conv4(h4)  # h4->40*40*512

        h5 = self.maxpool4(h4)
        hd5 = self.conv5(h5)  # h5->20*20*1024

        ## -------------Decoder-------------
        h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels

        h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels

        h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
        hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels

        h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels

        d1 = self.outconv1(hd1)  # d1->320*320*n_classes
        return F.sigmoid(d1)


In [8]:
# #define criterion
# # 创建 Dice Loss 函数
# def dice_loss(pred, target, smooth=1e-2):
#     """
#     Compute the DICE loss, which is 1 - Dice coefficient.
#     Args:
#         pred (tensor): the model's output, raw logits that have not been normalized.
#         target (tensor): the ground truth labels.
#         smooth (float): a smoothing constant to avoid division by zero.

#     Returns:
#         float: dice loss.
#     """
#     intersection = (pred * target).sum(dim=(1, 2, 3))  # 计算每个样本的交集
#     union = pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))  # 计算每个样本的并集

#     dice = (2. * intersection + smooth) / (union + smooth)  # 计算Dice系数
#     dice_loss = 1 - dice  # 计算Dice损失
#     return dice_loss.mean()  # 返回批量的平均Dice损失



def train(model, train_loader, optimizer, criterion, epoch):
    model.train()  # 设置模型为训练模式
    running_loss = 0.0

    for data, target in train_loader:
        optimizer.zero_grad()  # 清除之前的梯度

        output = model(data)  # 前向传播，获取模型输出

        loss = criterion(output, target)  # 计算损失

        loss.backward()  # 反向传播，计算梯度
        optimizer.step()  # 更新模型参数

        running_loss += loss.item() * data.size(0)  # 累积损失

    # 计算平均损失
    running_loss /= len(train_loader.dataset)

    print(f'Epoch: {epoch + 1}, Training Loss: {running_loss:.4f}')
    
def iou(input, target, eps=1e-6):
    """
    Compute the Intersection over Union (IOU) for two float tensors.

    Args:
        input (torch.Tensor): Predicted tensor.
        target (torch.Tensor): Ground truth tensor.
        eps (float): A small constant to avoid division by zero.

    Returns:
        float: IOU.
    """
    input_flat = input.view(-1)
    target_flat = target.view(-1)

    intersection = torch.sum(input_flat * target_flat)
    union = torch.sum(input_flat) + torch.sum(target_flat) - intersection

    iou = (intersection + eps) / (union + eps)
    
    return iou.item()

def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    correct = 0
    total_iou = 0.0

    with torch.no_grad():
        for data, target in test_loader:

            output = model(data)
            loss = criterion(output, target)
            
            test_loss += loss * data.size(0)
            # 计算准确率
            pred = (output > 0.5).float().squeeze()
            target = target.squeeze()
            correct += (pred.type(torch.int8) == target.type(torch.int8)).sum().item()


            total_iou += iou(output, target)

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / (len(test_loader.dataset)*320*320)
    avg_iou = total_iou / len(test_loader)

    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)*320*320} ({accuracy:.2f}%), Avg IoU: {avg_iou:.4f}')



In [10]:
# 模型初始化

model = UNet3Plus().to(device)

# 优化器配置
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 损失函数配置
criterion = nn.BCELoss()

# 训练周期数
epochs = 300

# 执行训练
for epoch in range(epochs):
    train(model, train_loader, optimizer, criterion, epoch)
    if (epoch + 1) % 10 == 0:
        if os.path.exists('./models') is False:
            os.makedirs('./models')
        model_path = os.path.join("./models", f'model_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, model_path)
        print(f'Model saved to {model_path} at epoch {epoch+1}')


Epoch: 1, Training Loss: 0.7113
Epoch: 2, Training Loss: 0.5859
Epoch: 3, Training Loss: 0.5939
Epoch: 4, Training Loss: 0.5808
Epoch: 5, Training Loss: 0.6367
Epoch: 6, Training Loss: 0.5881
Epoch: 7, Training Loss: 0.5777
Epoch: 8, Training Loss: 0.6074
Epoch: 9, Training Loss: 0.5630
Epoch: 10, Training Loss: 0.5516
Model saved to ./models\model_epoch_10.pth at epoch 10
Epoch: 11, Training Loss: 0.5509
Epoch: 12, Training Loss: 0.5444
Epoch: 13, Training Loss: 0.5474
Epoch: 14, Training Loss: 0.5433
Epoch: 15, Training Loss: 0.5170
Epoch: 16, Training Loss: 0.5420
Epoch: 17, Training Loss: 0.5201
Epoch: 18, Training Loss: 0.5247
Epoch: 19, Training Loss: 0.5354
Epoch: 20, Training Loss: 0.5334
Model saved to ./models\model_epoch_20.pth at epoch 20
Epoch: 21, Training Loss: 0.5210
Epoch: 22, Training Loss: 0.5075
Epoch: 23, Training Loss: 0.4844
Epoch: 24, Training Loss: 0.5434
Epoch: 25, Training Loss: 0.4940
Epoch: 26, Training Loss: 0.5193
Epoch: 27, Training Loss: 0.5014
Epoch: 2

In [9]:

for model_name in os.listdir('models'):
    if model_name.endswith('.pth'):
        checkpoint_path = os.path.join('models', model_name)
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        test(model, test_loader, dice_loss)
        print('Model {} tested.'.format(model_name))





Test set: Average loss: 0.6364, Accuracy: 3215370/4096000 (78.50%), Avg IoU: 0.2536
Model model_epoch_10.pth tested.
Test set: Average loss: 0.5069, Accuracy: 3160247/4096000 (77.15%), Avg IoU: 0.3941
Model model_epoch_100.pth tested.
Test set: Average loss: 0.6459, Accuracy: 3155326/4096000 (77.03%), Avg IoU: 0.2441
Model model_epoch_20.pth tested.
Test set: Average loss: 0.5822, Accuracy: 3021929/4096000 (73.78%), Avg IoU: 0.3102
Model model_epoch_30.pth tested.
Test set: Average loss: 0.5643, Accuracy: 3028421/4096000 (73.94%), Avg IoU: 0.3359
Model model_epoch_40.pth tested.
Test set: Average loss: 0.5698, Accuracy: 2687927/4096000 (65.62%), Avg IoU: 0.3277
Model model_epoch_50.pth tested.
Test set: Average loss: 0.6359, Accuracy: 3142163/4096000 (76.71%), Avg IoU: 0.2664
Model model_epoch_60.pth tested.
Test set: Average loss: 0.6743, Accuracy: 2892368/4096000 (70.61%), Avg IoU: 0.2415
Model model_epoch_70.pth tested.
Test set: Average loss: 0.6089, Accuracy: 3100031/4096000 (75.6