In [None]:
!git clone https://github.com/776lucky/9517_project.git

In [None]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights


def split_file_lists(rgb_dir, nrg_dir, ratios=(0.64, 0.16, 0.20), seed=42):
    """
    划分 RGB/NRG 文件 ID 列表为 train/val/test 三部分
    ratios: (train_ratio, val_ratio, test_ratio)，总和应为1.0
    返回三个列表，元素格式为 "RGB_<id>"
    """
    # 1. 找到所有成对的图像 ID
    rgb_ids = [f.replace("RGB_", "") for f in os.listdir(rgb_dir) if f.startswith("RGB_")]
    nrg_ids = [f.replace("NRG_", "") for f in os.listdir(nrg_dir) if f.startswith("NRG_")]
    common = sorted(set(rgb_ids) & set(nrg_ids))
    # 2. 打乱
    np.random.seed(seed)
    np.random.shuffle(common)
    N = len(common)
    n_train = int(ratios[0] * N)
    n_val   = int(ratios[1] * N)
    # 3. 切分
    train_ids = common[:n_train]
    val_ids   = common[n_train:n_train + n_val]
    test_ids  = common[n_train + n_val:]
    # 4. 恢复带前缀的文件名
    train_files = ["RGB_" + i for i in train_ids]
    val_files   = ["RGB_" + i for i in val_ids]
    test_files  = ["RGB_" + i for i in test_ids]
    print(f"✅ 共 {N} 张图，划分为 train={len(train_files)}, val={len(val_files)}, test={len(test_files)}")
    return train_files, val_files, test_files

def save_prediction_images(model, dataset, epoch, indices=[0, 1, 2], save_dir="predictions"):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()

    for idx in indices:
        img, mask = dataset[idx]
        with torch.no_grad():
            pred = model(img.unsqueeze(0).to(DEVICE))['out']
            pred = torch.sigmoid(pred).squeeze().cpu().numpy()

        img_rgb = img[:3].permute(1, 2, 0).numpy()
        mask_np = mask.squeeze().numpy()

        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow(img_rgb); axs[0].set_title("RGB Image")
        axs[1].imshow(mask_np, cmap='gray'); axs[1].set_title("Ground Truth")
        axs[2].imshow(pred, cmap='gray'); axs[2].set_title("Prediction")
        for ax in axs: ax.axis("off")

        save_path = os.path.join(save_dir, f"epoch_{epoch}_idx_{idx}.png")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        print(f"📸 预测图已保存至 {save_path}")

# ===== 数据预处理部分 =====
def build_file_lists(rgb_dir, nrg_dir, train_ratio=0.8):
    rgb_files = [f.replace("RGB_", "") for f in os.listdir(rgb_dir) if f.startswith("RGB_")]
    nrg_files = [f.replace("NRG_", "") for f in os.listdir(nrg_dir) if f.startswith("NRG_")]
    common_ids = sorted(set(rgb_files) & set(nrg_files))
    print(f"✅ 找到 {len(common_ids)} 张 RGB 和 NRG 对应图像")

    rgb_filenames = ["RGB_" + fid for fid in common_ids]
    np.random.seed(42)
    np.random.shuffle(rgb_filenames)
    n_train = int(train_ratio * len(rgb_filenames))
    return rgb_filenames[:n_train], rgb_filenames[n_train:]


