<a href="https://colab.research.google.com/github/Tokisaki-Galaxy/PterygiumSeg/blob/master/work2_basemode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 翼状胬肉区域分割模型

这是项目的第二个任务：实现对眼部裂隙灯检查图片中翼状胬肉区域的精准分割。我们将使用U-Net模型解决这一问题。

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import collections
import torch.backends.cudnn as cudnn
from torchvision import transforms, models
from torchvision.transforms import functional as F
from skimage.morphology import binary_opening, binary_closing, disk, square
import albumentations as A # type: ignore
from albumentations.pytorch import ToTensorV2 # type: ignore
import torch.nn.functional
from copy import deepcopy
import pandas as pd
import os
import shutil
import numpy as np
from PIL import Image
import zipfile
import sys
import platform
import random
import time
import glob
from tqdm.autonotebook import tqdm # 好看！
%matplotlib inline
import matplotlib.pyplot as plt
from skimage.measure import label
import matplotlib.font_manager

try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    _xla_available = True
    print("torch_xla 导入成功。")
except ImportError:
    _xla_available = False
    print("torch_xla 未安装或导入失败。将使用 CUDA 或 CPU。")

In [None]:
if platform.system() == "Windows":
    num_workers = 0
    print(f"检测到 Windows 系统，将 DataLoader 的 num_workers 设置为 {num_workers}。")
else:
    # 在非 Windows 系统（如 Linux/Colab）上
    num_workers = 4
    print(f"检测到非 Windows 系统 ({platform.system()})，将 DataLoader 的 num_workers 设置为 {num_workers}。")
    # 设置中文字体
    if not os.path.exists('simhei.ttf'):
        !wget -O simhei.ttf "https://cdn.jsdelivr.net/gh/Haixing-Hu/latex-chinese-fonts/chinese/%E9%BB%91%E4%BD%93/SimHei.ttf"
    matplotlib.font_manager.fontManager.addfont('simhei.ttf')
    matplotlib.rc('font', family='SimHei')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
if _xla_available:
    num_workers = 0

patch_size = 512  # 定义 patch 大小
patch_stride = patch_size // 2 # 定义训练时的步长（产生50%重叠）
predict_stride = patch_size // 2 # 定义预测时的步长（产生50%重叠）
target_input_size = (patch_size, patch_size) # U-Net的输入尺寸应为 patch_size

# ================== 数据集路径 =================
# 数据路径
image_dir =          r"f:/train"
# colab路径
colab_zip_path = "/content/drive/My Drive/train.zip"
colab_extract_path = "/content/trains/"
# Kaggle路径
kaggle_extract_path = "/kaggle/input/pterygium/train/"
kaggle_temp_path = "/kaggle/working/"

# =================== 验证集路径 =================
# 验证集路径
val_image_dir =      r"f:/val_img"
# colab路径
# Kaggle路径
kaggle_val_path = "/kaggle/input/pterygium/val_img/"

# ================== 掩码输出路径 ================
output_mask_dir = r"f:/mask"
# colab路径
output_maskfiles_colab = "/content/mask"
# Kaggle路径
output_maskfiles_kaggle = "/kaggle/working/mask"

# ================== 训练参数 ==================
MIN_FOREGROUND_RATIO = 0.01 # 离线patching，保留的Mask Patch中前景像素(128)最小比例，设为0则不过滤
tpu_batch_size = 32 # 单核TPU的批处理大小，可以调整
cuda_batch_size = 40 # CUDA批处理大小
windows_batch_size = 6 # Windows批处理大小

# 配置GPU/TPU/CPU
if _xla_available:
    # 获取 XLA 设备 (TPU)
    # xm.xla_device() 会自动获取当前进程可用的 TPU核心
    device = xm.xla_device()
    print(f"检测到 TPU，使用的设备: {device}")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        cudnn.benchmark = True
        print("cuDNN benchmark 模式已启用")
    print(f"CUDA 可用: {torch.cuda.is_available()}")
    print(f"使用的设备: {device}")

# 读取和准备数据
我们需要读取原始图像和对应的分割标签（mask）。标签中像素值为128的区域表示翼状胬肉，像素值为0的区域表示背景。

In [None]:
# ========== 多环境变量设置 ==========
R2_ACCOUNT_ID = os.environ.get('R2_ACCOUNT_ID', 'YOUR_R2_ACCOUNT_ID')
R2_ACCESS_KEY_ID = os.environ.get('R2_ACCESS_KEY_ID', 'YOUR_R2_ACCESS_KEY_ID')
R2_SECRET_ACCESS_KEY = os.environ.get('R2_SECRET_ACCESS_KEY', 'YOUR_R2_SECRET_ACCESS_KEY')
R2_BUCKET_NAME = os.environ.get('R2_BUCKET_NAME', 'YOUR_R2_BUCKET_NAME')
R2_ENDPOINT_URL = f'https://{R2_ACCOUNT_ID}.r2.cloudflarestorage.com'

if 'google.colab' in sys.modules or os.path.exists("/kaggle/working"):
    if 'google.colab' in sys.modules:
        print('在 Google Colab 环境中运行')
        image_dir = os.path.join(colab_extract_path,"train")
        label_file = os.path.join(image_dir,"train_classification_label.xlsx")
        zip_path = colab_zip_path
        extract_path = colab_extract_path
        BASE_PATCH_DIR = "/content/train_patches_gpu"

        output_mask_dir = output_maskfiles_colab
        print(f"Colab 环境：验证结果将验证压缩 {output_mask_dir} 到 {output_mask_dir}.zip")

        # Mount Google Drive
        from google.colab import drive # type: ignore
        from google.colab import userdata # type: ignore
        drive.mount('/content/drive')
        R2_ACCESS_KEY_ID = userdata.get("R2_ACCESS_KEY_ID")
        R2_SECRET_ACCESS_KEY = userdata.get("R2_SECRET_ACCESS_KEY")
        R2_BUCKET_NAME = userdata.get("R2_BUCKET_NAME")
        R2_ENDPOINT_URL = userdata.get("R2_ENDPOINT_URL")
    else:
        print('在 Kaggle 环境中运行')
        image_dir = os.path.join(kaggle_extract_path,"train")
        label_file = os.path.join(image_dir,"train_classification_label.xlsx")
        val_image_dir = os.path.join(kaggle_val_path,"val_img")
        BASE_PATCH_DIR = "/kaggle/working/train_patches_gpu"
        
        output_mask_dir = output_maskfiles_kaggle
        print(f"Kaggle 环境：验证结果将压缩 {output_mask_dir} 到 {output_mask_dir}.zip")

        from kaggle_secrets import UserSecretsClient # type: ignore
        user_secrets = UserSecretsClient()
        R2_ACCESS_KEY_ID = user_secrets.get_secret("R2_ACCESS_KEY_ID")
        R2_SECRET_ACCESS_KEY = user_secrets.get_secret("R2_SECRET_ACCESS_KEY")
        R2_BUCKET_NAME = user_secrets.get_secret("R2_BUCKET_NAME")
        R2_ENDPOINT_URL = user_secrets.get_secret("R2_ENDPOINT_URL")

    if not os.path.exists(label_file):
        # 解压数据
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)    
else:
    print(f'不在云端环境中运行,使用本地数据路径{image_dir}')
    BASE_PATCH_DIR = "data/train_patches_gpu"
label_file = os.path.join(image_dir,"train_classification_label.xlsx")

# 离线Patching函数 (利用GPU加速)
用于将原始的大尺寸图像和对应的分割掩码离线切割成指定大小的Patches，并保存到磁盘。
该函数会尝试将大图加载到GPU进行裁剪和过滤，以加速处理过程（需要注意GPU显存）。

