<a href="https://colab.research.google.com/github/Tokisaki-Galaxy/PterygiumSeg/blob/master/work2_basemode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 翼状胬肉区域分割模型

这是项目的第二个任务：实现对眼部裂隙灯检查图片中翼状胬肉区域的精准分割。我们将使用U-Net模型解决这一问题。

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import pandas as pd
import os
import cv2
import numpy as np
from PIL import Image
import zipfile
import sys
import platform
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.font_manager
%matplotlib inline

if platform.system() == "Windows":
    num_workers = 0
    print(f"检测到 Windows 系统，将 DataLoader 的 num_workers 设置为 {num_workers}。")
else:
    # 在非 Windows 系统（如 Linux/Colab）上
    num_workers = 4
    print(f"检测到非 Windows 系统 ({platform.system()})，将 DataLoader 的 num_workers 设置为 {num_workers}。")
    # 设置中文字体
    if not os.path.exists('simhei.ttf'):
        !wget -O simhei.ttf "https://www.wfonts.com/download/data/2014/06/01/simhei/chinese.simhei.ttf"
    matplotlib.font_manager.fontManager.addfont('simhei.ttf')
    matplotlib.rc('font', family='SimHei')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# ================== 数据集路径 =================
# 数据路径
image_dir =          r"f:/train"
# colab路径
colab_zip_path = "/content/drive/My Drive/train.zip"
colab_extract_path = "/content/trains/"
# Kaggle路径
kaggle_extract_path = "/kaggle/input/pterygium/train/"
kaggle_temp_path = "/kaggle/working/"

# =================== 验证集路径 =================
# 验证集路径
val_image_dir =      r"f:/val"
# colab路径
# Kaggle路径
kaggle_val_path = "/kaggle/input/pterygium/val_img/"

# ================== 掩码输出路径 ================
output_maskfiles = r"f:/mask"
# colab路径
output_maskfiles_colab = "/content/mask"
# Kaggle路径
output_maskfiles_kaggle = "/kaggle/working/mask"

# 配置GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA 可用: {torch.cuda.is_available()}")
print(f"使用的设备: {device}")

# 读取和准备数据
我们需要读取原始图像和对应的分割标签（mask）。标签中像素值为128的区域表示翼状胬肉，像素值为0的区域表示背景。

In [None]:
# 如果在云端上运行，从 Google Drive 读取数据
if 'google.colab' in sys.modules or os.path.exists("/kaggle/working"):

    if 'google.colab' in sys.modules:
        print('在 Google Colab 环境中运行')
        image_dir = os.path.join(colab_extract_path,"train")
        label_file = os.path.join(image_dir,"train_classification_label.xlsx")
        zip_path = colab_zip_path
        extract_path = colab_extract_path

        output_dir_to_zip = output_maskfiles_colab
        zip_file_path = f"{output_dir_to_zip}.zip"
        print(f"Colab 环境：验证结果将验证压缩 {output_dir_to_zip} 到 {zip_file_path}")

        # Mount Google Drive
        from google.colab import drive
        drive.mount('/content/drive')
    else:
        print('在 Kaggle 环境中运行')
        image_dir = os.path.join(kaggle_extract_path,"train")
        label_file = os.path.join(image_dir,"train_classification_label.xlsx")
        val_image_dir = os.path.join(kaggle_val_path,"val_img")
        
        output_dir_to_zip = output_maskfiles_kaggle
        zip_file_path = f"{output_dir_to_zip}.zip"
        print(f"Kaggle 环境：验证结果将压缩 {output_dir_to_zip} 到 {zip_file_path}")

    if not os.path.exists(label_file):
        # 解压数据
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)    
else:
    print(f'不在云端环境中运行,使用本地数据路径{image_dir}')
label_file = os.path.join(image_dir,"train_classification_label.xlsx")

