In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import v2

from unet.dataloader import LoveDA
from unet.loss import DiceFocalLoss
from unet.net import UNet
from unet.train import validate_model

In [None]:
lr = 1e-4
batch_size = 32
num_epochs = 20

In [None]:
print("Initializing datasets...")
# 创建数据集对象
train_transform = v2.Compose(
    [
        v2.RandomResizedCrop((256, 256), scale=(0.8, 1.2), ratio=(0.75, 1.33)),
        v2.RandomRotation(degrees=(-180.0, 180.0)),
        v2.RandomHorizontalFlip(p=0.5),
        v2.RandomVerticalFlip(p=0.5),
        v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
val_transform = v2.Compose(
    [
        v2.Resize((256, 256), antialias=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
train_rural_dataset = LoveDA(
    Path("dataset/Train/Rural/images_png_resized"),
    Path("dataset/Train/Rural/masks_png_resized"),
    transform=train_transform,
)
train_urban_dataset = LoveDA(
    Path("dataset/Train/Urban/images_png_resized"),
    Path("dataset/Train/Urban/masks_png_resized"),
    transform=train_transform,
)
train_dataset = torch.utils.data.ConcatDataset(
    [train_rural_dataset, train_urban_dataset]
)
val_rural_dataset = LoveDA(
    Path("dataset/Val/Rural/images_png_resized"),
    Path("dataset/Val/Rural/masks_png_resized"),
    transform=val_transform,
)
val_urban_dataset = LoveDA(
    Path("dataset/Val/Urban/images_png_resized"),
    Path("dataset/Val/Urban/masks_png_resized"),
    transform=val_transform,
)
val_dataset = torch.utils.data.ConcatDataset([val_rural_dataset, val_urban_dataset])

In [None]:
# 使用DataLoader包装数据集，并设置batch_size和num_workers
print("Initializing dataloaders...")
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, num_workers=4, shuffle=True
)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)

In [None]:
print("Initializing model...")
if torch.cuda.is_available():
    print("Using GPU...")
    device = torch.device("cuda")
    if torch.backends.cudnn.is_available():
        print("Using cuDNN...")
        torch.backends.cudnn.benchmark = True
else:
    print("Using CPU...")
    device = torch.device("cpu")

In [None]:
def denormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """
    反归一化图像张量以用于显示
    """
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor


def visualize_prediction(model, dataset, device, epoch):
    """
    随机选择一个验证样本进行预测并可视化
    """
    # 随机选择一个样本
    idx = random.randint(0, len(dataset) - 1)
    image, mask = dataset[idx]

    # 添加批次维度并移动到设备
    image_tensor = image.unsqueeze(0).to(device)

    # 获取模型预测
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        # 对于多分类问题，使用argmax获取预测类别
        prediction = torch.argmax(output, dim=1).cpu().numpy()[0]
    model.train()

    # 反归一化图像以用于显示
    image_display = denormalize(image).permute(1, 2, 0).numpy()

    # 创建可视化
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # 显示原图
    axes[0].imshow(np.clip(image_display, 0, 1))
    axes[0].set_title(f"Original Image (Epoch {epoch})")
    axes[0].axis('off')

    # 显示真实标签
    axes[1].imshow(mask.numpy(), cmap='tab10')
    axes[1].set_title(f"Ground Truth")
    axes[1].axis('off')

    # 显示预测结果
    axes[2].imshow(prediction, cmap='tab10')
    axes[2].set_title(f"Prediction")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
model = UNet(3, 8).to(device)

print("Initializing loss...")
# value
# 无效       0
# 背景	    1
# 建筑物	    2
# 道路	    3
# 水体	    4
# 荒地	    5
# 森林	    6
# 农业用地	7

# 创建忽略索引0的损失函数，num_classes设为8（包括0）
loss = DiceFocalLoss(num_classes=8, ignore_index=0).to(device)

print("Initializing optimizer...")
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)

print("Training...")
model.train()
model.to(device)
# 更改学习率调度器为更敏感的参数
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, "min", patience=2, factor=0.5, min_lr=1e-7
)

# 添加早停机制
best_val_loss = float('inf')
patience_counter = 0
patience_limit = 5

for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_dataloader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        # 计算损失
        loss_value = loss(output, target)
        loss_value.backward()
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        if batch_idx % 10 == 0:
            print(
                f"Epoch: {epoch}, Batch: {batch_idx}, Training Loss: {loss_value.item():.4f}"
            )

    val_loss = validate_model(model, val_dataloader, loss, device)

    # 更新学习率
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]["lr"]

    print(f"Epoch: {epoch}, Validation Loss: {val_loss:.4f}, LR: {current_lr:.6f}")

    # 随机选择一个验证样本进行预测并可视化
    visualize_prediction(model, val_dataset, device, epoch)

    # 早停机制
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # 保存最佳模型
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1

    if patience_counter >= patience_limit:
        print(f"Early stopping at epoch {epoch}")
        break