In [None]:
def create_offline_patches_gpu(
    input_image_dir,
    label_file_path,
    output_image_patch_dir,
    output_mask_patch_dir, # 如果为 None，则只处理图像 Patch
    patch_size,
    stride,
    device,
    min_foreground_ratio=0.01
    ):
    """
    离线创建图像和掩码的Patches，尝试使用GPU加速，并根据标签文件筛选图像。

    Args:
        input_image_dir (str): 包含原始图像子文件夹 (如 0001, 0002...) 的目录。
        label_file_path (str): 包含图像标签的Excel文件路径。
        output_image_patch_dir (str): 保存图像Patch的目录。
        output_mask_patch_dir (str or None): 保存掩码Patch的目录。如果为 None，则不处理或保存掩码。
        patch_size (int): Patch的边长。
        stride (int): 切割Patch时的步长。
        device (torch.device): 用于计算的设备 (cuda or cpu)。
        min_foreground_ratio (float): 保留的Mask Patch中前景像素(值>0)的最小比例。设为0则不过滤。

    Returns:
        tuple: (成功处理的大图数量, 创建的Patch对数量)
    """
    start_time = time.time()
    try:
        os.makedirs(output_image_patch_dir, exist_ok=True)
        if output_mask_patch_dir:
            os.makedirs(output_mask_patch_dir, exist_ok=True)
    except Exception as e:
        print(f"创建输出目录时出错: {e}")
        # 允许在 output_mask_patch_dir 为 None 时继续

    # 读取并筛选标签文件
    try:
        labels_df = pd.read_excel(label_file_path)
        # 只保留翼状胬肉样本（标签1和2）
        pterygium_df = labels_df[labels_df['Pterygium'] > 0].reset_index(drop=True)
        # 获取需要处理的图像文件夹名称列表 (格式化为 0001, 0002 ...)
        image_folders_to_process = pterygium_df['Image'].astype(int).apply(lambda x: f"{x:04d}").tolist()
        print(f"从 {os.path.basename(label_file_path)} 读取标签，找到 {len(image_folders_to_process)} 个翼状胬肉样本进行Patching。")
    except FileNotFoundError:
        print(f"错误: 标签文件未找到 {label_file_path}。无法进行Patching。")
        return 0, 0
    except Exception as e:
        print(f"读取或处理标签文件 {label_file_path} 时出错: {e}")
        return 0, 0

    processed_files = 0
    created_patches = 0

    for folder_name in tqdm(image_folders_to_process, desc="处理带标签的大图"):
        image_path = os.path.join(input_image_dir, folder_name, f"{folder_name}.png")
        mask_path = os.path.join(input_image_dir, folder_name, f"{folder_name}_label.png")

        # 检查图像文件是否存在
        if not os.path.exists(image_path):
            print(f"警告: 图像文件未找到 {image_path}，跳过文件夹 {folder_name}。")
            continue

        # 仅在需要处理掩码时检查掩码文件
        if output_mask_patch_dir and not os.path.exists(mask_path):
            print(f"警告: 掩码文件未找到 {mask_path} (但需要处理掩码)，跳过文件夹 {folder_name}。")
            continue
        elif not output_mask_patch_dir:
            # 如果不需要处理掩码，即使掩码不存在也继续（只生成图像patch）
            mask_path = None # 明确设置为 None

        try:
            # 1. 加载大图 (CPU)
            img_pil = Image.open(image_path).convert('RGB')
            img_w, img_h = img_pil.size
            mask_pil = None
            mask_tensor_gpu = None # 初始化

            # 仅在需要时加载和处理掩码
            if mask_path:
                mask_pil = Image.open(mask_path).convert('RGB')
                mask_w, mask_h = mask_pil.size
                if img_w != mask_w or img_h != mask_h:
                    print(f"警告: 图像和掩码尺寸不匹配 {folder_name}，跳过。 ({img_w}x{img_h} vs {mask_w}x{mask_h})")
                    continue

            # 2. 转换为Tensor并移至GPU (如果显存允许)
            try:
                img_tensor_gpu = F.to_tensor(img_pil).to(device) # (3, H, W)
                if mask_pil:
                    # 将Mask转换为0/1 Tensor，再移到GPU
                    mask_np_rgb = np.array(mask_pil)
                    mask_binary = (mask_np_rgb[:, :, 0] == 128).astype(np.float32) # 前景 = 1 if R == 128
                    mask_tensor_gpu = torch.from_numpy(mask_binary).unsqueeze(0).to(device) # (1, H, W)

            except RuntimeError as e:
                print(f"\n错误: 将图像/掩码 {folder_name} 移至GPU时出错 (可能显存不足): {e}")
                print("尝试在CPU上处理此图像...")
                device_fallback = torch.device("cpu")
                img_tensor_gpu = F.to_tensor(img_pil).to(device_fallback)
                if mask_pil:
                    mask_np_rgb = np.array(mask_pil)
                    mask_binary = (mask_np_rgb[:, :, 0] == 128).astype(np.float32)
                    mask_tensor_gpu = torch.from_numpy(mask_binary).unsqueeze(0).to(device_fallback)

            # 3. 在GPU上进行Patch裁剪和过滤
            patch_count_for_image = 0
            for y in range(0, img_h - patch_size + 1, stride):
                for x in range(0, img_w - patch_size + 1, stride):
                    # 在GPU上裁剪图像
                    img_patch_gpu = img_tensor_gpu[:, y:y+patch_size, x:x+patch_size]
                    mask_patch_gpu = None # 初始化

                    # 仅在需要时裁剪掩码
                    if mask_tensor_gpu is not None:
                        mask_patch_gpu = mask_tensor_gpu[:, y:y+patch_size, x:x+patch_size]

                        # 在GPU上过滤 (基于前景比例) - 仅在有掩码时进行
                        if min_foreground_ratio > 0:
                            foreground_ratio = torch.mean(mask_patch_gpu) # mask是0/1，均值即比例
                            if foreground_ratio < min_foreground_ratio:
                                continue # 跳过前景过少的Patch

                    # 4. 将需要保存的Patch移回CPU
                    img_patch_cpu = img_patch_gpu.cpu()
                    mask_patch_cpu = None
                    if mask_patch_gpu is not None:
                        mask_patch_cpu = mask_patch_gpu.cpu() # 仍然是 0/1

                    # 5. 转换回PIL Image并保存 (CPU)
                    img_patch_pil = F.to_pil_image(img_patch_cpu)

                    # 生成保存文件名
                    patch_filename = f"{folder_name}_y{y}_x{x}.png"
                    img_save_path = os.path.join(output_image_patch_dir, patch_filename)
                    img_patch_pil.save(img_save_path)

                    # 仅在需要时保存掩码 Patch
                    if output_mask_patch_dir and mask_patch_cpu is not None:
                        # 将 0/1 的 mask tensor 转换回 0/128 的 PIL 灰度图
                        mask_patch_np = (mask_patch_cpu.squeeze().numpy() * 128).astype(np.uint8)
                        # 保存为 RGB 格式 (128,0,0) 以匹配原始格式
                        mask_patch_pil_rgb = Image.new("RGB", (patch_size, patch_size))
                        # 创建一个与mask_patch_np形状相同，但值为0的数组
                        zeros_channel = np.zeros_like(mask_patch_np)
                        # 堆叠通道 R=mask, G=0, B=0
                        mask_rgb_array = np.stack((mask_patch_np, zeros_channel, zeros_channel), axis=-1)
                        mask_patch_pil = Image.fromarray(mask_rgb_array, mode='RGB')

                        mask_save_path = os.path.join(output_mask_patch_dir, patch_filename)
                        mask_patch_pil.save(mask_save_path)

                    created_patches += 1
                    patch_count_for_image += 1

            # 6. 清理GPU内存中的大Tensor (重要!)
            del img_tensor_gpu
            if mask_tensor_gpu is not None: del mask_tensor_gpu
            if 'img_patch_gpu' in locals(): del img_patch_gpu
            if 'mask_patch_gpu' in locals() and mask_patch_gpu is not None: del mask_patch_gpu
            if device == torch.device("cuda"):
                torch.cuda.empty_cache() # 释放缓存

            processed_files += 1

        except Exception as e:
            print(f"\n处理文件 {folder_name} 时发生未预料错误: {e}")
            if device == torch.device("cuda"):
                torch.cuda.empty_cache()

    end_time = time.time()
    print("-" * 30)
    print("离线Patching完成！")
    print(f"成功处理（基于标签过滤后）大图数量: {processed_files} / {len(image_folders_to_process)}")
    print(f"总共创建Patch数量: {created_patches}") # 注意：如果是图像+掩码对，这是一个对数
    print(f"总耗时: {end_time - start_time:.2f} 秒")
    print("-" * 30)

    return processed_files, created_patches