# 自定义数据集类，用于读取图像和分割掩码
class PterygiumSegmentationDataset(Dataset):
    def __init__(self, label_file, image_dir, transform=None, mask_transform=None):
        """
        初始化数据集
        :param label_file: 包含图像标签的Excel文件路径
        :param image_dir: 图像文件夹路径
        :param transform: 图像变换操作
        :param mask_transform: 掩码变换操作
        """
        self.labels_df = pd.read_excel(label_file)
        # 只保留翼状胬肉样本（标签1和2）
        self.labels_df = self.labels_df[self.labels_df['Pterygium'] > 0].reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        """
        获取指定索引的图像和分割掩码
        :param idx: 索引
        :return: 图像张量和对应掩码张量
        """
        row = self.labels_df.iloc[idx]
        image_name = row['Image']
        label = row['Pterygium']
        image_folder = f"{int(image_name):04d}"
        
        # 加载图像
        image_path = os.path.join(self.image_dir, image_folder, f"{image_folder}.png")
        image = Image.open(image_path).convert("RGB")
        
        # 加载分割掩码
        mask_path = os.path.join(self.image_dir, image_folder, f"{image_folder}_label.png")
        mask = Image.open(mask_path).convert("L")  # 转换为灰度图
        
        # 应用图像变换
        if self.transform:
            image = self.transform(image)
            
        # 应用掩码变换
        if self.mask_transform:
            mask = self.mask_transform(mask)
        else:
            # 将掩码转换为张量，并二值化（翼状胬肉区域为1，背景为0）
            mask = torch.from_numpy(np.array(mask))
            mask = mask.float() / 255.0
            mask = (mask > 0.2).float()  # 二值化，阈值设为0.2以捕获可能的淡色区域
        
        return image, mask

# 数据 Resize
将图像和掩码调整为统一的大小，以适应模型输入要求。我们使用256x256的分辨率以提高分割精度。

In [None]:
from torchvision.transforms.functional import to_pil_image
target_size = (256, 256)  # 目标尺寸
output_format = "PNG"  # 输出格式