class FourChannelSegmentationDataset(Dataset):
    def __init__(self, rgb_dir, nrg_dir, mask_dir, file_list, image_size=(256,256), augment=False):
        self.rgb_dir = rgb_dir
        self.nrg_dir = nrg_dir
        self.mask_dir = mask_dir
        self.file_list = file_list
        self.image_size = image_size
        self.augment = augment

        # 归一化参数：RGB 用 ImageNet，NIR 沿用 R 通道
        mean = [0.485, 0.456, 0.406, 0.485]
        std  = [0.229, 0.224, 0.225, 0.229]

        # 图像增强（含归一化）
        base_img_tf = [
            T.ToPILImage(), T.Resize(self.image_size),
            T.ToTensor(), T.Normalize(mean, std),
        ]
        aug_img_tf = [
            T.ToPILImage(), T.Resize(self.image_size),
            T.RandomHorizontalFlip(), T.RandomRotation(15),
            T.ToTensor(), T.Normalize(mean, std),
        ]
        self.img_transform  = T.Compose(aug_img_tf if augment else base_img_tf)

        # mask 增强（不归一化，只 flip/rotate → resize → to-tensor）
        base_mask_tf = [
            T.ToPILImage(), T.Resize(self.image_size),
            T.ToTensor(),
        ]
        aug_mask_tf  = [
            T.ToPILImage(), T.Resize(self.image_size),
            T.RandomHorizontalFlip(), T.RandomRotation(15),
            T.ToTensor(),
        ]
        self.mask_transform = T.Compose(aug_mask_tf if augment else base_mask_tf)


    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = self.file_list[idx]
        img_id   = filename.replace("RGB_", "")
        rgb_path = os.path.join(self.rgb_dir, filename)
        nrg_path = os.path.join(self.nrg_dir, "NRG_" + img_id)
        mask_path= os.path.join(self.mask_dir, "mask_" + img_id)

        # 读入并拼通道
        rgb = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
        nrg = cv2.imread(nrg_path)
        nir = nrg[:, :, 0]
        img_4ch = np.concatenate((rgb, nir[..., None]), axis=-1)

        # 二值化 mask
        mask = (cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127).astype(np.float32)

        # 同一 seed 保证 img/mask 随机参数一致
        seed = torch.randint(0, 1_000_000, (1,)).item()
        torch.manual_seed(seed)
        img_tensor  = self.img_transform(img_4ch)
        torch.manual_seed(seed)
        mask_tensor = self.mask_transform(mask)

        return img_tensor, (mask_tensor > 0.5).float()


def get_datasets(base_dir="9517_project/USA_segmentation", image_size=(256,256)):
    rgb_dir  = os.path.join(base_dir, "RGB_images")
    nrg_dir  = os.path.join(base_dir, "NRG_images")
    mask_dir = os.path.join(base_dir, "masks")
    # 三份列表
    train_files, val_files, test_files = split_file_lists(rgb_dir, nrg_dir)
    # 构造 Dataset
    train_ds = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, train_files, image_size, augment=True)
    val_ds   = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, val_files,   image_size, augment=False)
    test_ds  = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, test_files,  image_size, augment=False)
    return train_ds, val_ds, test_ds



# ===== 模型训练部分 =====
# === 1. 参数设置 ===
IMG_SIZE = (256, 256)
IN_CHANNELS = 4
OUT_CHANNELS = 1
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 2. 加载数据 ===
train_ds, val_ds, test_ds = get_datasets("9517_project/USA_segmentation", image_size=IMG_SIZE)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)


# === 3. 修改 DeepLabV3+ 模型以接受 4 通道输入 ===
class ModifiedDeepLabV3(nn.Module):
    def __init__(self, in_channels=4, out_channels=1):
        super().__init__()
        # 加载预训练的 DeepLabV3+ 模型
        self.deeplab = deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT)

        # 修改输入层以接受 4 通道
        original_conv1 = self.deeplab.backbone.conv1
        self.deeplab.backbone.conv1 = nn.Conv2d(
            in_channels,
            original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias
        )

        # 将预训练权重复制到新卷积层的 RGB 通道，并随机初始化 NIR 通道
        with torch.no_grad():
            self.deeplab.backbone.conv1.weight[:, :3] = original_conv1.weight
            self.deeplab.backbone.conv1.weight[:, 3:] = original_conv1.weight[:, :1]

        # 修改分类器以接受 ASPP 输出的 2048 通道
        self.deeplab.classifier = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.3),
            nn.Conv2d(256, out_channels, kernel_size=1)
        )

        # 初始化最终卷积层的偏置
        nn.init.constant_(self.deeplab.classifier[-1].bias, -2.0)

    def forward(self, x):
        return self.deeplab(x)