def get_folder_size(folder_path):
    """计算文件夹的总大小（字节）"""
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(folder_path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            total_size += os.path.getsize(fp)
    return total_size

## Cloudflare R2 缓存patching

In [None]:
import boto3
import botocore
def create_r2_client():
    """尝试创建并返回一个配置好的 boto3 R2 客户端。"""
    # 确认环境变量已加载 (这些变量应在之前的单元格中设置)
    required_vars = ['R2_ENDPOINT_URL', 'R2_ACCESS_KEY_ID', 'R2_SECRET_ACCESS_KEY', 'R2_BUCKET_NAME']
    if not all(var in globals() and globals()[var] for var in required_vars):
        print("R2 配置不完整（缺少 Endpoint URL, Access Key, Secret Key 或 Bucket Name）。跳过 R2 缓存。")
        return None, False # 返回 None 和 R2 未配置标志

    global r2_configured # 声明我们要修改全局变量
    r2_configured = True # 标记 R2 已配置

    try:
        print("正在创建 R2 (boto3 S3) 客户端...")
        s3_client = boto3.client(
            service_name='s3',
            endpoint_url=R2_ENDPOINT_URL,
            aws_access_key_id=R2_ACCESS_KEY_ID,
            aws_secret_access_key=R2_SECRET_ACCESS_KEY,
            region_name='auto', # R2 通常使用 'auto'
            config=botocore.config.Config(signature_version='s3v4') # 明确签名版本
        )
        # 尝试列出 buckets (可选，作为连接测试)
        # s3_client.list_buckets()
        print("R2 客户端创建成功。")
        return s3_client, True
    except Exception as e:
        print(f"创建 R2 客户端时出错: {e}")
        r2_configured = False # 出错则标记为未配置
        return None, False

def check_r2_cache(s3_client, bucket_name, cache_key):
    """检查指定的缓存键是否存在于 R2 存储桶中。"""
    if not s3_client: return False
    try:
        s3_client.head_object(Bucket=bucket_name, Key=cache_key)
        return True
    except botocore.exceptions.ClientError as e:
        if e.response['Error']['Code'] == '404':
            return False # 文件未找到
        else:
            # 其他错误 (如权限问题)
            print(f"检查 R2 缓存时出错 (Key: {cache_key}): {e}")
            return False
    except Exception as e:
        print(f"检查 R2 缓存时发生未知错误: {e}")
        return False

def download_from_r2(s3_client, bucket_name, cache_key, local_path):
    """从 R2 下载文件到本地路径，带进度条。"""
    if not s3_client: return False
    try:
        # 获取文件大小以显示进度
        response = s3_client.head_object(Bucket=bucket_name, Key=cache_key)
        total_size = int(response.get('ContentLength', 0))

        print(f"正在从 R2 下载 {cache_key} 到 {local_path} ({total_size / (1024*1024):.2f} MB)...")
        with tqdm(total=total_size, unit='B', unit_scale=True, desc=cache_key, leave=False) as pbar:
            s3_client.download_file(
                Bucket=bucket_name,
                Key=cache_key,
                Filename=local_path,
                Callback=lambda bytes_transferred: pbar.update(bytes_transferred)
            )
        print(f"文件 {cache_key} 下载完成。")
        return True
    except botocore.exceptions.ClientError as e:
        print(f"从 R2 下载文件时出错 (Key: {cache_key}): {e}")
        # 如果文件下载失败，尝试删除本地可能不完整的文件
        if os.path.exists(local_path):
            try: os.remove(local_path)
            except: pass
        return False
    except Exception as e:
        print(f"下载 R2 文件时发生未知错误: {e}")
        if os.path.exists(local_path):
            try: os.remove(local_path)
            except: pass
        return False

def upload_to_r2(s3_client, bucket_name, local_path, cache_key):
    """将本地文件上传到 R2，带进度条。"""
    if not s3_client or not os.path.exists(local_path):
        print(f"上传 R2 失败：客户端未初始化或本地文件不存在 ({local_path})。")
        return False
    try:
        total_size = os.path.getsize(local_path)
        print(f"正在上传 {local_path} ({total_size / (1024*1024):.2f} MB) 到 R2 作为 {cache_key}...")
        with tqdm(total=total_size, unit='B', unit_scale=True, desc=cache_key, leave=False) as pbar:
            s3_client.upload_file(
                Filename=local_path,
                Bucket=bucket_name,
                Key=cache_key,
                Callback=lambda bytes_transferred: pbar.update(bytes_transferred)
            )
        print(f"文件 {cache_key} 上传完成。")
        return True
    except botocore.exceptions.ClientError as e:
        print(f"上传文件到 R2 时出错 (Key: {cache_key}): {e}")
        return False
    except Exception as e:
        print(f"上传 R2 文件时发生未知错误: {e}")
        return False

def zip_directory(folder_path, zip_path):
    """压缩指定文件夹的内容到 zip 文件。"""
    if not os.path.isdir(folder_path):
        print(f"错误：要压缩的文件夹不存在 {folder_path}")
        return False
    print(f"正在压缩目录 {folder_path} 到 {zip_path}...")
    try:
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            # 获取文件夹内的所有文件和子文件夹
            file_paths = []
            for root, dirs, files in os.walk(folder_path):
                for filename in files:
                    file_paths.append(os.path.join(root, filename))

            # 使用 tqdm 显示压缩进度 (按文件数)
            with tqdm(total=len(file_paths), desc="压缩文件", unit="file", leave=False) as pbar:
                for file in file_paths:
                    # 计算文件在 zip 中的相对路径
                    arcname = os.path.relpath(file, folder_path)
                    zipf.write(file, arcname)
                    pbar.update(1)
        print("目录压缩完成。")
        return True
    except Exception as e:
        print(f"压缩目录时出错: {e}")
        # 如果压缩失败，删除可能不完整的 zip 文件
        if os.path.exists(zip_path):
            try: os.remove(zip_path)
            except: pass
        return False

def unzip_directory(zip_path, extract_to_folder):
    """解压缩 zip 文件到指定文件夹。"""
    if not os.path.exists(zip_path):
        print(f"错误：要解压的 zip 文件不存在 {zip_path}")
        return False
    print(f"正在解压缩文件 {zip_path} 到 {extract_to_folder}...")
    try:
        os.makedirs(extract_to_folder, exist_ok=True) # 确保目标文件夹存在
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # 获取 zip 文件中的成员数量以显示进度
            total_files = len(zip_ref.namelist())
            with tqdm(total=total_files, desc="解压缩文件", unit="file", leave=False) as pbar:
                # 使用 extractall 并更新进度条可能不直接，改为逐个提取
                for member in zip_ref.infolist():
                    zip_ref.extract(member, extract_to_folder)
                    pbar.update(1)
                    # 或者直接用 extractall，进度条可能不准确但更快
                    # zip_ref.extractall(extract_to_folder)
        print("文件解压缩完成。")
        return True
    except Exception as e:
        print(f"解压缩文件时出错: {e}")
        # 如果解压失败，可以选择是否删除不完整的解压目录
        # if os.path.exists(extract_to_folder):
        #     shutil.rmtree(extract_to_folder)
        return False


# 执行离线Patching
调用函数开始处理
如果输出目录已存在且包含文件，你可能想先清空或跳过

In [None]:
if os.path.exists("/kaggle/input/pterygium/train/"):
    INPUT_IMAGE_DIR = "/kaggle/input/pterygium/train/train"
    OUTPUT_PATCH_DIR = "/kaggle/working/train_patches"
elif 'google.colab' in sys.modules:
    INPUT_IMAGE_DIR = "/content/trains/train"
    OUTPUT_PATCH_DIR = "/content/train_patches"
elif os.name == 'nt':
    INPUT_IMAGE_DIR = "f:/train"
    OUTPUT_PATCH_DIR = "f:/train_patches"
else:
    INPUT_IMAGE_DIR = "data/train/train"
    OUTPUT_PATCH_DIR = "data/train_patches"

OUTPUT_IMAGE_PATCH_DIR = os.path.join(OUTPUT_PATCH_DIR, "images")
OUTPUT_MASK_PATCH_DIR = os.path.join(OUTPUT_PATCH_DIR, "masks")
LOCAL_TEMP_ZIP_PATH = os.path.join(os.path.dirname(OUTPUT_PATCH_DIR), "patch_cache_temp.zip") # 临时zip文件路径

print(f"使用设备: {device}") # 确保 device 已定义
print(f"输入图像目录: {INPUT_IMAGE_DIR}")
print(f"输出 Patch 目录: {OUTPUT_PATCH_DIR}")
print(f"Patch 大小: {patch_size}x{patch_size}")
print(f"步长: {patch_stride}")
print(f"最小前景比例阈值: {MIN_FOREGROUND_RATIO}")

# --- 生成缓存键 ---
# 注意：如果输入数据或标签文件内容变化，这个 key 不会变，需要更复杂的策略
# 但对于固定数据集和参数，这个 key 是有效的
try:
    cache_key = f"v1-{os.path.basename(INPUT_IMAGE_DIR)}-{os.path.basename(label_file)}-{patch_size}-{patch_stride}-{MIN_FOREGROUND_RATIO}"
    R2_CACHE_KEY = f"patch_cache_{cache_key}.zip"
    print(f"生成的 R2 缓存键: {R2_CACHE_KEY}")
except Exception as e:
    print(f"生成缓存键时出错: {e}。无法使用 R2 缓存。")

# --- R2 缓存检查与处理 ---
r2_client, r2_configured = create_r2_client()
run_patching = True # 默认需要执行 patching
patch_count = 0   # 初始化 patch 数量

if r2_client and r2_configured:
    print("\n--- 正在检查 R2 缓存 ---")
    if check_r2_cache(r2_client, R2_BUCKET_NAME, R2_CACHE_KEY):
        print(f"在 R2 上找到缓存文件: {R2_CACHE_KEY}")
        # 检查本地目录是否需要更新
        should_download = True
        if os.path.exists(OUTPUT_PATCH_DIR):
            print(f"本地目录 {OUTPUT_PATCH_DIR} 已存在。将删除旧目录以下载最新缓存。")
            try:
                shutil.rmtree(OUTPUT_PATCH_DIR)
            except Exception as e:
                print(f"删除本地旧目录 {OUTPUT_PATCH_DIR} 时出错: {e}。继续尝试下载...")
        else:
            print(f"本地目录 {OUTPUT_PATCH_DIR} 不存在。准备下载缓存。")


        if should_download:
            if download_from_r2(r2_client, R2_BUCKET_NAME, R2_CACHE_KEY, LOCAL_TEMP_ZIP_PATH):
                if unzip_directory(LOCAL_TEMP_ZIP_PATH, OUTPUT_PATCH_DIR):
                    print("成功从 R2 下载并解压缓存。")
                    run_patching = False # 不需要本地 patching
                    # 清理下载的 zip 文件
                    try:
                        os.remove(LOCAL_TEMP_ZIP_PATH)
                        print(f"已删除临时文件: {LOCAL_TEMP_ZIP_PATH}")
                    except Exception as e:
                        print(f"删除临时文件 {LOCAL_TEMP_ZIP_PATH} 时出错: {e}")
                    # 统计下载的 patches 数量
                    try:
                        existing_patches = glob.glob(os.path.join(OUTPUT_IMAGE_PATCH_DIR, "*.png"))
                        patch_count = len(existing_patches)
                        print(f"使用 R2 缓存中的 {patch_count} 个 patches。")
                    except Exception as e:
                        print(f"无法统计下载的 patches 数量: {e}")
                        patch_count = 0
                else:
                    print("解压 R2 缓存失败。将尝试本地 Patching。")
                    # 清理可能不完整的解压目录
                    if os.path.exists(OUTPUT_PATCH_DIR): shutil.rmtree(OUTPUT_PATCH_DIR)
            else:
                print("从 R2 下载缓存失败。将尝试本地 Patching。")
    else:
        print(f"在 R2 上未找到对应的缓存文件 ({R2_CACHE_KEY})。")
else:
    print("\nR2 客户端未配置或创建失败，将跳过 R2 缓存检查。")

# --- 本地 Patching (如果需要) ---
if run_patching:
    print("\n--- 执行本地 Patching ---")
    os.makedirs(OUTPUT_PATCH_DIR, exist_ok=True)

    # 调用你的 patching 函数
    processed_count, patch_count = create_offline_patches_gpu(
        input_image_dir=INPUT_IMAGE_DIR,
        label_file_path=label_file,
        output_image_patch_dir=OUTPUT_IMAGE_PATCH_DIR,
        output_mask_patch_dir=OUTPUT_MASK_PATCH_DIR,
        patch_size=patch_size,
        stride=patch_stride,
        device=device, # 确保 device 已定义
        min_foreground_ratio=MIN_FOREGROUND_RATIO
    )

    # --- 上传到 R2 (如果 Patching 成功且 R2 已配置) ---
    if patch_count > 0 and r2_client and r2_configured:
        print("\n--- 准备上传 Patching 结果到 R2 ---")
        if zip_directory(OUTPUT_PATCH_DIR, LOCAL_TEMP_ZIP_PATH):
            if upload_to_r2(r2_client, R2_BUCKET_NAME, LOCAL_TEMP_ZIP_PATH, R2_CACHE_KEY):
                print(f"成功上传缓存 {R2_CACHE_KEY} 到 R2。")
            else:
                print(f"上传缓存 {R2_CACHE_KEY} 到 R2 失败。")
            # 清理本地 zip 文件
            try:
                os.remove(LOCAL_TEMP_ZIP_PATH)
                print(f"已删除本地临时 zip 文件: {LOCAL_TEMP_ZIP_PATH}")
            except Exception as e:
                print(f"删除本地临时 zip 文件 {LOCAL_TEMP_ZIP_PATH} 时出错: {e}")
        else:
            print("压缩 Patching 结果失败，无法上传到 R2。")
    elif patch_count == 0:
        print("本地 Patching 未生成任何文件，跳过上传。")
    else: # R2 未配置
        print("R2 未配置，跳过上传缓存步骤。")

elif not run_patching:
    print("\nPatching 步骤已跳过（使用 R2 缓存）。")
else:
    raise Exception("出现未知状态，未执行 Patching 也未使用缓存。")

# 验证生成的Patches
随机抽查几个生成的图像和掩码Patch，确保它们是对应的并且格式正确。

In [None]:
def verify_patches(image_patch_dir, mask_patch_dir, num_samples=5):
    image_patches = glob.glob(os.path.join(image_patch_dir, "*.png"))
    if not image_patches:
        print("错误: 找不到生成的图像Patches。")
        return

    print(f"\n随机抽查 {num_samples} 个生成的Patch对...")
    random_samples = random.sample(image_patches, min(num_samples, len(image_patches)))

    fig, axes = plt.subplots(len(random_samples), 2, figsize=(8, 4 * len(random_samples)))
    if len(random_samples) == 1: # 处理只有一个样本的情况
        axes = axes.reshape(1, 2)

    for i, img_path in enumerate(random_samples):
        base_name = os.path.basename(img_path)
        mask_path = os.path.join(mask_patch_dir, base_name)

        if not os.path.exists(mask_path):
            print(f"错误: 找不到对应的掩码Patch {mask_path}")
            continue

        img_patch = Image.open(img_path)
        mask_patch = Image.open(mask_path)

        axes[i, 0].imshow(img_patch)
        axes[i, 0].set_title(f"图像 Patch:\n{base_name}")
        axes[i, 0].axis('off')

        axes[i, 1].imshow(mask_patch, cmap='gray')
        axes[i, 1].set_title(f"掩码 Patch (0/128):\n{base_name}")
        axes[i, 1].axis('off')

    plt.tight_layout()
    plt.show()

# 执行验证
try:
    verify_patches(OUTPUT_IMAGE_PATCH_DIR, OUTPUT_MASK_PATCH_DIR, num_samples=5)
except:
    print("未生成任何Patch，跳过验证。")

# 数据增强

In [None]:
# 训练时的增强
train_transform = A.Compose([
    # --- 空间变换 ---
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.0625, # 平移范围 (图像尺寸的百分比)
        scale_limit=0.1,    # 缩放范围 (+/- 10%)
        rotate_limit=15,    # 旋转范围 (+/- 15度)
        interpolation=1,    # cv2.INTER_LINEAR for image
        border_mode=0,      # cv2.BORDER_CONSTANT (padding mode)
        value=0,            # padding value (for image)
        mask_value=0,       # padding value (for mask)
        p=0.7 # 应用这个组合变换的概率
    ),
    # 弹性变形
    A.ElasticTransform(
        alpha=1,        # 强度参数
        sigma=50,       # 高斯核标准差
        alpha_affine=50,# 仿射部分强度
        interpolation=1,
        border_mode=0,
        value=0,
        mask_value=0,
        p=0.5
    ),

    # --- 强度/颜色变换 (只作用于图像) ---
    # OneOf: 从列表中随机选一个应用
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.5),
    ], p=0.3), # 应用其中一种噪声/模糊的概率
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05, p=0.5),
    A.RandomGamma(gamma_limit=(80, 120), p=0.4), # gamma 在 0.8 到 1.2 之间

    # --- 遮挡 ---
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, # 随机挖洞
                    min_holes=1, min_height=8, min_width=8,
                    fill_value=0, mask_fill_value=0, p=0.3), # 只对图像生效(mask_fill_value=0)

    # --- 标准化 & 转 Tensor ---
    # 注意：Normalize 必须在 ToTensorV2 之前或之后都可以，但通常放在最后
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2() # 将 NumPy [H,W,C] 转为 PyTorch [C,H,W]
])

