In [None]:
# !git clone https://github.com/776lucky/9517_project.git

In [None]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt


def save_prediction_images(model, dataset, epoch, indices=[0, 1, 2], save_dir="predictions"):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()

    for idx in indices:
        img, mask = dataset[idx]
        with torch.no_grad():
            pred = model(img.unsqueeze(0).to(DEVICE))
            pred = torch.sigmoid(pred).squeeze().cpu().numpy()

        img_rgb = img[:3].permute(1, 2, 0).numpy()
        mask_np = mask.squeeze().numpy()

        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow(img_rgb); axs[0].set_title("RGB Image")
        axs[1].imshow(mask_np, cmap='gray'); axs[1].set_title("Ground Truth")
        axs[2].imshow(pred, cmap='gray'); axs[2].set_title("Prediction")
        for ax in axs: ax.axis("off")

        save_path = os.path.join(save_dir, f"epoch_{epoch}_idx_{idx}.png")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        print(f"📸 预测图已保存至 {save_path}")



# ===== 数据预处理部分 =====
def build_file_lists(rgb_dir, nrg_dir, train_ratio=0.8):
    rgb_files = [f.replace("RGB_", "") for f in os.listdir(rgb_dir) if f.startswith("RGB_")]
    nrg_files = [f.replace("NRG_", "") for f in os.listdir(nrg_dir) if f.startswith("NRG_")]
    common_ids = sorted(set(rgb_files) & set(nrg_files))
    print(f"✅ 找到 {len(common_ids)} 张 RGB 和 NRG 对应图像")

    rgb_filenames = ["RGB_" + fid for fid in common_ids]
    np.random.seed(42)
    np.random.shuffle(rgb_filenames)
    n_train = int(train_ratio * len(rgb_filenames))
    return rgb_filenames[:n_train], rgb_filenames[n_train:]

class FourChannelSegmentationDataset(Dataset):
    def __init__(self, rgb_dir, nrg_dir, mask_dir, file_list, image_size=(256, 256)):
        self.rgb_dir = rgb_dir
        self.nrg_dir = nrg_dir
        self.mask_dir = mask_dir
        self.file_list = file_list
        self.image_size = image_size

        self.transform_img = T.Compose([
            T.ToPILImage(),
            T.Resize(self.image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = self.file_list[idx]
        img_id = filename.replace("RGB_", "")
        rgb_path = os.path.join(self.rgb_dir, filename)
        nrg_path = os.path.join(self.nrg_dir, "NRG_" + img_id)
        mask_path = os.path.join(self.mask_dir, "mask_" + img_id)

        rgb = cv2.imread(rgb_path)
        rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
        nrg = cv2.imread(nrg_path)
        nir_2d = nrg[:, :, 0]
        img_4ch = np.concatenate((rgb, nir_2d[..., None]), axis=-1)

        rgb_tensor = self.transform_img(img_4ch[:, :, :3])
        nir_tensor = self.transform_img(nir_2d)
        img_tensor = torch.cat([rgb_tensor, nir_tensor], dim=0)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, self.image_size)
        mask = (mask > 127).astype(np.float32)
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)

        return img_tensor, mask_tensor

def get_datasets(base_dir="9517_project/USA_segmentation", image_size=(256, 256)):
    rgb_dir = os.path.join(base_dir, "RGB_images")
    nrg_dir = os.path.join(base_dir, "NRG_images")
    mask_dir = os.path.join(base_dir, "masks")
    train_files, test_files = build_file_lists(rgb_dir, nrg_dir)
    train_dataset = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, train_files, image_size)
    test_dataset = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, test_files, image_size)
    return train_dataset, test_dataset

# ===== 模型训练部分 =====
# === 1. 参数设置 ===
IMG_SIZE = (256, 256)
IN_CHANNELS = 4
OUT_CHANNELS = 1
BATCH_SIZE = 16
EPOCHS = 100
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 2. 加载数据 ===
train_dataset, test_dataset = get_datasets("9517_project/USA_segmentation", image_size=IMG_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi


class AttentionUNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Dropout2d(0.1)
            )

        self.pool = nn.MaxPool2d(2)

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.bottleneck = conv_block(512, 512)

        self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.att1 = AttentionBlock(F_g=64,  F_l=64,  F_int=32)


        self.upconv4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
        self.dec4 = conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)
        nn.init.constant_(self.final.bias, -2.0)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))

        d4 = self.upconv4(b)
        e4 = self.att4(g=d4, x=e4)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.upconv3(d4)
        e3 = self.att3(g=d3, x=e3)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.upconv2(d3)
        e2 = self.att2(g=d2, x=e2)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.upconv1(d2)
        e1 = self.att1(g=d1, x=e1)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.final(d1)