# --- Transformation Definition ---
# 图像变换
image_transform = transforms.Compose([
    transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 掩码变换 - 不进行标准化，仅调整大小
mask_transform = transforms.Compose([
    transforms.Resize(target_size, interpolation=transforms.InterpolationMode.NEAREST),  # 使用最近邻插值避免引入新值
    transforms.ToTensor()
])

# --- 处理函数，调整图像和掩码大小 ---
def resize_and_save_image_mask_pair(img_info, base_input_dir, base_output_dir, image_transform, mask_transform, device):
    """
    读取图像和掩码，调整大小并保存。
    """
    try:
        image_name = img_info['Image']
        image_folder = f"{int(image_name):04d}"
        
        # 输入路径
        image_path = os.path.join(base_input_dir, image_folder, f"{image_folder}.png")
        mask_path = os.path.join(base_input_dir, image_folder, f"{image_folder}_label.png")
        
        # 确保输出目录存在
        output_folder_path = os.path.join(base_output_dir, image_folder)
        os.makedirs(output_folder_path, exist_ok=True)
        
        # 输出路径
        output_image_path = os.path.join(output_folder_path, f"{image_folder}.{output_format.lower()}")
        output_mask_path = os.path.join(output_folder_path, f"{image_folder}_label.{output_format.lower()}")
        
        # 1. 处理图像
        img_pil = Image.open(image_path).convert("RGB")
        img_tensor_cpu = transforms.functional.to_tensor(img_pil)  # 转换为张量
        img_tensor_gpu = img_tensor_cpu.to(device)  # 移至GPU
        resized_tensor_gpu = image_transform(img_tensor_gpu)  # 应用变换
        resized_tensor_cpu = resized_tensor_gpu.cpu()  # 移回CPU
        resized_img_pil = to_pil_image(resized_tensor_cpu)  # 转回PIL图像
        resized_img_pil.save(output_image_path, format=output_format)  # 保存
        
        # 2. 处理掩码
        mask_pil = Image.open(mask_path).convert("L")  # 读取为灰度图
        mask_tensor_cpu = transforms.functional.to_tensor(mask_pil)  # 转换为张量
        mask_tensor_gpu = mask_tensor_cpu.to(device)  # 移至GPU
        resized_mask_gpu = mask_transform(mask_tensor_gpu)  # 应用变换
        resized_mask_cpu = resized_mask_gpu.cpu()  # 移回CPU
        resized_mask_pil = to_pil_image(resized_mask_cpu)  # 转回PIL图像
        resized_mask_pil.save(output_mask_path, format=output_format)  # 保存
        
        return True  # 表示成功
        
    except FileNotFoundError as e:
        print(f"错误: 文件未找到 {str(e)}")
        return False
    except Exception as e:
        print(f"错误处理图像和掩码: {str(e)}")
        return False

# 只在非Windows环境执行数据预处理
if not platform.system() == "Windows":
    if 'google.colab' in sys.modules:
        original_image_dir = os.path.join(colab_extract_path,"train")
        output_dir = os.path.join(colab_extract_path,"train_seg_resized")
    elif os.path.exists("/kaggle/working"):
        original_image_dir = os.path.join(kaggle_extract_path,"train")
        output_dir = os.path.join(kaggle_temp_path,"train_seg_resized")
    else:
        print("错误: 无法识别的环境")
        exit(1)
    image_dir = output_dir

    print(f"输入目录: {original_image_dir}")
    print(f"输出目录: {output_dir}")
    print(f"目标尺寸: {target_size}")

    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    if os.listdir(output_dir):
        print("检测到已存在的resize数据，跳过resize步骤")
    else:
        # 读取标签文件以知道要处理哪些图像
        try:
            labels_df = pd.read_excel(label_file)
            # 只保留翼状胬肉样本（标签1和2）用于分割训练
            pterygium_df = labels_df[labels_df['Pterygium'] > 0].reset_index(drop=True)
        except FileNotFoundError:
            print(f"错误: 标签文件未找到 {label_file}")
            sys.exit(1)

        success_count = 0
        error_count = 0

        # 遍历标签文件中列出的图像
        for index, row in tqdm(pterygium_df.iterrows(), total=len(pterygium_df), desc="Resizing Images and Masks"):
            if resize_and_save_image_mask_pair(row, original_image_dir, output_dir, image_transform, mask_transform, device):
                success_count += 1
            else:
                error_count += 1

        print(f"\n处理完成!")
        print(f"成功处理图像和掩码对的数量: {success_count}")
        print(f"处理失败的图像和掩码对的数量: {error_count}")
        print(f"处理后的图像和掩码保存在: {output_dir}")

# 创建数据加载器
设置训练和验证数据加载器，包括数据增强策略。

In [None]:
# 数据增强变换
from torchvision.transforms import functional as F
import random

class SegmentationTransform:
    """自定义变换，确保图像和掩码进行相同的随机变换"""
    def __init__(self, base_size=256, crop_size=224, flip_prob=0.5, rotation_degrees=15):
        self.base_size = base_size
        self.crop_size = crop_size
        self.flip_prob = flip_prob
        self.rotation_degrees = rotation_degrees
        
    def __call__(self, image, mask):
        # 调整大小
        image = F.resize(image, (self.base_size, self.base_size), interpolation=transforms.InterpolationMode.BILINEAR)
        mask = F.resize(mask, (self.base_size, self.base_size), interpolation=transforms.InterpolationMode.NEAREST)
        
        # 随机裁剪
        i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(self.crop_size, self.crop_size))
        image = F.crop(image, i, j, h, w)
        mask = F.crop(mask, i, j, h, w)
        
        # 随机水平翻转
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            mask = F.hflip(mask)
        
        # 随机旋转
        angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
        image = F.rotate(image, angle, interpolation=transforms.InterpolationMode.BILINEAR)
        mask = F.rotate(mask, angle, interpolation=transforms.InterpolationMode.NEAREST)
        
        # 转换为张量
        image = F.to_tensor(image)
        # 标准化图像
        image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
        # 掩码转为张量，不需要标准化
        mask = torch.from_numpy(np.array(mask))
        mask = mask.float() / 255.0
        mask = (mask > 0.2).float().unsqueeze(0)  # 添加通道维度
        
        return image, mask

# 验证集变换 - 仅调整大小，不进行随机变换
val_transform = transforms.Compose([
    transforms.Resize((256, 256)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 掩码验证变换
val_mask_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor()
])

# 修改数据集类，使其支持成对变换
class PterygiumSegDataset(Dataset):
    def __init__(self, label_file, image_dir, transform=None):
        self.labels_df = pd.read_excel(label_file)
        # 只保留翼状胬肉样本（标签1和2）
        self.labels_df = self.labels_df[self.labels_df['Pterygium'] > 0].reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_df)
    
    def __getitem__(self, idx):
        """
        获取指定索引的图像和分割掩码
        :param idx: 索引
        :return: 图像张量和对应掩码张量
        """
        row = self.labels_df.iloc[idx]
        image_name = row['Image']
        label = row['Pterygium'] # 虽然分割任务不需要label，但保留以兼容数据结构
        image_folder = f"{int(image_name):04d}"

        # 加载图像
        image_path = os.path.join(self.image_dir, image_folder, f"{image_folder}.png")
        image = Image.open(image_path).convert("RGB")

        # 加载分割掩码
        mask_path = os.path.join(self.image_dir, image_folder, f"{image_folder}_label.png")
        # 检查掩码文件是否存在（虽然当前逻辑只处理有翼状胬肉的样本，但增加健壮性）
        if os.path.exists(mask_path):
            mask_rgb = Image.open(mask_path).convert("RGB")
            mask_np = np.array(mask_rgb)
            # 提取R通道，像素值为128的区域设为1，其他为0
            mask_binary = (mask_np[:, :, 0] == 128).astype(np.uint8)
            # 转换回单通道灰度PIL图像 (1 -> 255, 0 -> 0) 以便后续transform处理
            mask = Image.fromarray(mask_binary * 255, mode='L')
        else:
            # 如果没有掩码文件，创建一个全黑的掩码
            # 注意：这取决于具体任务需求，这里假设必须有掩码
            print(f"掩码文件未找到 {mask_path}, 将使用全零掩码。")
            # 获取图像尺寸以创建匹配的空掩码
            img_width, img_height = image.size
            mask = Image.new('L', (img_width, img_height), 0)

        # 应用变换
        if self.transform:
            # 假设transform接受 image 和 mask 作为输入
            image, mask = self.transform(image, mask)
        else:
            # 如果没有特定transform，应用默认变换
            # 图像变换
            image = F.to_tensor(image)
            image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            # 掩码变换 (转换为Tensor，确保为0或1)
            mask = transforms.ToTensor()(mask) # 将 PIL (0/255) 转为 Tensor (0.0/1.0)
            mask = (mask > 0.5).float() # 确保是严格的0和1

        return image, mask

# 划分训练集和验证集
from sklearn.model_selection import train_test_split

# 读取标签文件
labels_df = pd.read_excel(label_file)
# 只保留翼状胬肉样本
pterygium_df = labels_df[labels_df['Pterygium'] > 0].reset_index(drop=True)

# 按照8:2的比例划分训练集和验证集
train_df, val_df = train_test_split(pterygium_df, test_size=0.2, random_state=42, stratify=pterygium_df['Pterygium'])

# 保存划分后的数据集
train_label_file = os.path.join(image_dir, "train_segmentation_label_train.xlsx")
val_label_file = os.path.join(image_dir, "train_segmentation_label_val.xlsx")
if os.path.exists("/kaggle/working"):
    train_label_file = os.path.join(kaggle_temp_path, "train_segmentation_label_train.xlsx")
    val_label_file = os.path.join(kaggle_temp_path, "train_segmentation_label_val.xlsx")
train_df.to_excel(train_label_file, index=False)
val_df.to_excel(val_label_file, index=False)

# 创建训练集和验证集
seg_transform = SegmentationTransform(base_size=256, crop_size=224, flip_prob=0.5, rotation_degrees=15)
train_dataset = PterygiumSegDataset(train_label_file, image_dir, transform=seg_transform)
val_dataset = PterygiumSegDataset(val_label_file, image_dir, transform=None)  # 验证集不使用数据增强

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=num_workers, pin_memory=False if platform.system() == "Windows" else True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=num_workers, pin_memory=False if platform.system() == "Windows" else True)