# 验证/测试时的变换 (通常只有 Resize, Normalize, ToTensor)
# 模型输入是 Patch，这里不需要 Resize
val_transform_alb = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# 创建数据加载器
设置训练和验证数据加载器，包括数据增强策略。

In [None]:
class PterygiumSegDataset(Dataset):
    def __init__(self, image_patch_dir, mask_patch_dir, file_list=None, transform=None):
            """
            初始化数据集 (加载离线 Patches)
            :param image_patch_dir: 包含图像 patch 文件的目录
            :param mask_patch_dir: 包含对应掩码 patch 文件的目录
            :param file_list: (可选) 一个文件名列表，只加载这些文件。如果为 None，加载目录下所有文件。
            :param transform: (可选) 预处理和数据增强的转换函数
            """
            self.image_patch_dir = image_patch_dir
            self.mask_patch_dir = mask_patch_dir
            self.transform = transform
    
            # 获取 patch 的文件名列表
            if file_list is None:
                self.image_filenames = sorted([
                    f for f in os.listdir(image_patch_dir)
                    if os.path.isfile(os.path.join(image_patch_dir, f)) and f.endswith('.png')
                ])
            else:
                self.image_filenames = sorted([f for f in file_list if f.endswith('.png')]) # 使用提供的列表

            if not self.image_filenames:
                # 检查是否是因为目录不存在或为空
                if not os.path.isdir(image_patch_dir):
                    raise FileNotFoundError(f"图像 patch 目录不存在: {image_patch_dir}")
                elif not os.listdir(image_patch_dir) and file_list is None:
                    print(f"警告: 目录 {image_patch_dir} 为空。")
                    # Dataset 将为空，len() == 0
                elif file_list is not None and not self.image_filenames:
                    print(f"警告: 提供的 file_list 为空或不包含 .png 文件。")
                else:
                    # 目录存在但 filtered list 为空
                    print(f"警告: 在目录 {image_patch_dir} 中未找到匹配的 .png 文件 (可能检查 file_list)。")
    
            # 定义图像标准化 (这个总是在最后应用)
            self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            
            print(f"数据集初始化: 找到 {len(self.image_filenames)} 个 patches in {image_patch_dir}" + (f" (来自 file_list)" if file_list is not None else ""))

    def __len__(self):
        """返回数据集中样本的数量 (即 patch 的数量)"""
        return len(self.image_filenames)

    def __getitem__(self, idx):
        """
        获取指定索引的图像 patch 和 掩码 patch
        :param idx: 索引
        :return: 图像 patch 张量和对应掩码 patch 张量
        """
        # 获取文件名
        patch_filename = self.image_filenames[idx]
        
        # 构建完整路径
        img_path = os.path.join(self.image_patch_dir, patch_filename)
        mask_path = os.path.join(self.mask_patch_dir, patch_filename)

        # 加载图像和掩码 patch (PIL Images)
        try:
            image_patch = Image.open(img_path).convert("RGB")
            mask_patch = Image.open(mask_path).convert("RGB") # 掩码是 (128,0,0)
        except FileNotFoundError:
            print(f"错误: 文件未找到 {img_path} 或 {mask_path}")
            # 返回一个虚拟数据或引发错误
            dummy_size = (3, 256, 256) # 假设 patch size 是 256x256，如果不同需要修改
            if hasattr(self, 'patch_size'): dummy_size = (3, self.patch_size, self.patch_size)
            return torch.zeros(dummy_size), torch.zeros((1, dummy_size[1], dummy_size[2]))
        except Exception as e:
            print(f"加载 patch 时出错 {patch_filename}: {e}")
            # 返回虚拟数据
            dummy_size = (3, 256, 256)
            if hasattr(self, 'patch_size'): dummy_size = (3, self.patch_size, self.patch_size)
            return torch.zeros(dummy_size), torch.zeros((1, dummy_size[1], dummy_size[2]))


        # 应用 albumentations 变换
        if self.transform:
            augmented = self.transform(image=image_patch, mask=mask_binary)
            image_tensor = augmented['image'] # 已经是 Tensor [C,H,W]
            mask_tensor = augmented['mask'].unsqueeze(0) # 已经是 Tensor [H,W]，增加通道维度 -> [1,H,W]
        else:
            # 如果没有指定变换，直接标准化和转 Tensor
            # --- 转换为 Tensor ---
            image_tensor = F.to_tensor(image_patch)
            # --- 掩码处理：转换为单通道二值Tensor ---
            # 将 PIL Mask 转换为 NumPy array (H, W, C)
            mask_np = np.array(mask_patch)
            # 检查红色通道是否为 128 来确定翼状胬肉区域
            # mask_binary 的形状将是 (H, W)，值为 True/False
            mask_binary = (mask_np[:, :, 0] == 128)
            # 转换为 float32 类型的 0.0 或 1.0
            mask_np_float = mask_binary.astype(np.float32)
            # 转换为 PyTorch Tensor，并增加一个通道维度 (H, W) -> (1, H, W)
            mask_tensor = torch.from_numpy(mask_np_float).unsqueeze(0)
            # --- 标准化图像 ---
            image_tensor = self.normalize(image_tensor)

        return image_tensor, mask_tensor

In [None]:
# 确保 Patch 目录存在且包含文件
if not os.path.exists(OUTPUT_IMAGE_PATCH_DIR) or not os.listdir(OUTPUT_IMAGE_PATCH_DIR):
    raise FileNotFoundError(f"错误：无法找到生成的图像 Patches 于 {OUTPUT_IMAGE_PATCH_DIR}。请先成功执行 Patching 步骤。")
if not os.path.exists(OUTPUT_MASK_PATCH_DIR) or not os.listdir(OUTPUT_MASK_PATCH_DIR):
    raise FileNotFoundError(f"错误：无法找到生成的掩码 Patches 于 {OUTPUT_MASK_PATCH_DIR}。请先成功执行 Patching 步骤。")

# 1. 获取所有 Patch 文件名
all_patch_files = sorted([
    f for f in os.listdir(OUTPUT_IMAGE_PATCH_DIR)
    if os.path.isfile(os.path.join(OUTPUT_IMAGE_PATCH_DIR, f)) and f.endswith('.png')
])

if not all_patch_files:
    raise ValueError(f"错误：在 {OUTPUT_IMAGE_PATCH_DIR} 中未找到任何 Patch 文件。")

# 2. 按原始图像 ID 分组 Patch 文件名
#    假设文件名格式为 "XXXX_yYYY_xZZZ.png"，其中 XXXX 是原始图像 ID
patches_by_original_image = collections.defaultdict(list)
for filename in all_patch_files:
    try:
        # 提取前4位作为原始图像 ID
        original_image_id = filename[:4]
        # 检查是否是有效的数字ID (可选，增加健壮性)
        if original_image_id.isdigit():
            patches_by_original_image[original_image_id].append(filename)
        else:
            print(f"警告：无法从文件名 {filename} 中提取有效的原始图像 ID，跳过此文件。")
    except IndexError:
        print(f"警告：文件名 {filename} 格式不符合预期，无法提取原始图像 ID，跳过此文件。")

if not patches_by_original_image:
    raise ValueError("错误：无法根据文件名对任何 Patch 进行分组。请检查文件名格式。")

# 3. 获取所有唯一的原始图像 ID
unique_original_ids = sorted(list(patches_by_original_image.keys()))
print(f"从 {len(all_patch_files)} 个 Patches 中识别出 {len(unique_original_ids)} 个唯一的原始图像来源。")

# 4. 在唯一的原始图像 ID 层面进行训练/验证划分
#    这样可以保证同一原始图像的所有 Patches 都在同一个集合中
val_split_ratio = 0.2 # 验证集占原始图像 ID 的比例
try:
    train_ids, val_ids = train_test_split(
        unique_original_ids,
        test_size=val_split_ratio,
        random_state=42 # 保证划分可复现
    )
except ValueError as e:
    print(f"错误：无法进行 train_test_split。可能是唯一 ID 数量过少。错误信息: {e}")
    # 可以根据情况处理，例如使用所有数据进行训练，或者调整比例
    if len(unique_original_ids) < 2:
        print("唯一原始图像 ID 少于 2 个，无法划分验证集。将使用所有数据进行训练。")
        train_ids = unique_original_ids
        val_ids = []
    else:
        # 其他错误，重新抛出
        raise e

