In [19]:
# !pip install torch torchvision
import os
print(os.getcwd())  # 当前目录
print(os.path.exists("data/train/images"))  # True = 正确


/Users/zwy/Desktop/9517_project
True


In [None]:
import os
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau


# ==== 1. 数据加载部分（保持不变）====
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_size=(256, 256)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_size = image_size
        self.filenames = sorted(os.listdir(image_dir))

        self.transform_img = T.Compose([
            T.ToPILImage(),
            T.Resize(self.image_size),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        image_filename = self.filenames[idx]
        mask_filename = image_filename.replace("RGB_", "mask_")

        img_path = os.path.join(self.image_dir, image_filename)
        mask_path = os.path.join(self.mask_dir, mask_filename)

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"❌ 无法读取 mask 图像: {mask_path}")

        # 转换为二值掩码
        mask = (mask > 127).astype(np.float32)

        img_tensor = self.transform_img(img)
        # 直接处理mask
        mask = cv2.resize(mask, self.image_size)
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)  # [1, H, W]

        return img_tensor, mask_tensor

# 创建数据集和数据加载器
train_dataset = SegmentationDataset(
    # image_dir="9517_project/data/train/images",
    # mask_dir="9517_project/data/train/masks",
    image_dir="data/train/images",
    mask_dir="data/train/masks",
    image_size=(256, 256)
)

val_dataset = SegmentationDataset(
    image_dir="data/test/images",
    mask_dir="data/test/masks",
    image_size=(256, 256)
)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

def evaluate(model, val_loader):
    model.eval()
    total_loss, total_iou = 0, 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            preds = model(imgs)
            loss = combined_loss(preds, masks)
            total_loss += loss.item()
            total_iou  += compute_iou(preds, masks)
    return total_loss / len(val_loader), total_iou / len(val_loader)



# ==== 2. 改进的U-Net模型（更稳定的版本）====
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Dropout2d(0.1)  # 添加dropout防止过拟合
            )

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)
        
        # 减少bottleneck的复杂度
        self.bottleneck = conv_block(512, 512)  # 从1024改为512

        self.upconv4 = nn.ConvTranspose2d(512, 512, 2, stride=2)  # 调整通道数
        self.dec4 = conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.final = nn.Sequential(
            nn.Conv2d(64, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        )
        nn.init.xavier_uniform_(self.final[0].weight, gain=1.0)
        nn.init.constant_(self.final[0].bias, -2.0)
        
        # Call weight initialization
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is self.final:  # Skip final layer since it's initialized in __init__
                    pass
                else:
                    nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))

        d4 = self.upconv4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.upconv3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.upconv2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.upconv1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.final(d1)


def focal_loss(pred, target, alpha=0.25, gamma=2.0):
    """Focal Loss for addressing class imbalance"""
    bce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
    pt = torch.exp(-bce_loss)
    focal_loss = alpha * (1-pt)**gamma * bce_loss
    return focal_loss.mean()

# ==== 4. IoU计算函数 ====
def compute_iou(pred, target, threshold=0.5):
    with torch.no_grad():
        pred = torch.sigmoid(pred)
        pred_bin = (pred > threshold).float()
        target = target.float()

        intersection = (pred_bin * target).sum((2, 3))
        union = (pred_bin + target).clamp(0, 1).sum((2, 3))
        
        iou = torch.where(union == 0, torch.tensor(1.0, device=union.device), 
                         intersection / (union + 1e-8))
        return iou.mean().item()

# ==== 5. 训练代码 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

model = UNet().to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(5.0).to(device))


def dice_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)  # 转换为概率
    target = target.float()

    intersection = (pred * target).sum((1,2,3))
    union = pred.sum((1,2,3)) + target.sum((1,2,3))
    dice = (2. * intersection + smooth) / (union + smooth)

    return 1 - dice.mean()

def combined_loss(pred, target):
    bce = criterion(pred, target)
    dice = dice_loss(pred, target)
    return bce + dice


# 更保守的优化器设置
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

# 学习率调度
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# 添加早停机制
best_iou = 0
patience_counter = 0
patience = 5

epochs = 50
for epoch in range(epochs):


    print(f"\n🚀 Epoch {epoch+1}/{epochs} 开始")
    start_time = time.time()

    model.train()
    total_loss, total_iou = 0, 0

    for batch_idx, (imgs, masks) in enumerate(train_loader):
        imgs = imgs.to(device)
        masks = masks.to(device)

        preds = model(imgs)
        loss = combined_loss(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        
        # 更严格的梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

        batch_iou = compute_iou(preds, masks)
        total_loss += loss.item()
        total_iou += batch_iou

        # 每10个batch打印一次
        if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(train_loader):
            print(f"  🌀 Batch {batch_idx+1:>3}/{len(train_loader)} | Loss: {loss.item():.4f} | IoU: {batch_iou:.4f}")
            print(f"    🔍 Pred min/max: {preds.min().item():.4f} ~ {preds.max().item():.4f}")
            print(f"    📊 Pred after sigmoid: {torch.sigmoid(preds).min().item():.4f} ~ {torch.sigmoid(preds).max().item():.4f}")

    avg_loss = total_loss / len(train_loader)
    avg_iou = total_iou / len(train_loader)
    elapsed = time.time() - start_time
    current_lr = optimizer.param_groups[0]['lr']
    print(f"... | LR: {current_lr:.6f}")
    val_loss, val_iou = evaluate(model, val_loader)
    print(f"🧪 验证集 Loss: {val_loss:.4f} | IoU: {val_iou:.4f}")

    print(f"✅ Epoch {epoch+1}/{epochs} 完成 | Loss: {avg_loss:.4f} | IoU: {avg_iou:.4f} | LR: {current_lr:.6f} | 耗时: {elapsed:.2f}s")

    # 学习率调度
    scheduler.step(val_iou)


    # 早停检查
    if val_iou > best_iou:
        best_iou = val_iou

        patience_counter = 0
        # 保存最佳模型
        torch.save(model.state_dict(), 'best_unet_model.pth')
        print(f"🎉 新的最佳IoU: {best_iou:.4f}, 模型已保存")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"🛑 早停触发，最佳IoU: {best_iou:.4f}")
            break

print(f"\n🎯 训练完成！最佳IoU: {best_iou:.4f}")

使用设备: cpu

🚀 Epoch 1/50 开始


KeyboardInterrupt: 

In [28]:
# === 加载模型 ===
model = UNet().to(device)
model.load_state_dict(torch.load("best_unet_model.pth"))
model.eval()

# === 在验证集/测试集上评估 ===
val_loss, val_iou = evaluate(model, val_loader)
print(f"🎯 最终验证集表现 | Loss: {val_loss:.4f} | IoU: {val_iou:.4f}")

RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict: "final.weight", "final.bias". 
	Unexpected key(s) in state_dict: "final.0.weight", "final.0.bias", "final.1.weight", "final.1.bias", "final.1.running_mean", "final.1.running_var", "final.1.num_batches_tracked". 