In [32]:
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

# === 1. 构造可配对的图像列表 ===
def build_file_lists(rgb_dir, nrg_dir, train_ratio=0.8):
    rgb_files = [f.replace("RGB_", "") for f in os.listdir(rgb_dir) if f.startswith("RGB_")]
    nrg_files = [f.replace("NRG_", "") for f in os.listdir(nrg_dir) if f.startswith("NRG_")]
    common_ids = sorted(set(rgb_files) & set(nrg_files))
    print(f"✅ 找到 {len(common_ids)} 张 RGB 和 NRG 对应图像")

    rgb_filenames = ["RGB_" + fid for fid in common_ids]
    np.random.seed(42)
    np.random.shuffle(rgb_filenames)
    n_train = int(train_ratio * len(rgb_filenames))
    return rgb_filenames[:n_train], rgb_filenames[n_train:]

# === 2. Dataset 定义 ===
class FourChannelSegmentationDataset(Dataset):
    def __init__(self, rgb_dir, nrg_dir, mask_dir, file_list, image_size=(256, 256)):
        self.rgb_dir = rgb_dir
        self.nrg_dir = nrg_dir
        self.mask_dir = mask_dir
        self.file_list = file_list
        self.image_size = image_size

        self.transform_img = T.Compose([
            T.ToPILImage(),
            T.Resize(self.image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = self.file_list[idx]
        img_id = filename.replace("RGB_", "")
        rgb_path = os.path.join(self.rgb_dir, filename)
        nrg_path = os.path.join(self.nrg_dir, "NRG_" + img_id)
        mask_path = os.path.join(self.mask_dir, "mask_" + img_id)

        rgb = cv2.imread(rgb_path)
        nrg = cv2.imread(nrg_path)
        nir = nrg[:, :, 0][:, :, np.newaxis]

        img_4ch = np.concatenate((rgb, nir), axis=-1)
        rgb_tensor = self.transform_img(img_4ch[:, :, :3])
        nir_tensor = self.transform_img(nir.squeeze(-1))
        img_tensor = torch.cat([rgb_tensor, nir_tensor], dim=0)  # shape: (4, H, W)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, self.image_size)
        mask = (mask > 127).astype(np.float32)
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)

        return img_tensor, mask_tensor

def build_file_lists(rgb_dir, nrg_dir, train_ratio=0.8):
    rgb_files = [f.replace("RGB_", "") for f in os.listdir(rgb_dir) if f.startswith("RGB_")]
    nrg_files = [f.replace("NRG_", "") for f in os.listdir(nrg_dir) if f.startswith("NRG_")]
    common_ids = sorted(set(rgb_files) & set(nrg_files))
    rgb_filenames = ["RGB_" + fid for fid in common_ids]
    np.random.seed(42)
    np.random.shuffle(rgb_filenames)
    n_train = int(train_ratio * len(rgb_filenames))
    return rgb_filenames[:n_train], rgb_filenames[n_train:]

def get_datasets(base_dir="USA_segmentation", image_size=(256, 256)):
    rgb_dir = os.path.join(base_dir, "RGB_images")
    nrg_dir = os.path.join(base_dir, "NRG_images")
    mask_dir = os.path.join(base_dir, "masks")
    train_files, test_files = build_file_lists(rgb_dir, nrg_dir)
    train_dataset = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, train_files, image_size)
    test_dataset = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, test_files, image_size)
    return train_dataset, test_dataset


In [33]:

# === 3. 使用示例 ===
rgb_dir = "USA_segmentation/RGB_images"
nrg_dir = "USA_segmentation/NRG_images"
mask_dir = "USA_segmentation/masks"

train_files, test_files = build_file_lists(rgb_dir, nrg_dir)

train_dataset = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, train_files)
test_dataset = FourChannelSegmentationDataset(rgb_dir, nrg_dir, mask_dir, test_files)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# === 4. 验证是否为 4 通道图像 ===
sample_img, sample_mask = train_dataset[0]
print("✅ 图像 shape:", sample_img.shape)      # 应为 (4, 256, 256)
print("✅ 掩码 shape:", sample_mask.shape)    # 应为 (1, 256, 256)


✅ 图像 shape: torch.Size([4, 256, 256])
✅ 掩码 shape: torch.Size([1, 256, 256])


In [34]:
# ==== 启动训练 ====
best_iou = 0
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    total_loss, total_iou = 0, 0

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        preds = model(imgs)
        loss = combined_loss(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        total_iou += compute_iou(preds, masks)

    avg_loss = total_loss / len(train_loader)
    avg_iou = total_iou / len(train_loader)

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | IoU: {avg_iou:.4f}")

    scheduler.step(avg_iou)
    if avg_iou > best_iou:
        best_iou = avg_iou
        patience_counter = 0
        torch.save(model.state_dict(), "best_unet_model.pth")
        print(f"✅ 保存新最佳模型 (IoU = {best_iou:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("🛑 Early stopping.")
            break


NameError: name 'EPOCHS' is not defined