print(f"划分结果：{len(train_ids)} 个原始图像用于训练，{len(val_ids)} 个原始图像用于验证。")

# 5. 根据划分的 ID 列表，构建训练和验证的 Patch 文件名列表
train_filenames = []
for img_id in train_ids:
    train_filenames.extend(patches_by_original_image[img_id])

val_filenames = []
for img_id in val_ids:
    val_filenames.extend(patches_by_original_image[img_id])

print(f"划分后的 Patch 数量：训练集 {len(train_filenames)} Patches，验证集 {len(val_filenames)} Patches。")
# 注意：这里的 Patch 数量比例不一定会严格等于原始图像 ID 的比例 (val_split_ratio)，
# 因为不同原始图像产生的 Patch 数量可能不同。

# 6. 创建独立的训练和验证 Dataset 实例
train_dataset_offline = PterygiumSegDataset(
    image_patch_dir=OUTPUT_IMAGE_PATCH_DIR,
    mask_patch_dir=OUTPUT_MASK_PATCH_DIR,
    file_list=train_filenames, # 传入分组后的训练文件名列表
    transform=train_transform
)
# 只有在 val_filenames 不为空时才创建验证集
if val_filenames:
    val_dataset_offline = PterygiumSegDataset(
        image_patch_dir=OUTPUT_IMAGE_PATCH_DIR,
        mask_patch_dir=OUTPUT_MASK_PATCH_DIR,
        file_list=val_filenames, # 传入分组后的验证文件名列表
        augment=False # 验证时不进行增强
        transform=val_transform_alb
    )
else:
    val_dataset_offline = None # 如果没有验证 ID，则验证集为空
    print("警告：验证集为空。")


# 7. 创建 DataLoader
train_loader_batch_size = tpu_batch_size if _xla_available else (cuda_batch_size if torch.cuda.is_available() else windows_batch_size)
val_loader_batch_size = train_loader_batch_size

train_loader = DataLoader(
    train_dataset_offline,
    batch_size=train_loader_batch_size,
    shuffle=True, # 训练时打乱顺序
    num_workers=num_workers,
    pin_memory=not _xla_available and torch.cuda.is_available(),
    drop_last=True if _xla_available else False
)

# 只有在验证集存在时才创建 val_loader
if val_dataset_offline:
    val_loader = DataLoader(
        val_dataset_offline,
        batch_size=val_loader_batch_size,
        shuffle=False, # 验证时不需要打乱
        num_workers=num_workers,
        pin_memory=not _xla_available and torch.cuda.is_available(),
        drop_last=True if _xla_available else False
    )
    print(f"\n训练集大小 (Patches): {len(train_dataset_offline)}")
    print(f"验证集大小 (Patches): {len(val_dataset_offline)}")
    print(f"训练 DataLoader 批次数: {len(train_loader)}")
    print(f"验证 DataLoader 批次数: {len(val_loader)}")
else:
    val_loader = None
    print(f"\n训练集大小 (Patches): {len(train_dataset_offline)}")
    print("验证集为空，无法创建验证 DataLoader。")
    # 可能需要调整后续的训练/验证循环逻辑，跳过验证步骤

# 清理不再需要的大列表
del all_patch_files, patches_by_original_image, unique_original_ids
del train_ids, val_ids #, train_filenames, val_filenames # 如果后续不再需要这些列表

# 测试数据加载器
取一个批次的数据出来看看形状和内容是否正确。