# 构建U-Net分割模型
U-Net是一种经典的图像分割模型，其结构包括下采样路径（编码器）和上采样路径（解码器），以及跳跃连接。

In [None]:
class DoubleConv(nn.Module):
    """双卷积块：(Conv -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """下采样层：MaxPool + DoubleConv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """上采样层：UpConv + DoubleConv（带跳跃连接）"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        # 使用双线性插值或转置卷积进行上采样
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # 输入可能不是整数倍的2，需要进行尺寸调整
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # 连接特征图
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    """输出卷积层"""
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    """完整的UNet模型"""
    def __init__(self, n_channels=3, n_classes=1, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # 加载预训练的ResNet-18
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        # 编码器部分 (使用ResNet-18的层)
        self.inc = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu
        ) # 输出通道: 64
        self.maxpool = resnet.maxpool
        self.down1 = resnet.layer1 # 输出通道: 64
        self.down2 = resnet.layer2 # 输出通道: 128
        self.down3 = resnet.layer3 # 输出通道: 256
        self.down4 = resnet.layer4 # 输出通道: 512

        # 解码器部分 (调整通道数以匹配ResNet)
        factor = 2 if bilinear else 1
        self.up1 = Up(512 + 256, 512 // factor, bilinear) # down4(512) + down3(256) -> 256
        self.up2 = Up(256 + 128, 256 // factor, bilinear) # up1(256) + down2(128) -> 128
        self.up3 = Up(128 + 64, 128 // factor, bilinear)  # up2(128) + down1(64) -> 64
        self.up4 = Up(64 + 64, 64, bilinear)             # up3(64) + inc(64) -> 64
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        # 编码路径 (ResNet)
        x1 = self.inc(x)       # (N, 64, H/2, W/2) after initial conv+bn+relu (stride=2)
        x_pool = self.maxpool(x1) # (N, 64, H/4, W/4)
        x2 = self.down1(x_pool) # (N, 64, H/4, W/4)
        x3 = self.down2(x2)     # (N, 128, H/8, W/8)
        x4 = self.down3(x3)     # (N, 256, H/16, W/16)
        x5 = self.down4(x4)     # (N, 512, H/32, W/32)

        # 解码路径 (带跳跃连接)
        x = self.up1(x5, x4) # 输入: x5(512), x4(256) -> 输出: 256
        x = self.up2(x, x3)  # 输入: x(256), x3(128) -> 输出: 128
        x = self.up3(x, x2)  # 输入: x(128), x2(64) -> 输出: 64
        x = self.up4(x, x1)  # 输入: x(64), x1(64) -> 输出: 64
        logits = self.outc(x)
        return logits

# 初始化模型
model = UNet(n_classes=1, bilinear=True).to(device)

# 定义损失函数和评估指标
我们使用组合损失函数：二元交叉熵损失和Dice损失的组合，以更好地处理类别不平衡问题。

In [None]:
# Dice损失函数
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, logits, targets):
        # 使用sigmoid将logits转换为概率
        probs = torch.sigmoid(logits)
        
        # 将维度展平
        batch_size = targets.size(0)
        probs = probs.view(batch_size, -1)
        targets = targets.view(batch_size, -1)
        
        # 计算交集
        intersection = (probs * targets).sum(dim=1)
        
        # 计算Dice系数
        dice = (2. * intersection + self.smooth) / (
            probs.sum(dim=1) + targets.sum(dim=1) + self.smooth)
        
        # 返回Dice损失
        return 1 - dice.mean()

# 组合损失：二元交叉熵 + Dice损失
class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()
        
    def forward(self, logits, targets):
        bce = self.bce_loss(logits, targets)
        dice = self.dice_loss(logits, targets)
        return self.bce_weight * bce + self.dice_weight * dice

# 评估指标：Dice系数
def dice_coefficient(y_pred, y_true, threshold=0.5, smooth=1e-6):
    """计算预测掩码和真实掩码之间的Dice系数"""
    # 应用阈值将概率转换为二值掩码
    y_pred = (torch.sigmoid(y_pred) > threshold).float()
    
    # 压平张量
    y_pred = y_pred.contiguous().view(-1)
    y_true = y_true.contiguous().view(-1)
    
    # 计算交集
    intersection = (y_pred * y_true).sum()
    
    # 计算Dice系数
    dice = (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)
    
    return dice.item()

# 初始化损失函数
criterion = CombinedLoss(bce_weight=0.6, dice_weight=0.4)

# 配置优化器和训练参数
设置Adam优化器和学习率调度器，为模型训练做准备。

In [None]:
# 训练参数
num_epochs = 30
log_interval = 5

# 配置优化器
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# 学习率调度器
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6) # 使用基线超参数

# 训练模型
实现训练循环，包括前向传播、损失计算、反向传播、参数更新，并记录训练过程中的指标。同时实现早停机制。

In [None]:
# 定义早停类
from copy import deepcopy

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.0, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.mode = mode
        self.best_model_weights = None
        
        # 根据模式确定比较操作
        if self.mode == 'min':
            self.delta_sign = -1  # 对于最小值模式，分数需要减少delta
        else:  # mode == 'max'
            self.delta_sign = 1  # 对于最大值模式，分数需要增加delta
    
    def __call__(self, val_score, model):
        score = val_score
        
        if self.best_score is None:
            self.best_score = score
            self.best_model_weights = deepcopy(model.state_dict())
            tqdm.write(f"EarlyStopping: 初始化最佳分数为 {self.best_score:.4f}")
        # 检查是否有足够的提升
        elif (score * self.delta_sign) > (self.best_score * self.delta_sign) + self.min_delta:
            # 有足够的提升
            self.best_score = score
            self.best_model_weights = deepcopy(model.state_dict())
            self.counter = 0
            tqdm.write(f"EarlyStopping: 发现改进。最佳分数更新为 {self.best_score:.4f}。计数器重置。")
        else:
            # 没有足够的提升
            self.counter += 1
            tqdm.write(f'EarlyStopping计数器: {self.counter} (共 {self.patience})。最佳分数仍为 {self.best_score:.4f}。')
            if self.counter >= self.patience:
                self.early_stop = True
                tqdm.write("EarlyStopping: 已达到耐心值。")

In [None]:
import time

def train_validate_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    """
    训练并验证模型一个完整的周期，支持早停。
    
    返回：
        - float: 最佳验证Dice系数
        - list: 训练Dice系数历史
        - list: 验证Dice系数历史
        - dict: 最佳模型的state_dict
    """
    start_time = time.time()
    print("\n--- 开始训练 ---")
    
    # 初始化状态对象
    early_stopping = EarlyStopping(patience=7, mode='max')
    scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
    train_dice_history = []
    val_dice_history = []
    best_model_state_dict = None
    
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_dice = 0.0
        train_samples = 0
        train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
        
        for images, masks in train_loader_tqdm:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            
            # 使用混合精度训练
            with torch.amp.autocast(enabled=torch.cuda.is_available()):
                outputs = model(images)
                loss = criterion(outputs, masks)
            
            # 反向传播和优化
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            batch_size = images.size(0)
            train_loss += loss.item() * batch_size
            train_dice += dice_coefficient(outputs, masks) * batch_size
            train_samples += batch_size
            
            current_lr = optimizer.param_groups[0]['lr']
            train_loader_tqdm.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{dice_coefficient(outputs, masks):.4f}',
                'lr': f'{current_lr:.1e}'
            })
        
        train_loss /= train_samples
        train_dice /= train_samples
        train_dice_history.append(train_dice)
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_dice = 0.0
        val_samples = 0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                
                # 推理
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                batch_size = images.size(0)
                val_loss += loss.item() * batch_size
                val_dice += dice_coefficient(outputs, masks) * batch_size
                val_samples += batch_size
        
        val_loss /= val_samples
        val_dice /= val_samples
        val_dice_history.append(val_dice)
        
        # 更新学习率
        scheduler.step(val_dice)
        
        tqdm.write(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, "
                  f"Train Dice: {train_dice:.4f}, Val Loss: {val_loss:.4f}, "
                  f"Val Dice: {val_dice:.4f}")
        
        # 早停检查
        early_stopping(val_dice, model)
        if early_stopping.early_stop:
            tqdm.write(f"早停触发于第 {epoch + 1} 轮。")
            best_model_state_dict = early_stopping.best_model_weights
            break
    
    # 如果训练正常完成（未早停），也要保存最后的最佳权重
    if not early_stopping.early_stop:
        tqdm.write(f"训练在 {num_epochs} 轮后完成。")
        best_model_state_dict = early_stopping.best_model_weights
    
    # 使用最佳模型进行最终评估
    if best_model_state_dict:
        model.load_state_dict(best_model_state_dict)
        model.eval()
        final_val_dice = 0.0
        val_samples = 0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                batch_size = images.size(0)
                final_val_dice += dice_coefficient(outputs, masks) * batch_size
                val_samples += batch_size
                
        final_val_dice /= val_samples
    else:
        final_val_dice = 0.0
        tqdm.write("警告：无法获取最佳模型权重。")
    
    end_time = time.time()
    print("--- 训练完成 ---")
    print(f"最终验证Dice系数: {final_val_dice:.4f}")
    print(f"训练耗时: {end_time - start_time:.2f} 秒")
    
    return final_val_dice, train_dice_history, val_dice_history, best_model_state_dict

# 开始训练
best_dice, train_dice_history, val_dice_history, best_model_weights = train_validate_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
    device=device
)

# 评估模型性能
可视化学习曲线和分割结果，计算Dice系数和95% Hausdorff距离等评估指标。

In [None]:
# 可视化学习曲线
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(train_dice_history) + 1), train_dice_history, label='训练Dice系数')
plt.plot(range(1, len(val_dice_history) + 1), val_dice_history, label='验证Dice系数')
plt.title('训练和验证Dice系数')
plt.xlabel('轮次')
plt.ylabel('Dice系数')
plt.legend()
plt.grid(True)
plt.show()

# 可视化分割结果
def visualize_segmentation(model, dataloader, num_samples=5):
    """可视化分割结果"""
    model.eval()
    dataiter = iter(dataloader)
    
    # 获取一批数据
    try:
        images, masks = next(dataiter)
    except StopIteration:
        print("数据集太小，无法获取足够的样本。")
        return
    
    # 限制样本数
    num_samples = min(num_samples, images.size(0))
    
    # 进行预测
    with torch.no_grad():
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)
        pred_masks = (torch.sigmoid(outputs) > 0.5).float()
    
    # 反标准化图像以便可视化
    images_np = []
    for img in images[:num_samples]:
        img = img.cpu().numpy().transpose(1, 2, 0)  # 转为HWC格式
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        images_np.append(img)
    
    # 准备掩码和预测
    masks_np = masks[:num_samples].cpu().numpy().squeeze(1)  # (N, H, W)
    pred_masks_np = pred_masks[:num_samples].cpu().numpy().squeeze(1)  # (N, H, W)
    
    # 可视化
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 4 * num_samples))
    
    for i in range(num_samples):
        # 原始图像
        axes[i, 0].imshow(images_np[i])
        axes[i, 0].set_title('原始图像')
        axes[i, 0].axis('off')
        
        # 真实掩码
        axes[i, 1].imshow(masks_np[i], cmap='gray')
        axes[i, 1].set_title('真实掩码')
        axes[i, 1].axis('off')
        
        # 预测掩码
        axes[i, 2].imshow(pred_masks_np[i], cmap='gray')
        dice = dice_coefficient(outputs[i:i+1], masks[i:i+1])
        axes[i, 2].set_title(f'预测掩码 (Dice: {dice:.4f})')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# 可视化一些样本
visualize_segmentation(model, val_loader, num_samples=5)

# 模型保存和加载
保存训练好的模型，以便将来加载并用于预测。

In [None]:
# 保存模型参数
def save_model(model, path):
    """保存模型参数到指定路径"""
    torch.save(model.state_dict(), path)
    print(f"模型参数已保存到 {path}")


# 模型预测与应用

现在我们已经训练并保存了最佳模型，可以在新的图像上使用它来进行翼状胬肉区域的分割预测。

测试数据的组织方式应与 `work1` 类似，通常包含一个图像文件夹。我们需要遍历测试图像，加载它们，进行预处理，然后使用加载的模型进行预测，最后将预测的掩码保存下来。

In [None]:
# --- 1. 加载训练好的模型 ---

# 确保模型定义可用 (UNet, DoubleConv, Down, Up, OutConv 类需要已定义)
# 初始化模型结构


loaded_model = UNet(n_channels=3, n_classes=1, bilinear=True).to(device)

# 加载最佳权重 (假设 best_model_weights 变量包含 state_dict)
if 'best_model_weights' in locals() and best_model_weights is not None:
    loaded_model.load_state_dict(best_model_weights)
    print("成功加载训练好的模型权重。")
else:
    # 如果没有 best_model_weights，尝试从文件加载（需要先保存）
    model_save_path = "pterygium_unet_model_best.pth" # 假设保存的文件名
    if os.path.exists(model_save_path):
        loaded_model.load_state_dict(torch.load(model_save_path, map_location=device))
        print(f"从 {model_save_path} 加载模型权重。")
    else:
        print("警告: 未找到训练好的模型权重 (best_model_weights 或文件)。模型将使用随机初始化的权重。")

loaded_model.eval() # 设置为评估模式

# --- 2. 定义预测所需的辅助函数 ---

# 使用验证集的变换来预处理输入图像
predict_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # 匹配训练尺寸
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def preprocess_single_image(image_path):
    """加载并预处理单张图像"""
    try:
        image = Image.open(image_path).convert("RGB")
        image_tensor = predict_transform(image)
        return image_tensor.unsqueeze(0) # 添加 batch 维度
    except FileNotFoundError:
        print(f"错误: 图像文件未找到 {image_path}")
        return None
    except Exception as e:
        print(f"错误: 处理图像时出错 {image_path}: {e}")
        return None

def predict_single_mask(model, image_tensor, device, threshold=0.5):
    """使用模型预测单张图像的掩码"""
    if image_tensor is None:
        return None
    image_tensor = image_tensor.to(device)
    with torch.no_grad():
        logits = model(image_tensor)
        probabilities = torch.sigmoid(logits)
        predicted_mask = (probabilities > threshold).float()
    return predicted_mask.squeeze(0).cpu().numpy() # 移除 batch 维度并转为 numpy

# --- 3. 对测试图像进行预测 ---

# 假设测试图像在 'f:/test_images/' 目录下
os.makedirs(output_maskfiles, exist_ok=True)

# 查找所有测试图像 (例如 .png 文件)
test_image_paths = glob.glob(os.path.join(output_maskfiles, "*.png"))

print(f"找到 {len(test_image_paths)} 张测试图像。")

# 遍历测试图像并进行预测
for img_path in tqdm(test_image_paths, desc="处理测试图像"):
    print(f"处理图像: {img_path}")
    input_tensor = preprocess_single_image(img_path)

    if input_tensor is not None:
        predicted_mask_np = predict_single_mask(loaded_model, input_tensor, device)

        if predicted_mask_np is not None:
            # predicted_mask_np 是 (1, H, W) 的 numpy 数组，值为 0 或 1
            # 可以将其保存为图像文件
            mask_image = Image.fromarray((predicted_mask_np.squeeze() * 255).astype(np.uint8), mode='L') # 转换为灰度图
            base_name = os.path.basename(img_path)
            save_path = os.path.join(output_maskfiles, f"{os.path.splitext(base_name)[0]}.png")
            mask_image.save(save_path)
            print(f"预测掩码已保存到: {save_path}")

            # 可选：可视化原始图像和预测掩码
            original_image = Image.open(img_path)
            fig, axes = plt.subplots(1, 2, figsize=(10, 5))
            axes[0].imshow(original_image)
            axes[0].set_title("原始图像")
            axes[0].axis('off')
            axes[1].imshow(mask_image, cmap='gray')
            axes[1].set_title("预测掩码")
            axes[1].axis('off')
            plt.show()

# 如果在 Colab 或 Kaggle 环境中，保存预测结果到指定目录
if 'google.colab' in sys.modules or os.path.exists("/kaggle/working"):
    if output_dir_to_zip and os.path.exists(output_dir_to_zip) and os.listdir(output_dir_to_zip):
        print(f"开始压缩目录: {output_dir_to_zip}")
        try:
            with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
                files_to_zip = glob.glob(os.path.join(output_dir_to_zip, '*.*')) # 获取目录下所有文件
                if not files_to_zip:
                    print(f"警告: 目录 {output_dir_to_zip} 为空，无需压缩。")
                else:
                    for file in tqdm(files_to_zip, desc="压缩文件"):
                        # arcname确保zip文件中不包含完整路径，只有文件名
                        zipf.write(file, arcname=os.path.basename(file))
                    print(f"预测结果已成功压缩到: {zip_file_path}")

                    # 压缩成功后删除原始文件
                    print(f"开始删除原始掩码文件于: {output_dir_to_zip}")
                    delete_count = 0
                    for file in tqdm(files_to_zip, desc="删除原始文件"):
                        try:
                            os.remove(file)
                            delete_count += 1
                        except OSError as e:
                            print(f"删除文件 {file} 时出错: {e}")
                    print(f"已成功删除 {delete_count} 个原始掩码文件。")
                    # 删除空目录
                    try:
                        os.rmdir(output_dir_to_zip)
                        print(f"已删除空目录: {output_dir_to_zip}")
                    except OSError as e:
                        print(f"删除目录 {output_dir_to_zip} 时出错: {e}")

        except Exception as e:
            print(f"压缩或删除文件时发生错误: {e}")
    elif output_dir_to_zip:
        print(f"目录 {output_dir_to_zip} 不存在或为空，跳过压缩和删除步骤。")


print("\n预测处理完成。")