In [5]:
import torch
from torch.utils.data import DataLoader
from torch import nn
from preprocessing import get_datasets

# ==== 超参数 ====
IMG_SIZE = (256, 256)
BATCH_SIZE = 8
EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 5
POS_WEIGHT = 5.0
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

ModuleNotFoundError: No module named 'preprocessing'

In [None]:
# ==== 加载数据集 ====
train_dataset, test_dataset = get_datasets("USA_segmentation", IMG_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# ==== 定义 U-Net ====
class UNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1):
        super().__init__()
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Dropout2d(0.1)
            )

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = conv_block(512, 512)
        self.upconv4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)
        self.final = nn.Conv2d(64, out_channels, 1)
        nn.init.constant_(self.final.bias, -2.0)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))
        d4 = self.dec4(torch.cat([self.upconv4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.upconv3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.upconv2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.upconv1(d2), e1], dim=1))
        return self.final(d1)


In [None]:
# ==== 训练准备 ====
model = UNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(POS_WEIGHT).to(DEVICE))

def dice_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum((1,2,3))
    union = pred.sum((1,2,3)) + target.sum((1,2,3))
    return 1 - ((2. * intersection + smooth) / (union + smooth)).mean()

def combined_loss(pred, target):
    return criterion(pred, target) + dice_loss(pred, target)

def compute_iou(pred, target, threshold=0.5):
    with torch.no_grad():
        pred_bin = (torch.sigmoid(pred) > threshold).float()
        intersection = (pred_bin * target).sum((2,3))
        union = (pred_bin + target).clamp(0,1).sum((2,3))
        return (intersection / (union + 1e-8)).mean().item()

In [None]:
# ==== 训练 ====
best_iou = 0
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    total_loss, total_iou = 0, 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        preds = model(imgs)
        loss = combined_loss(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()
        total_iou += compute_iou(preds, masks)

    avg_loss = total_loss / len(train_loader)
    avg_iou = total_iou / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | IoU: {avg_iou:.4f}")

    scheduler.step(avg_iou)
    if avg_iou > best_iou:
        best_iou = avg_iou
        patience_counter = 0
        torch.save(model.state_dict(), "best_unet_model.pth")
        print(f"✅ 保存新最佳模型 (IoU = {best_iou:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("🛑 Early stopping.")
            break