In [None]:
try:
    images_batch, masks_batch = next(iter(train_loader))
    print("\n从训练加载器获取一个批次的数据:")
    print(f"图像批次形状: {images_batch.shape}") # 应该类似 [BATCH_SIZE, 3, PATCH_SIZE, PATCH_SIZE]
    print(f"掩码批次形状: {masks_batch.shape}")   # 应该类似 [BATCH_SIZE, 1, PATCH_SIZE, PATCH_SIZE]
    print(f"图像数据类型: {images_batch.dtype}")
    print(f"掩码数据类型: {masks_batch.dtype}")
    print(f"掩码最小值: {masks_batch.min()}") # 应该是 0.0
    print(f"掩码最大值: {masks_batch.max()}") # 应该是 1.0
    
    # 可视化一个样本
    sample_idx = 0
    img_sample = images_batch[sample_idx].permute(1, 2, 0).numpy() # CHW -> HWC
    mask_sample = masks_batch[sample_idx].squeeze().numpy()       # 1HW -> HW
    
    # 反归一化以便显示
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_sample = std * img_sample + mean
    img_sample = np.clip(img_sample, 0, 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(img_sample)
    axes[0].set_title("示例图像 Patch (反归一化)")
    axes[0].axis('off')
    axes[1].imshow(mask_sample, cmap='gray')
    axes[1].set_title("示例掩码 Patch (0/1)")
    axes[1].axis('off')
    plt.show()

except StopIteration:
    print("错误：无法从 DataLoader 获取数据，请检查数据集是否为空或配置是否正确。")
except Exception as e:
    print(f"测试 DataLoader 时出错: {e}")

# 构建U-Net分割模型
U-Net是一种经典的图像分割模型，其结构包括下采样路径（编码器）和上采样路径（解码器），以及跳跃连接。

In [None]:
class DoubleConv(nn.Module):
    """双卷积块：(Conv -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """下采样层：MaxPool + DoubleConv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """上采样层：UpConv + DoubleConv（带跳跃连接）"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        # 使用双线性插值或转置卷积进行上采样
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # 输入可能不是整数倍的2，需要进行尺寸调整
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # 连接特征图
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    """输出卷积层"""
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    """完整的UNet模型"""
    def __init__(self, n_channels=3, n_classes=1, bilinear=True, dropout_p=0.5):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.dropout_p = dropout_p

        # 加载预训练的ResNet-18
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        # 编码器部分 (使用ResNet-18的层)
        self.inc = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu
        ) # 输出通道: 64
        self.maxpool = resnet.maxpool
        self.down1 = resnet.layer1 # 输出通道: 64
        self.down2 = resnet.layer2 # 输出通道: 128
        self.down3 = resnet.layer3 # 输出通道: 256
        self.down4 = resnet.layer4 # 输出通道: 512

        # --- 瓶颈层后 Dropout 层 ---
        self.dropout = nn.Dropout(self.dropout_p)

        # 解码器部分 (调整通道数以匹配ResNet)
        factor = 2 if bilinear else 1
        self.up1 = Up(512 + 256, 512 // factor, bilinear) # down4(512) + down3(256) -> 256
        self.up2 = Up(256 + 128, 256 // factor, bilinear) # up1(256) + down2(128) -> 128
        self.up3 = Up(128 + 64, 128 // factor, bilinear)  # up2(128) + down1(64) -> 64
        self.up4 = Up(64 + 64, 64, bilinear)             # up3(64) + inc(64) -> 64
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        # 编码路径 (ResNet)
        x1 = self.inc(x)       # (N, 64, H/2, W/2) after initial conv+bn+relu (stride=2)
        x_pool = self.maxpool(x1) # (N, 64, H/4, W/4)
        x2 = self.down1(x_pool) # (N, 64, H/4, W/4)
        x3 = self.down2(x2)     # (N, 128, H/8, W/8)
        x4 = self.down3(x3)     # (N, 256, H/16, W/16)
        x5 = self.down4(x4)     # (N, 512, H/32, W/32)

        # --- Dropout 层 ---
        x5_dropout = self.dropout(x5)

        # 解码路径 (带跳跃连接)
        x = self.up1(x5_dropout, x4) # 输入: x5(512), x4(256) -> 输出: 256
        x = self.up2(x, x3)  # 输入: x(256), x3(128) -> 输出: 128
        x = self.up3(x, x2)  # 输入: x(128), x2(64) -> 输出: 64
        x = self.up4(x, x1)  # 输入: x(64), x1(64) -> 输出: 64
        logits = self.outc(x)
        return logits

# 初始化模型
model = UNet(n_classes=1, bilinear=True, dropout_p=0.5).to(device)
if torch.cuda.device_count() > 1:
    print(f"检测到 {torch.cuda.device_count()} 块GPU, 由于多卡存在问题，只使用GPU0")
    #model = nn.DataParallel(model)

# 定义损失函数和评估指标
我们使用组合损失函数：二元交叉熵损失和Dice损失的组合，以更好地处理类别不平衡问题。

In [None]:
# Dice损失函数
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, logits, targets):
        # 使用sigmoid将logits转换为概率
        probs = torch.sigmoid(logits)
        
        # 将维度展平
        batch_size = targets.size(0)
        probs = probs.view(batch_size, -1)
        targets = targets.view(batch_size, -1)
        
        # 计算交集
        intersection = (probs * targets).sum(dim=1)
        
        # 计算Dice系数
        dice = (2. * intersection + self.smooth) / (
            probs.sum(dim=1) + targets.sum(dim=1) + self.smooth)
        
        # 返回Dice损失
        return 1 - dice.mean()

# 组合损失
class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        # BCEWithLogitsLoss 实例本身不存储 pos_weight, 在 forward 中传入
        self.bce_loss_fn = nn.BCEWithLogitsLoss()
        self.dice_loss_fn = DiceLoss()

    def forward(self, logits, targets, pos_weight=None):
        """
        计算组合损失。
        :param logits: 模型输出 (N, 1, H, W)
        :param targets: 真实掩码 (N, 1, H, W)
        :param pos_weight: 正样本权重 (scalar) for BCE loss.
        """
        # 更新BCE损失的pos_weight参数
        self.bce_loss_fn.pos_weight = pos_weight
        bce = self.bce_loss_fn(logits, targets)

        dice = self.dice_loss_fn(logits, targets)
        return self.bce_weight * bce + self.dice_weight * dice

# 评估指标：Dice系数
def dice_coefficient(y_pred, y_true, threshold=0.5, smooth=1e-6):
    """计算预测掩码和真实掩码之间的Dice系数"""
    assert y_pred.shape == y_true.shape, f"预测形状 {y_pred.shape} 与目标形状 {y_true.shape} 不匹配"
    # 应用阈值将概率转换为二值掩码
    y_pred = (torch.sigmoid(y_pred) > threshold).float()
    
    # 压平张量
    y_pred = y_pred.contiguous().view(-1)
    y_true = y_true.contiguous().view(-1)
    
    # 计算交集
    intersection = (y_pred * y_true).sum()
    
    # 计算Dice系数
    dice = (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)
    
    return dice.item()

# 初始化损失函数
criterion = CombinedLoss(bce_weight=0.6, dice_weight=0.4)

# 配置优化器和训练参数
设置Adam优化器和学习率调度器，为模型训练做准备。

In [None]:
# 训练参数
num_epochs = 30
log_interval = 5

# 配置优化器
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=7e-5)

# 学习率调度器
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6) # 使用基线超参数

# 训练模型
实现训练循环，包括前向传播、损失计算、反向传播、参数更新，并记录训练过程中的指标。同时实现早停机制。

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.0, mode='max', verbose=True):
        """
        Args:
            patience (int): 在性能没有提升多少轮后停止训练。
            min_delta (float): 被认为是性能提升的最小变化量。
            mode (str): 'min' 或 'max'。监控指标是越小越好还是越大越好。
            verbose (bool): 是否打印早停信息。
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_model_state_dict_cpu = None # 直接存储 CPU 上的 state_dict

        if self.mode not in ['min', 'max']:
            raise ValueError("mode 必须是 'min' 或 'max'")

        self.delta_sign = 1 if mode == 'max' else -1

    def __call__(self, val_score, model_state_dict_cpu):
        """
        Args:
            val_score (float): 当前验证分数。
            model_state_dict_cpu (dict): 模型当前的 state_dict (已移至 CPU)。
        """
        score = val_score

        if self.best_score is None:
            self.best_score = score
            self.best_model_state_dict_cpu = deepcopy(model_state_dict_cpu) # 保存第一个状态
            if self.verbose:
                tqdm.write(f"EarlyStopping: 初始化最佳分数为 {self.best_score:.4f}")
        # 检查是否有足够的提升 (乘以 delta_sign 以统一处理 min/max)
        elif (score * self.delta_sign) > (self.best_score * self.delta_sign) + self.min_delta:
            self.best_score = score
            self.best_model_state_dict_cpu = deepcopy(model_state_dict_cpu) # 保存更好的状态
            self.counter = 0
            if self.verbose:
                tqdm.write(f"EarlyStopping: 发现改进。最佳分数更新为 {self.best_score:.4f}。计数器重置。")
        else:
            self.counter += 1
            if self.verbose:
                tqdm.write(f'EarlyStopping计数器: {self.counter}/{self.patience}。最佳分数仍为 {self.best_score:.4f}。')
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    tqdm.write("EarlyStopping: 已达到耐心值，触发早停。")

    def load_best_weights(self, model):
        """将最佳权重加载回模型"""
        if self.best_model_state_dict_cpu:
            # 需要将 state_dict 移回模型所在的设备
            device = next(model.parameters()).device
            best_state_device = {k: v.to(device) for k, v in self.best_model_state_dict_cpu.items()}
            model.load_state_dict(best_state_device)
            if self.verbose:
                print("已将最佳模型权重加载回模型。")
        else:
            if self.verbose:
                print("警告：未找到可加载的最佳模型权重。")

In [None]:
def train_validate_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    """
    训练并验证模型一个完整的周期，支持单核TPU、CUDA和CPU。

    Args:
        model (nn.Module): 要训练的模型 (应已移动到目标 device)。
        train_loader (DataLoader): 训练数据加载器。
        val_loader (DataLoader): 验证数据加载器。
        criterion (nn.Module): 损失函数。
        optimizer: 优化器。
        scheduler: 学习率调度器。
        num_epochs (int): 训练的总轮数。
        device (torch.device or str): 目标设备 ('cpu', 'cuda', or xm.xla_device())。

    Returns:
        tuple: (最终验证Dice系数, 训练Dice历史, 验证Dice历史, 最佳模型在CPU上的state_dict)
            如果验证加载器为 None，则最终验证 Dice 为 0，验证历史为空。
    """
    start_time = time.time()
    is_tpu = 'xla' in str(device)
    print(f"\n--- 开始训练 ---")
    print(f"设备: {device}")
    print(f"轮数: {num_epochs}")
    print(f"优化器: {type(optimizer).__name__}")
    print(f"学习率调度器: {type(scheduler).__name__}")
    print(f"损失函数: {type(criterion).__name__}")

    # 初始化 EarlyStopping
    # 注意：如果 val_loader 为 None，早停将基于不存在的验证分数，实际上不会起作用，
    # 但我们仍然创建它以保持代码结构一致。训练将运行完所有 num_epochs。
    # 在这种情况下，我们保存最后一轮的模型。
    early_stopping = EarlyStopping(patience=7, mode='max', min_delta=0.001, verbose=True)

    # 配置混合精度 (仅用于 CUDA)
    use_amp = not is_tpu and torch.cuda.is_available()
    scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
    print(f"使用混合精度 (AMP): {use_amp}")

    train_dice_history = []
    val_dice_history = []
    best_model_state_dict_cpu = None # 存储在CPU上的最佳权重

    for epoch in range(num_epochs):
        epoch_start_time = time.time()

        # --- 训练阶段 ---
        model.train()
        train_loss = 0.0
        train_dice = 0.0
        train_samples = 0
        train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)

        for images, masks in train_loader_tqdm:
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()

            # --- 计算 pos_weight (移至设备) ---
            num_pixels = masks.numel()
            num_pos = torch.sum(masks)
            num_neg = num_pixels - num_pos
            # 避免除以零，并确保 pos_weight 合理
            pos_weight_value = torch.clamp(num_neg / (num_pos + 1e-6), min=1.0) # 至少为1，防止过分抑制前景
            pos_weight_tensor = torch.tensor([pos_weight_value], device=device)

            # --- 前向传播 (根据需要使用 AMP) ---
            with torch.amp.autocast(device_type=str(device).split(':')[0], enabled=use_amp):
                outputs = model(images)
                # 确保掩码尺寸与输出匹配
                masks_downsampled = torch.nn.functional.interpolate(masks, size=outputs.shape[2:], mode='nearest')
                loss = criterion(outputs, masks_downsampled, pos_weight=pos_weight_tensor)

            # --- 反向传播和优化 ---
            if is_tpu:
                loss.backward()
                # xm.optimizer_step 会处理梯度同步（如果需要）和权重更新
                xm.optimizer_step(optimizer)
                # 对于单核 TPU，通常不需要显式的 xm.mark_step() 在这里
            else: # CPU 或 CUDA
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

            # --- 累积指标 ---
            batch_size = images.size(0)
            train_loss += loss.item() * batch_size
            # dice_coefficient 应能在各种设备上运行
            current_dice = dice_coefficient(outputs.detach(), masks_downsampled.detach())
            train_dice += current_dice * batch_size
            train_samples += batch_size

            # --- 更新进度条 ---
            current_lr = optimizer.param_groups[0]['lr']
            train_loader_tqdm.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{current_dice:.4f}',
                'lr': f'{current_lr:.1e}',
                'pw': f'{pos_weight_value.item():.2f}'
            })

        # --- 计算平均训练指标 ---
        if train_samples > 0:
            train_loss /= train_samples
            train_dice /= train_samples
        else:
            train_loss, train_dice = 0.0, 0.0 # 处理空 loader 的情况
        train_dice_history.append(train_dice)

        # --- 验证阶段 (如果 val_loader 存在) ---
        current_val_dice = 0.0 # 初始化本轮验证分数
        if val_loader:
            model.eval()
            val_loss = 0.0
            val_dice = 0.0
            val_samples = 0
            val_loader_tqdm = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False)

            with torch.no_grad():
                for images, masks in val_loader_tqdm:
                    images, masks = images.to(device), masks.to(device)

                    # 前向传播 (验证时通常不使用 AMP autocast，但也可以用)
                    outputs = model(images)
                    masks_downsampled = torch.nn.functional.interpolate(masks, size=outputs.shape[2:], mode='nearest')
                    # 验证时不使用 pos_weight
                    loss = criterion(outputs, masks_downsampled, pos_weight=None)

                    # --- 累积指标 ---
                    batch_size = images.size(0)
                    val_loss += loss.item() * batch_size
                    current_dice_val = dice_coefficient(outputs, masks_downsampled)
                    val_dice += current_dice_val * batch_size
                    val_samples += batch_size

                    val_loader_tqdm.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'dice': f'{current_dice_val:.4f}'
                        })


            # --- 计算平均验证指标 ---
            if val_samples > 0:
                val_loss /= val_samples
                val_dice /= val_samples
            else:
                val_loss, val_dice = 0.0, 0.0 # 处理空 loader
            val_dice_history.append(val_dice)
            current_val_dice = val_dice # 更新本轮验证分数用于早停

            epoch_duration = time.time() - epoch_start_time
            tqdm.write(f"Epoch [{epoch + 1}/{num_epochs}], "
                    f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, "
                    f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, "
                    f"LR: {current_lr:.1e}, Duration: {epoch_duration:.2f}s")

            # --- 早停检查 (仅当有验证集时) ---
            # 获取当前模型状态到 CPU
            current_model_state_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
            early_stopping(current_val_dice, current_model_state_cpu) # 使用修改后的 ES 类
            if early_stopping.early_stop:
                best_model_state_dict_cpu = early_stopping.best_model_state_dict_cpu # 获取最佳状态
                break # 跳出 epoch 循环

        else:
            # 如果没有验证集，直接打印训练结果
            epoch_duration = time.time() - epoch_start_time
            tqdm.write(f"Epoch [{epoch + 1}/{num_epochs}], "
                    f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, "
                    f"LR: {current_lr:.1e}, Duration: {epoch_duration:.2f}s")
            # 没有验证集，无法早停，将保存最后一轮的模型状态
            if epoch == num_epochs - 1: # 如果是最后一轮
                best_model_state_dict_cpu = {k: v.cpu() for k, v in model.state_dict().items()}


        # --- 更新学习率 ---
        # CosineAnnealingLR 在每个 epoch 后调用 step()
        scheduler.step()

    # --- 训练结束处理 ---
    total_training_time = time.time() - start_time
    print("\n--- 训练完成 ---")

    # 如果没有提前停止，并且有最佳权重记录 (来自验证过程)
    if not early_stopping.early_stop and early_stopping.best_model_state_dict_cpu is not None:
        best_model_state_dict_cpu = early_stopping.best_model_state_dict_cpu
        print(f"训练完成 {num_epochs} 轮。使用验证集找到的最佳模型权重。")
    elif early_stopping.early_stop:
        print(f"训练因早停而在第 {epoch + 1} 轮结束。")
        # best_model_state_dict_cpu 已在 break 前被赋值
    elif best_model_state_dict_cpu is None: # 训练完成但从未有过最佳状态(例如val_loader=None)
        print("警告：训练完成，但没有记录最佳模型权重（可能因为没有验证集）。将使用最后一轮的模型权重。")
        # 此时 best_model_state_dict_cpu 应该已经被赋值为最后一轮的状态
        if best_model_state_dict_cpu is None: # 双重检查，理论上不应发生
            best_model_state_dict_cpu = {k: v.cpu() for k, v in model.state_dict().items()}


    # --- 使用最佳模型进行最终评估 (如果 val_loader 存在) ---
    final_val_dice = 0.0
    if val_loader and best_model_state_dict_cpu:
        print("\n使用最佳权重在验证集上进行最终评估...")
        # 将最佳权重加载回模型 (确保移回正确的设备)
        best_state_device = {k: v.to(device) for k, v in best_model_state_dict_cpu.items()}
        model.load_state_dict(best_state_device)
        model.eval()
        val_samples_final = 0
        val_dice_final = 0.0
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc="Final Validation", leave=False):
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                masks_downsampled = torch.nn.functional.interpolate(masks, size=outputs.shape[2:], mode='nearest')
                batch_size = images.size(0)
                val_dice_final += dice_coefficient(outputs, masks_downsampled) * batch_size
                val_samples_final += batch_size
        if val_samples_final > 0:
            final_val_dice = val_dice_final / val_samples_final
        else:
            final_val_dice = 0.0
        print(f"最终(最佳)验证 Dice 系数: {final_val_dice:.4f}")
    elif not val_loader:
        print("没有提供验证集，跳过最终验证评估。")
    else: # 有 val_loader 但没有 best_model_state_dict_cpu
        print("警告：无法获取最佳模型权重，无法进行最终验证评估。")


    print(f"总训练耗时: {total_training_time:.2f} 秒")

    # 返回最终验证 Dice 和 CPU 上的最佳 state_dict
    return final_val_dice, train_dice_history, val_dice_history, best_model_state_dict_cpu

