In [9]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from dataset import get_loaders
from PIL import Image
import segmentation_models_pytorch as smp

# ------------------------------
# 1. 損失與評估函數
# ------------------------------
def dice_loss(pred, target, eps=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    return 1 - (2. * intersection + eps) / (union + eps)

def mixed_loss(pred, target):
    bce = nn.BCEWithLogitsLoss()(pred, target)
    return bce + dice_loss(pred, target)

def compute_iou(preds, masks, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold).float()
    intersection = (preds * masks).sum(dim=(1, 2, 3))
    union = ((preds + masks) > 0).float().sum(dim=(1, 2, 3))
    return (intersection / (union + 1e-6)).mean().item()

def evaluate_prediction_error(preds, masks, px_per_cm=72):
    preds = (torch.sigmoid(preds) > 0.5).float()
    errors = []
    for i in range(preds.size(0)):
        pred_coords = torch.nonzero(preds[i, 0], as_tuple=False)
        true_coords = torch.nonzero(masks[i, 0], as_tuple=False)
        if len(pred_coords) == 0 or len(true_coords) == 0:
            continue
        pred_y = pred_coords[:, 0].float().mean()
        true_y = true_coords[:, 0].float().mean()
        errors.append(torch.abs(pred_y - true_y).item())

    if not errors:
        return 0, 0, 0
    mean_err = np.mean(errors)
    acc_0_5cm = np.mean(np.array(errors) <= (0.5 * px_per_cm)) * 100
    acc_1_0cm = np.mean(np.array(errors) <= (1.0 * px_per_cm)) * 100
    return mean_err, acc_0_5cm, acc_1_0cm

# ------------------------------
# 2. 訓練與預測流程（使用預訓練 Unet）
# ------------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs("checkpoints/unet", exist_ok=True)

folds = ['Fold1', 'Fold2', 'Fold3', 'Fold4', 'Fold5']
for fold in folds:
    print(f"=== 開始訓練 {fold} ===")
    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1
    ).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    train_loader, val_loader = get_loaders('processed_data', fold, batch_size=4)

    best_val_iou = 0
    patience = 30
    early_stop_counter = 0

    for epoch in range(100):
        model.train()
        total_loss = 0
        for imgs, masks in train_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = mixed_loss(outputs, masks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # 驗證
        model.eval()
        val_iou = val_err = acc_0_5cm = acc_1_0cm = 0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs, masks = imgs.to(device), masks.to(device)
                outputs = model(imgs)
                val_iou += compute_iou(outputs, masks)
                mean_err, a05, a10 = evaluate_prediction_error(outputs, masks)
                val_err += mean_err
                acc_0_5cm += a05
                acc_1_0cm += a10

        val_iou /= len(val_loader)
        val_err /= len(val_loader)
        acc_0_5cm /= len(val_loader)
        acc_1_0cm /= len(val_loader)

        print(f"[{fold} - Epoch {epoch+1}] Loss: {total_loss:.2f} | IOU: {val_iou:.4f} | 平均誤差: {val_err:.2f}px | "
              f"誤差<0.5cm: {acc_0_5cm:.2f}% | 誤差<1cm: {acc_1_0cm:.2f}%")

        if val_iou > best_val_iou:
            best_val_iou = val_iou
            early_stop_counter = 0  # reset patience
            torch.save(model.state_dict(), f"checkpoints/unet/unet_{fold}_best.pth")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"🛑 Early stopping at epoch {epoch+1} for {fold}")
                break

    # 儲存整個驗證集所有預測圖像
    out_dir = f"predictions/unet/{fold}"
    os.makedirs(out_dir, exist_ok=True)
    model.eval()
    count = 1

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            preds = model(imgs)
            preds = (torch.sigmoid(preds) > 0.5).float()

            for i in range(imgs.size(0)):
                orig = to_pil_image(imgs[i][0].cpu())
                gt = to_pil_image(masks[i][0].cpu())
                pred = to_pil_image(preds[i][0].cpu())

                w, h = orig.width, orig.height
                combined = Image.new('L', (w * 3, h))
                combined.paste(orig, (0, 0))
                combined.paste(gt, (w, 0))
                combined.paste(pred, (2 * w, 0))
                combined.save(f"{out_dir}/pred_{count}.png")
                count += 1

print("✅ 使用預訓練 U-Net 完成訓練！")


=== 開始訓練 Fold1 ===
[Fold1 - Epoch 1] Loss: 104.59 | IOU: 0.0963 | 平均誤差: 24.61px | 誤差<0.5cm: 75.00% | 誤差<1cm: 95.83%
[Fold1 - Epoch 2] Loss: 87.50 | IOU: 0.1057 | 平均誤差: 6.32px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 3] Loss: 81.53 | IOU: 0.1702 | 平均誤差: 7.68px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 4] Loss: 77.02 | IOU: 0.1538 | 平均誤差: 7.21px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 5] Loss: 73.18 | IOU: 0.1897 | 平均誤差: 6.17px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 6] Loss: 68.74 | IOU: 0.1254 | 平均誤差: 7.54px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 7] Loss: 64.31 | IOU: 0.2510 | 平均誤差: 4.99px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 8] Loss: 60.80 | IOU: 0.2271 | 平均誤差: 8.38px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 9] Loss: 55.62 | IOU: 0.2552 | 平均誤差: 7.96px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 10] Loss: 50.29 | IOU: 0.2911 | 平均誤差: 7.19px | 誤差<0.5cm: 97.92% | 誤差<1cm: 100.00%
[Fold1 - Ep