In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models.segmentation as models
import numpy as np
import matplotlib.pyplot as plt
from dataset import get_loaders
from torchvision.transforms.functional import to_pil_image
from PIL import Image

# ⚙️ 1. 基本設定
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs("checkpoints/fcn", exist_ok=True)

# ⚙️ 2. Dice + BCE 混合損失
def dice_loss(pred, target, eps=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    return 1 - (2. * intersection + eps) / (union + eps)

def mixed_loss(pred, target):
    bce = nn.BCEWithLogitsLoss()(pred, target)
    return 0.3*bce + 0.7*dice_loss(pred, target)

# ⚙️ 3. 評估指標
def compute_iou(preds, masks, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold).float()
    intersection = (preds * masks).sum(dim=(1,2,3))
    union = ((preds + masks) > 0).float().sum(dim=(1,2,3))
    return (intersection / (union + 1e-6)).mean().item()

def evaluate_prediction_error(preds, masks, px_per_cm=72):
    preds = (torch.sigmoid(preds) > 0.5).float()
    errors = []
    for i in range(preds.size(0)):
        pred_coords = torch.nonzero(preds[i, 0], as_tuple=False)
        true_coords = torch.nonzero(masks[i, 0], as_tuple=False)
        if len(pred_coords) == 0 or len(true_coords) == 0:
            continue
        pred_y = pred_coords[:, 0].float().mean()
        true_y = true_coords[:, 0].float().mean()
        errors.append(torch.abs(pred_y - true_y).item())

    if not errors:
        return 0, 0, 0
    mean_err = np.mean(errors)
    acc_0_5cm = np.mean(np.array(errors) <= (0.5 * px_per_cm)) * 100
    acc_1_0cm = np.mean(np.array(errors) <= (1.0 * px_per_cm)) * 100
    return mean_err, acc_0_5cm, acc_1_0cm

# ⚙️ 4. 模型載入
def get_model():
    model = models.fcn_resnet50(weights=None, num_classes=1)
    return model


all_folds = ['Fold1', 'Fold2', 'Fold3', 'Fold4', 'Fold5']
for fold in all_folds:
    print(f"=== 開始訓練 {fold} ===")
    train_loader, val_loader = get_loaders('processed_data', fold, batch_size=8)
    model = get_model().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    best_val_iou = 0
    early_stop_count = 0

    for epoch in range(100):  # 最多訓練 100 次
        model.train()
        total_loss = 0
        for imgs, masks in train_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)['out']
            loss = mixed_loss(outputs, masks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        model.eval()
        val_iou = val_err = acc_0_5cm = acc_1_0cm = 0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs, masks = imgs.to(device), masks.to(device)
                outputs = model(imgs)['out']
                val_iou += compute_iou(outputs, masks)
                mean_err, a05, a10 = evaluate_prediction_error(outputs, masks)
                val_err += mean_err
                acc_0_5cm += a05
                acc_1_0cm += a10
        val_iou /= len(val_loader)
        val_err /= len(val_loader)
        acc_0_5cm /= len(val_loader)
        acc_1_0cm /= len(val_loader)

        print(f"[{fold} - Epoch {epoch+1}] Loss: {total_loss:.2f} | IOU: {val_iou:.4f} | 平均誤差: {val_err:.2f}px | "
              f"誤差<0.5cm: {acc_0_5cm:.2f}% | 誤差<1cm: {acc_1_0cm:.2f}%")

        if val_iou > best_val_iou:
            best_val_iou = val_iou
            early_stop_count = 0
            torch.save(model.state_dict(), f"checkpoints/fcn/fcn_{fold}_best.pth")
        else:
            early_stop_count += 1

        if early_stop_count >=30:
            print(f"⏹️ 提前停止訓練（early stopping），epoch = {epoch+1}")
            break

    # ⚙️ 儲存預測圖像（只存 predicted mask）
    out_dir = f"predictions/fcn/{fold}"
    os.makedirs(out_dir, exist_ok=True)
    model.eval()
    count = 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)

            # ✅ 從 OrderedDict 中取出 'out'
            preds = model(imgs)['out']

            # ✅ Sigmoid 並二值化
            preds = (torch.sigmoid(preds) > 0.5).float()

            for i in range(imgs.size(0)):
                orig = to_pil_image(imgs[i][0].cpu())
                gt = to_pil_image(masks[i][0].cpu())
                pred = to_pil_image(preds[i][0].cpu())

                w, h = orig.width, orig.height
                combined = Image.new('L', (w * 3, h))
                combined.paste(orig, (0, 0))
                combined.paste(gt, (w, 0))
                combined.paste(pred, (2 * w, 0))
                combined.save(f"{out_dir}/pred_{count}.png")
                count += 1




=== 開始訓練 Fold1 ===
[Fold1 - Epoch 1] Loss: 86.97 | IOU: 0.1795 | 平均誤差: 11.40px | 誤差<0.5cm: 95.83% | 誤差<1cm: 100.00%
[Fold1 - Epoch 2] Loss: 68.80 | IOU: 0.2465 | 平均誤差: 7.91px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 3] Loss: 61.15 | IOU: 0.2404 | 平均誤差: 6.25px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 4] Loss: 54.26 | IOU: 0.2541 | 平均誤差: 5.27px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 5] Loss: 46.80 | IOU: 0.2399 | 平均誤差: 7.27px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 6] Loss: 40.67 | IOU: 0.2554 | 平均誤差: 6.29px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 7] Loss: 35.15 | IOU: 0.1881 | 平均誤差: 6.25px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 8] Loss: 30.22 | IOU: 0.2044 | 平均誤差: 7.58px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 9] Loss: 25.58 | IOU: 0.2258 | 平均誤差: 8.21px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - Epoch 10] Loss: 22.18 | IOU: 0.2047 | 平均誤差: 6.06px | 誤差<0.5cm: 100.00% | 誤差<1cm: 100.00%
[Fold1 - E