# 开始训练
best_dice, train_dice_history, val_dice_history, best_model_weights = train_validate_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler, # 如果是ReduceLROnPlateau, scheduler.step(val_dice)
    num_epochs=num_epochs,
    device=device
)

# 评估模型性能
可视化学习曲线和分割结果，计算Dice系数和95% Hausdorff距离等评估指标。

In [None]:
# 可视化学习曲线
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(train_dice_history) + 1), train_dice_history, label='训练Dice系数')
plt.plot(range(1, len(val_dice_history) + 1), val_dice_history, label='验证Dice系数')
plt.title('训练和验证Dice系数')
plt.xlabel('轮次')
plt.ylabel('Dice系数')
plt.legend()
plt.grid(True)
plt.show()

# 可视化分割结果
def visualize_segmentation(model, dataloader, num_samples=5):
    """可视化分割结果"""
    model.eval()
    dataiter = iter(dataloader)
    
    # 获取一批数据
    try:
        images, masks = next(dataiter)
    except StopIteration:
        print("数据集太小，无法获取足够的样本。")
        return
    
    # 限制样本数
    num_samples = min(num_samples, images.size(0))
    
    # 进行预测
    with torch.no_grad():
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)
        pred_masks = (torch.sigmoid(outputs) > 0.5).float()
    
    # 反标准化图像以便可视化
    images_np = []
    for img in images[:num_samples]:
        img = img.cpu().numpy().transpose(1, 2, 0)  # 转为HWC格式
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        images_np.append(img)
    
    # 准备掩码和预测
    masks_np = masks[:num_samples].cpu().numpy().squeeze(1)  # (N, H, W)
    pred_masks_np = pred_masks[:num_samples].cpu().numpy().squeeze(1)  # (N, H, W)
    
    # 可视化
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 4 * num_samples))
    
    for i in range(num_samples):
        # 原始图像
        axes[i, 0].imshow(images_np[i])
        axes[i, 0].set_title('原始图像')
        axes[i, 0].axis('off')
        
        # 真实掩码
        axes[i, 1].imshow(masks_np[i], cmap='gray')
        axes[i, 1].set_title('真实掩码')
        axes[i, 1].axis('off')
        
        # 预测掩码
        axes[i, 2].imshow(pred_masks_np[i], cmap='gray')
        masks_downsampled = torch.nn.functional.interpolate(masks, size=outputs.shape[2:], mode='nearest')
        dice = dice_coefficient(outputs[i:i+1], masks_downsampled[i:i+1])
        axes[i, 2].set_title(f'预测掩码 (Dice: {dice:.4f})')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_segmentation(model, val_loader, num_samples=5) # 可视化验证集(resize后)的结果

# 模型保存和加载
保存训练好的模型，以便将来加载并用于预测。

In [None]:
# 保存模型参数
def save_model(model, path):
    """保存模型参数到指定路径"""
    torch.save(model.state_dict(), path)
    print(f"模型参数已保存到 {path}")
save_model_path=f'/kaggle/working/work2model_dice-{best_dice}_{time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())}.pth'
save_model(model,save_model_path)

# 模型预测与应用

遍历测试图像，加载它们，进行预处理，然后使用加载的模型进行预测，最后将预测的掩码保存下来。

In [None]:
# --- 0.级联-分类模型加载 ---
if platform.system() == "Windows":
    classification_model_path = 'w1.pth'
    model_save_path = 'work2model_dice-0.8811619244894272_2025-04-13-15-31-38.pth'
else:
    # 在 Kaggle 中使用的路径
    classification_model_path = '/kaggle/input/pterygium_classifier/pytorch/default/2/resnet18_pterygium_classifier.pth'

# 定义验证集/测试集的变换 (无需数据增强)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 构建 ResNet18 模型
from torchvision.models import ResNet18_Weights
class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes=3, dropout_rate=0.5):
        super(ResNet18Classifier, self).__init__()
        # 加载预训练的 ResNet18 模型
        self.resnet18 = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        # 替换最后的全连接层以适应3个类别的分类任务
        in_features = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Sequential(
            nn.Dropout(p=dropout_rate), # 添加 Dropout 层
            nn.Linear(in_features, num_classes) # 添加新的全连接层
        )

    def forward(self, x):
        return self.resnet18(x)

def predict_image(model, image_path, transform, device):
    """
    使用训练好的模型对单张图像进行预测
    :param model: 训练好的模型
    :param image_path: 图像路径
    :param transform: 图像预处理变换
    :param device: 设备（CPU 或 GPU）
    :return: 预测类别
    """
    model.eval()  # 设置模型为评估模式
    image = Image.open(image_path).convert("RGB")  # 加载图像并转换为RGB
    image = transform(image).unsqueeze(0).to(device)  # 应用预处理并添加批次维度

    with torch.no_grad():
        outputs = model(image)  # 前向传播
        _, predicted = outputs.max(1)  # 获取预测类别
    return predicted.item()

classification_model = ResNet18Classifier(num_classes=3)
classification_model.load_state_dict(torch.load(classification_model_path, map_location=device, weights_only=True))
classification_model = classification_model.to(device)
classification_model.eval();

In [None]:
# --- 1. 加载训练好的模型 ---

try:
    model_save_path = save_model_path
except:
    print("model_save_path 未定义。请确保在训练后保存模型。")

loaded_model = UNet(n_channels=3, n_classes=1, bilinear=True).to(device)
# 加载最佳权重 (假设 best_model_weights 变量包含 state_dict)
if 'best_model_weights' in locals() and best_model_weights is not None:
    loaded_model.load_state_dict(best_model_weights)
    print("成功加载训练好的模型权重。")
else:
    # 如果没有 best_model_weights，尝试从文件加载（需要先保存）
    if os.path.exists(model_save_path):
        loaded_model.load_state_dict(torch.load(model_save_path, map_location=device, weights_only=True))
        print(f"从 {model_save_path} 加载模型权重。")
    else:
        print("警告: 未找到训练好的模型权重 (best_model_weights 或文件)。模型将使用随机初始化的权重。")

loaded_model.eval(); # 设置为评估模式