# === 4. 初始化模型 ===
model = AttentionUNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# === 5. 损失函数与 IoU ===
def focal_loss(pred, target, alpha=0.75, gamma=2.5):
    bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
    prob = torch.sigmoid(pred)
    loss = alpha * (1 - prob)**gamma * bce
    return loss.mean()

def tversky_loss(pred, target, alpha=0.7, beta=0.3, smooth=1e-6):
    pred = torch.sigmoid(pred)
    TP = (pred * target).sum((1,2,3))
    FP = ((1 - target) * pred).sum((1,2,3))
    FN = (target * (1 - pred)).sum((1,2,3))
    return 1 - ((TP + smooth) / (TP + alpha * FP + beta * FN + smooth)).mean()


def dice_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum((1,2,3))
    union = pred.sum((1,2,3)) + target.sum((1,2,3))
    return 1 - ((2. * intersection + smooth) / (union + smooth)).mean()

def combined_loss(pred, target):
    return 0.5 * focal_loss(pred, target, alpha=0.85, gamma=3.0) + 0.5 * tversky_loss(pred, target)


def compute_iou(pred, target, threshold=0.3):
    with torch.no_grad():
        pred_bin = (torch.sigmoid(pred) > threshold).float()
        intersection = (pred_bin * target).sum((2, 3))
        union = (pred_bin + target).clamp(0, 1).sum((2, 3))
        return (intersection / (union + 1e-8)).mean().item()

def visualize_prediction(model, dataset, index=0):
    model.eval()
    img, mask = dataset[index]
    with torch.no_grad():
        pred = model(img.unsqueeze(0).to(DEVICE))
        pred = torch.sigmoid(pred).squeeze().cpu().numpy()

    img_rgb = img[:3].permute(1, 2, 0).numpy()
    mask_np = mask.squeeze().numpy()

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1); plt.imshow(img_rgb); plt.title("RGB Image")
    plt.subplot(1, 3, 2); plt.imshow(mask_np, cmap='gray'); plt.title("Ground Truth")
    plt.subplot(1, 3, 3); plt.imshow(pred, cmap='gray'); plt.title("Prediction")
    plt.tight_layout(); plt.show()

# === 6. 模型训练主循环 ===
best_iou = 0
patience_counter = 0

print("🚀 开始训练 U-Net")
for epoch in range(EPOCHS):
    model.train()
    total_loss, total_iou = 0, 0
    start_time = time.time()

    for batch_idx, (imgs, masks) in enumerate(train_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        preds = model(imgs)
        loss = combined_loss(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        batch_iou = compute_iou(preds, masks)
        total_loss += loss.item()
        total_iou += batch_iou

    avg_loss = total_loss / len(train_loader)
    avg_iou = total_iou / len(train_loader)
    current_lr = optimizer.param_groups[0]["lr"]
    elapsed = time.time() - start_time

    print(f"✅ Epoch {epoch+1}/{EPOCHS} 完成 | Loss: {avg_loss:.4f} | IoU: {avg_iou:.4f} | LR: {current_lr:.6f} | 耗时: {elapsed:.2f}s")
    if (epoch + 1) % 10 == 0:
        save_prediction_images(model, train_dataset, epoch + 1, indices=[0, 1, 2])
        visualize_prediction(model, train_dataset, index=0)  # 显示样本

    scheduler.step(avg_iou)

    if avg_iou > best_iou:
        best_iou = avg_iou
        patience_counter = 0
        torch.save(model.state_dict(), "best_unet_model.pth")
        print(f"🎉 最佳IoU更新为 {best_iou:.4f}，模型已保存")
    else:
        patience_counter += 1
        print(f"🕒 未提升（{patience_counter}/{PATIENCE}），等待中...")
        if patience_counter >= PATIENCE:
            print(f"🛑 Early stopping 触发。最佳 IoU = {best_iou:.4f}")
            break

In [None]:
model = AttentionUNet().to(DEVICE)  # 或 UNet() 看你使用哪个
model.load_state_dict(torch.load("best_unet_model.pth"))
model.eval()


total_iou = 0

for idx, (img, mask) in enumerate(test_dataset):
    img = img.unsqueeze(0).to(DEVICE)
    mask = mask.unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        pred = model(img)
        iou = compute_iou(pred, mask, threshold=0.3)  # 可调 threshold
        total_iou += iou

    print(f"📍 Test Sample {idx:03d} | IoU: {iou:.4f}")

avg_iou = total_iou / len(test_dataset)
print(f"\n✅ 测试集平均 IoU: {avg_iou:.4f}")