# === 4. 初始化模型 ===
model = ModifiedDeepLabV3(in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# === 5. 损失函数与 IoU ===
def focal_loss(pred, target, alpha=0.75, gamma=2):
    bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
    prob = torch.sigmoid(pred)
    loss = alpha * (1 - prob)**gamma * bce
    return loss.mean()

def tversky_loss(pred, target, alpha=0.7, beta=0.3, smooth=1e-6):
    pred = torch.sigmoid(pred)
    TP = (pred * target).sum((1,2,3))
    FP = ((1 - target) * pred).sum((1,2,3))
    FN = (target * (1 - pred)).sum((1,2,3))
    return 1 - ((TP + smooth) / (TP + alpha * FP + beta * FN + smooth)).mean()

def dice_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum((1,2,3))
    union = pred.sum((1,2,3)) + target.sum((1,2,3))
    return 1 - ((2. * intersection + smooth) / (union + smooth)).mean()

def combined_loss(pred, target):
    return 0.3 * focal_loss(pred, target, alpha=0.85, gamma=1.5) + 0.7 * tversky_loss(pred, target)

def compute_iou(pred, target, threshold=0.1):
    with torch.no_grad():
        pred_bin = (torch.sigmoid(pred) > threshold).float()
        intersection = (pred_bin * target).sum((2, 3))
        union = (pred_bin + target).clamp(0, 1).sum((2, 3))
        return (intersection / (union + 1e-8)).mean().item()

def visualize_prediction(model, dataset, index=0):
    model.eval()
    img, mask = dataset[index]
    with torch.no_grad():
        pred = model(img.unsqueeze(0).to(DEVICE))['out']
        pred = torch.sigmoid(pred).squeeze().cpu().numpy()

    img_rgb = img[:3].permute(1, 2, 0).numpy()
    mask_np = mask.squeeze().numpy()

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1); plt.imshow(img_rgb); plt.title("RGB Image")
    plt.subplot(1, 3, 2); plt.imshow(mask_np, cmap='gray'); plt.title("Ground Truth")
    plt.subplot(1, 3, 3); plt.imshow(pred, cmap='gray'); plt.title("Prediction")
    plt.tight_layout(); plt.show()

# === 6. 模型训练主循环（改） ===
best_val_iou = 0.0
patience_counter = 0




# 拿第 2 条样本试试看
img, mask = train_ds[6]   # img: [4,H,W], mask: [1,H,W]
# 反归一化并提取前三通道
mean = torch.tensor([0.485,0.456,0.406]).view(3,1,1)
std  = torch.tensor([0.229,0.224,0.225]).view(3,1,1)
rgb = img[:3] * std + mean  # 反归一化
rgb = rgb.permute(1,2,0).numpy()

m = mask.squeeze().numpy()
print(train_ds.file_list[2])      # 例如：'RGB_000123.png'

plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
plt.imshow(rgb); plt.title("RGB")
plt.axis("off")
plt.subplot(1,2,2)
plt.imshow(rgb); 
plt.imshow(m, cmap="jet", alpha=0.5); 
plt.title("Overlay Mask")
plt.axis("off")
plt.show()

imgs, masks = next(iter(train_loader))
print("img min/max:", imgs.min().item(), imgs.max().item())
print("mask min/max:", masks.min().item(), masks.max().item())




print("🚀 开始训练 DeepLabV3+（含验证集评估）")
for epoch in range(1, EPOCHS + 1):
    # --- 1) 训练 ---
    model.train()
    train_loss, train_iou = 0.0, 0.0
    t0 = time.time()
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        preds = model(imgs)['out']
        loss = combined_loss(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        train_loss += loss.item()
        train_iou  += compute_iou(preds, masks)
    train_loss /= len(train_loader)
    train_iou  /= len(train_loader)

    # --- 2) 验证 ---
    model.eval()
    val_iou = 0.0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            preds = model(imgs)['out']
            val_iou += compute_iou(preds, masks)
    val_iou /= len(val_loader)

    elapsed = time.time() - t0
    print(f"✅ Epoch {epoch}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | Train IoU: {train_iou:.4f} | "
          f" Val IoU: {val_iou:.4f} | Time: {elapsed:.1f}s")

    # --- 3) 调度 & EarlyStop（用 val_iou） ---
    scheduler.step(val_iou)

    if val_iou > best_val_iou:
        best_val_iou = val_iou
        patience_counter = 0
        torch.save(model.state_dict(), "best_deeplabv3plus_model.pth")
        print(f"🎉 验证集 IoU 提升到 {best_val_iou:.4f}，模型已保存")
    else:
        patience_counter += 1
        print(f"🕒 验证集 IoU 未提升（{patience_counter}/{PATIENCE}）")
        if patience_counter >= PATIENCE:
            print(f"🛑 Early stopping 触发，最佳 Val IoU = {best_val_iou:.4f}")
            break

    # --- 4) （可选）定期可视化验证集上的效果 ---
    if epoch % 5 == 0:
        save_prediction_images(model, val_ds, epoch, indices=[0,1,2], save_dir="predictions/val")
        visualize_prediction(model, val_ds, index=0)