In [None]:
# --- 2. 定义 Tiling 预测函数 ---
def predict_with_tiling(model, image_path, patch_size, stride, device, batch_size=4, threshold=0.5):
    """
    使用 Tiling 和重叠平均策略预测大图的掩码。
    :param model: 训练好的模型
    :param image_path: 原始大图路径
    :param patch_size: patch 大小 (int)
    :param stride: 切割 patch 时的步长 (int)
    :param device: 设备
    :param batch_size: 推理时的批次大小
    :param threshold: 二值化阈值
    :return: 预测的二值掩码 (numpy array, H x W) 或 None
    """
    try:
        img = Image.open(image_path).convert('RGB')
        img_w, img_h = img.size
    except FileNotFoundError:
        print(f"错误: 图像文件未找到 {image_path}")
        return None
    except Exception as e:
        print(f"错误: 加载图像时出错 {image_path}: {e}")
        return None

    model.eval() # 确保模型在评估模式

    # 定义预处理 (仅 ToTensor 和 Normalize)
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 计算所需的 padding
    pad_h = max(patch_size - img_h, 0)
    pad_w = max(patch_size - img_w, 0)
    # 在右侧和底部进行 padding，使其尺寸至少为 patch_size
    # 进一步 padding 使尺寸能被 stride 整除，以便覆盖所有边缘
    target_h = img_h + pad_h
    target_w = img_w + pad_w
    pad_h_stride = (stride - (target_h - patch_size) % stride) % stride
    pad_w_stride = (stride - (target_w - patch_size) % stride) % stride

    padding = (0, 0, pad_w + pad_w_stride, pad_h + pad_h_stride) # (left, top, right, bottom)
    img_padded = F.pad(img, padding, padding_mode='reflect') # 使用反射填充
    padded_h, padded_w = img_padded.size[1], img_padded.size[0] # PIL size is W, H

    # 创建空的概率图和计数图
    pred_prob_map = torch.zeros((1, padded_h, padded_w), dtype=torch.float32, device=device)
    count_map = torch.zeros((1, padded_h, padded_w), dtype=torch.float32, device=device)

    # 收集所有 patch 的坐标
    patch_coords = []
    for y in range(0, padded_h - patch_size + 1, stride):
        for x in range(0, padded_w - patch_size + 1, stride):
            patch_coords.append((x, y))

    # 分批处理 patches
    with torch.no_grad():
        for i in tqdm(range(0, len(patch_coords), batch_size), desc="Tiling Prediction",leave=False):
            batch_coords = patch_coords[i:i+batch_size]
            batch_patches_pil = [img_padded.crop((x, y, x + patch_size, y + patch_size)) for x, y in batch_coords]
            batch_patches_tensor = torch.stack([preprocess(p) for p in batch_patches_pil]).to(device)

            # 模型预测 (输出 logits)
            batch_outputs_logits = model(batch_patches_tensor)
            batch_outputs_probs = torch.sigmoid(batch_outputs_logits) # (B, 1, P, P)

            # 将预测结果累加回概率图
            for j, (x, y) in enumerate(batch_coords):
                # 修复：将模型输出上采样到patch_size大小
                output_height, output_width = batch_outputs_probs[j].shape[1:] # 获取当前输出尺寸
                if output_height != patch_size or output_width != patch_size:
                    # 上采样到与patch相同大小
                    upsampled_output = torch.nn.functional.interpolate(
                        batch_outputs_probs[j].unsqueeze(0), # 增加批次维度
                        size=(patch_size, patch_size),
                        mode='bilinear',
                        align_corners=False
                    ).squeeze(0) # 移除批次维度
                    pred_prob_map[:, y:y+patch_size, x:x+patch_size] += upsampled_output
                else:
                    pred_prob_map[:, y:y+patch_size, x:x+patch_size] += batch_outputs_probs[j]
                count_map[:, y:y+patch_size, x:x+patch_size] += 1

    # 处理计数为0的区域（如果padding策略完美，不应出现，但以防万一）
    count_map[count_map == 0] = 1
    # 计算平均概率
    avg_prob_map = pred_prob_map / count_map

    # 裁剪回原始图像尺寸 (去掉 padding)
    original_h, original_w = img_h, img_w
    avg_prob_map_cropped = avg_prob_map[:, :original_h, :original_w]

    # 应用阈值得到最终二值掩码
    final_mask = (avg_prob_map_cropped > threshold).squeeze().cpu().numpy().astype(np.uint8)

    # --- 后处理 ---
    # --- 1.开运算断开细小连接 (在 MCC 之前) ---
    if np.sum(final_mask) > 0: # 仅当存在前景时处理
        # 定义结构元素 (Kernel) - 需要调整大小！
        # disk(radius) 或 square(width)
        # 半径/宽度需要略大于你想断开的细线的厚度
        # !!! 这是一个关键的超参数，需要根据你的数据进行调整 !!!
        # 初始可以尝试较小的值，例如 disk(1) 或 disk(2) (对应直径约3或5像素) 或者 square(3)
        selem_opening = disk(6)

        # 应用开运算
        final_mask_opened = binary_opening(final_mask, selem_opening)

        # 更新掩码，开运算可能会移除所有前景
        final_mask = final_mask_opened.astype(np.uint8)

    # --- 2.保留最大连通区域 (MCC) ---
    if np.sum(final_mask) > 0: # 再次检查，因为开运算可能移除所有前景
        labeled_mask = label(final_mask)
        region_props = np.unique(labeled_mask[labeled_mask > 0], return_counts=True)

        if len(region_props[0]) > 0:
            largest_component_label = region_props[0][np.argmax(region_props[1])]
            # 只保留最大区域
            final_mask_mcc = np.zeros_like(final_mask)
            final_mask_mcc[labeled_mask == largest_component_label] = 1
            final_mask = final_mask_mcc
        else:
            final_mask = np.zeros_like(final_mask) # 如果开运算后没有有效区域了
    else:
        # 如果开运算后没有前景，或者原始就没有前景，保持全零
        final_mask = np.zeros_like(final_mask)

    # --- 3.填充小空洞 (在 MCC 之后) ---
    if np.sum(final_mask) > 0: # 仅对选出的最大区域进行处理
        # 定义结构元素 (Kernel) - 需要调整大小！
        # 通常用于闭运算的核比开运算的核稍大或相同
        # 目的是填充内部小孔，大小取决于孔洞的大小
        # !!! 同样需要调整 !!!
        selem_closing = disk(3) # 示例：半径为3的圆盘

        # 应用闭运算
        final_mask_closed = binary_closing(final_mask, selem_closing)

        # 更新最终掩码
        final_mask = final_mask_closed.astype(np.uint8)
    # --- 后处理结束 ---

    return final_mask

In [None]:
# --- 3. 对测试图像进行预测 ---

# 输出预测掩码的目录
if os.path.exists(output_mask_dir):
    print(f"输出目录 {output_mask_dir} 已存在，删除旧目录。")
    shutil.rmtree(output_mask_dir)
os.makedirs(output_mask_dir, exist_ok=True)

# 查找所有测试图像
test_image_paths = glob.glob(os.path.join(val_image_dir, "*.png"))

print(f"找到 {len(test_image_paths)} 张待预测图像于 {val_image_dir}")

# 统计跳过了多少健康图像
skipped_healthy_count = 0
processed_count = 0

# 遍历测试图像并进行预测
for img_path in tqdm(test_image_paths, desc="处理预测图像"):
    base_name = os.path.basename(img_path)
    save_name = f"{os.path.splitext(base_name)[0]}.png" # 统一保存为png
    save_path = os.path.join(output_mask_dir, save_name)

    # --- 级联：使用分类模型判断 ---
    predicted_class = predict_image(classification_model, img_path, val_transform, device)

    if predicted_class == 0: # 预测为健康 (类别 0)
        try:
            # 获取原始图像尺寸用于创建空掩码
            with Image.open(img_path) as img_pil:
                img_w, img_h = img_pil.size
            
            empty_mask_np = np.zeros((img_h, img_w, 3), dtype=np.uint8)
            empty_mask_image = Image.fromarray(empty_mask_np, mode='RGB')
            empty_mask_image.save(save_path)
            skipped_healthy_count += 1
        except Exception as e:
            print(f"为健康图像 {base_name} 创建或保存空掩码时出错: {e}")
        continue

    # --- 如果不是健康，则执行分割 ---
    # 使用 Tiling 进行预测
    predicted_mask_np = predict_with_tiling(
        model=loaded_model,
        image_path=img_path,
        patch_size=patch_size,
        stride=predict_stride, # 使用之前定义的 stride
        device=device,
        batch_size=32 # 可以根据GPU内存调整
    )

    if predicted_mask_np is not None:
        # predicted_mask_np 是经过后处理 (如最大连通域) 的 0/1 掩码
        # 创建RGB掩码图像，前景(1)为(128,0,0)，背景(0)为(0,0,0)
        h, w = predicted_mask_np.shape
        rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)
        # 设置前景区域(值为1的位置)为(128,0,0)
        rgb_mask[predicted_mask_np == 1, 0] = 128  # R通道设为128
        # 背景区域(值为0的位置)默认已经是(0,0,0)
        mask_image = Image.fromarray(rgb_mask, mode='RGB')

        try:
            mask_image.save(save_path)
            processed_count += 1
        except Exception as e:
            print(f"保存分割掩码时出错 {save_path}: {e}")
    else:
        # 如果 predict_with_tiling 返回 None (例如文件未找到)，也创建一个空掩码
        try:
            print(f"警告：predict_with_tiling未能处理图像 {base_name}，将保存空掩码。")
            # 获取原始图像尺寸用于创建空掩码
            with Image.open(img_path) as img_pil:
                img_w, img_h = img_pil.size
            empty_mask_np = np.zeros((img_h, img_w, 3), dtype=np.uint8)
            empty_mask_image = Image.fromarray(empty_mask_np, mode='RGB')
            empty_mask_image.save(save_path)
        except Exception as e:
            print(f"为处理失败的图像 {base_name} 创建或保存空掩码时出错: {e}")

print(f"\n预测处理完成。")
print(f"总共处理图像: {len(test_image_paths)}")
print(f"其中 {skipped_healthy_count} 张被分类为健康并保存了空掩码。")
print(f"对其余 {processed_count} 张图像进行了分割并保存了掩码。")
# 注意：processed_count + skipped_healthy_count 可能不等于总数，如果 predict_with_tiling 失败且保存空掩码也失败了。

# 可视化结果
# 要可视化的样本数量
num_visualize = 10
if not test_image_paths:
    print("错误：找不到测试图像路径，无法进行可视化。")
else:
    # 确保样本数量不超过实际图像数量
    num_visualize = min(num_visualize, len(test_image_paths))
    if num_visualize > 0:
        # 随机选择图像路径进行可视化
        sample_paths = random.sample(test_image_paths, num_visualize)
        print(f"\n可视化 {num_visualize} 个预测结果...")

        for img_path in sample_paths:
            try:
                # 加载原始图像
                original_image = Image.open(img_path).convert('RGB')

                # 构建对应的预测掩码文件的路径
                base_name = os.path.basename(img_path)
                file_stem = os.path.splitext(base_name)[0]
                mask_filename = f"{file_stem}.png" # 假设保存的掩码文件名与原图名相同
                predicted_mask_path = os.path.join(output_mask_dir, mask_filename)

                # 检查预测掩码文件是否存在并加载
                if os.path.exists(predicted_mask_path):
                    predicted_mask_image = Image.open(predicted_mask_path) # 加载之前保存的RGB掩码

                    # 创建绘图
                    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

                    # 显示原始图像
                    axes[0].imshow(original_image)
                    axes[0].set_title(f"原始图像: {base_name}")
                    axes[0].axis('off')

                    # 显示对应的预测掩码
                    axes[1].imshow(predicted_mask_image) # 显示加载的对应掩码
                    axes[1].set_title("预测掩码 (RGB 128,0,0)")
                    axes[1].axis('off')

                    plt.tight_layout() # 调整布局防止重叠
                    plt.show()
                else:
                    print(f"警告：找不到预测掩码文件 {predicted_mask_path}，无法可视化此样本。")

            except Exception as e:
                print(f"可视化图像 {img_path} 时出错: {e}")
    else:
        print("没有测试图像可供可视化。")

In [None]:
# --- 4. 压缩结果 ---
if 'google.colab' in sys.modules or os.path.exists("/kaggle/working"):
    zip_file_path = f"{output_mask_dir}.zip"

    if output_mask_dir and os.path.exists(output_mask_dir) and os.listdir(output_mask_dir):
        print(f"开始压缩目录: {output_mask_dir}")
        try:
            with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
                files_to_zip = glob.glob(os.path.join(output_mask_dir, '*.*'))
                if not files_to_zip:
                    print(f"警告: 目录 {output_mask_dir} 为空，无需压缩。")
                else:
                    for file in tqdm(files_to_zip, desc="压缩文件"):
                        zipf.write(file, arcname=os.path.basename(file))
                    print(f"预测结果已成功压缩到: {zip_file_path}")

                    # 删除原始文件和目录 (可选)
                    print(f"删除原始掩码文件于: {output_mask_dir}")
                    shutil.rmtree(output_mask_dir)

        except Exception as e:
            print(f"压缩或删除文件时发生错误: {e}")
    elif output_mask_dir:
        print(f"目录 {output_mask_dir} 不存在或为空，跳过压缩和删除步骤。")


print("\n预测处理完成。")