In [None]:
# main.ipynb

# 基本库
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os
import random

# 从我们自己创建的 .py 文件中导入
# 确保 dataset.py, model.py, trainer.py 与此 notebook 在同一目录下
# 或者它们所在的目录在 Python 的搜索路径中
try:
    import dataset
    import model
    import trainer
except ImportError as e:
    print(f"Error importing local modules: {e}")
    print("Please ensure dataset.py, model.py, and trainer.py are in the same directory or in PYTHONPATH.")
    # 在Colab等环境中，如果文件在驱动器，可能需要挂载驱动器并添加到sys.path
    # import sys
    # sys.path.append('/content/drive/My Drive/your_project_folder') # 示例
    # import dataset, model, trainer 

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 设置随机种子以便结果可复现 (可选)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED) # if you are using multi-GPU.
    # CuDNN的确定性设置，可能会影响性能，但有助于复现
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # 设置为 False 以获得完全的确定性

In [None]:
# --- 配置参数 ---

# 数据集参数
DATA_ROOT = './data_voc' # PASCAL VOC 数据集下载/存放的根目录
IMG_HEIGHT = dataset.IMG_HEIGHT # 使用 dataset.py 中定义的默认值 (256)
IMG_WIDTH = dataset.IMG_WIDTH   # (256)
NUM_CLASSES = dataset.NUM_CLASSES # (21)

# 模型参数
# 我们使用的是 SimpleFCN，它会从 dataset.py 获取 NUM_CLASSES
# 如果需要，可以在这里指定要加载的预训练模型路径等

# 训练参数
BATCH_SIZE = 8      # 根据你的GPU显存调整
NUM_EPOCHS = 25     # 训练的总轮数 (可以先设小一点测试，比如5-10轮)
LEARNING_RATE = 1e-4 # 学习率
WEIGHT_DECAY = 1e-5  # 优化器的权重衰减 (L2正则化)
OPTIMIZER_TYPE = 'Adam' # 'Adam' or 'SGD'
SCHEDULER_TYPE = 'StepLR' # 'StepLR', 'ReduceLROnPlateau', or None
STEP_LR_STEP_SIZE = 10 # For StepLR: 每 N 个 epochs 降低学习率
STEP_LR_GAMMA = 0.1    # For StepLR: 学习率降低的倍数

# 文件路径
MODEL_SAVE_DIR = './saved_models'
BEST_MODEL_NAME = 'best_fcn_voc.pth'
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
BEST_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, BEST_MODEL_NAME)

# 其他
NUM_WORKERS = 2 # DataLoader 的工作进程数
DOWNLOAD_DATA = True # 第一次运行时设为True，之后可以设为False如果数据已下载

In [None]:
print("Loading PASCAL VOC 2012 dataset...")
train_loader, val_loader = dataset.get_dataloaders(
    data_root=DATA_ROOT,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    img_height=IMG_HEIGHT,
    img_width=IMG_WIDTH,
    download_data=DOWNLOAD_DATA
)

if train_loader and val_loader:
    print(f"Number of training batches: {len(train_loader)}")
    print(f"Number of validation batches: {len(val_loader)}")
    
    # 可视化一个训练样本 (可选)
    print("\nVisualizing a sample from training data...")
    try:
        sample_images, sample_masks = next(iter(train_loader))
        img_pil = dataset.tensor_to_pil(sample_images[0])
        mask_pil = dataset.mask_to_pil_color(sample_masks[0])

        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        axes[0].imshow(img_pil)
        axes[0].set_title("Sample Image")
        axes[0].axis('off')
        axes[1].imshow(mask_pil)
        axes[1].set_title("Sample Ground Truth Mask")
        axes[1].axis('off')
        plt.show()
    except Exception as e:
        print(f"Error during sample visualization: {e}")
else:
    print("Failed to load dataloaders. Please check dataset path and download status.")
    # 强行停止Notebook执行，因为没有数据无法继续
    raise RuntimeError("Data loading failed. Exiting notebook.")

In [None]:
# 初始化模型
print(f"Initializing model: SimpleFCN with {NUM_CLASSES} classes")
seg_model = model.SimpleFCN(num_classes=NUM_CLASSES, init_weights=True).to(device)

# 打印模型结构和参数量 (可选)
# print(seg_model)
total_params = sum(p.numel() for p in seg_model.parameters())
trainable_params = sum(p.numel() for p in seg_model.parameters() if p.requires_grad)
print(f"Total model parameters: {total_params:,}")
print(f"Trainable model parameters: {trainable_params:,}")

# 定义损失函数
# PASCAL VOC 的掩码中，我们将边界值255映射到了背景类0
# 所以类别索引是 0 到 NUM_CLASSES-1
# CrossEntropyLoss 默认会忽略 target 中值为 ignore_index 的像素
# 如果我们的 dataset.py 中没有将255映射到0，这里可以设置 ignore_index=255
# 但由于我们已经处理了，所以不需要 ignore_index
criterion = nn.CrossEntropyLoss() 
print(f"Using loss function: {criterion.__class__.__name__}")

# 定义优化器
if OPTIMIZER_TYPE.lower() == 'adam':
    optimizer = optim.Adam(seg_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
elif OPTIMIZER_TYPE.lower() == 'sgd':
    optimizer = optim.SGD(seg_model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=WEIGHT_DECAY)
else:
    raise ValueError(f"Unsupported optimizer type: {OPTIMIZER_TYPE}")
print(f"Using optimizer: {optimizer.__class__.__name__} with LR={LEARNING_RATE}")

# 定义学习率调度器 (可选)
scheduler = None
if SCHEDULER_TYPE:
    if SCHEDULER_TYPE.lower() == 'steplr':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_LR_STEP_SIZE, gamma=STEP_LR_GAMMA)
        print(f"Using StepLR scheduler: step_size={STEP_LR_STEP_SIZE}, gamma={STEP_LR_GAMMA}")
    elif SCHEDULER_TYPE.lower() == 'reducelronplateau':
        # 通常基于验证集指标，比如验证集损失或MIoU
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)
        print(f"Using ReduceLROnPlateau scheduler (monitors val MIoU).")
    else:
        print(f"Scheduler type '{SCHEDULER_TYPE}' not recognized or None. No scheduler will be used.")

In [None]:
# 初始化训练器
segmentation_trainer = trainer.SemanticSegmenterTrainer(
    model=seg_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler, # Pass the scheduler here
    device=device,
    num_classes=NUM_CLASSES
)

# 开始训练
print("\n--- Starting Training ---")
segmentation_trainer.train(num_epochs=NUM_EPOCHS, model_save_path_best=BEST_MODEL_PATH)
print("--- Training Finished ---")

# 绘制训练过程中的损失曲线 (可选)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(segmentation_trainer.train_losses, label='Train Loss')
plt.plot(segmentation_trainer.val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')

# 绘制验证集上的MIoU变化 (可选)
if segmentation_trainer.val_metrics_history:
    val_mious = [m['MeanIoU'] for m in segmentation_trainer.val_metrics_history]
    plt.subplot(1, 2, 2)
    plt.plot(val_mious, label='Validation MIoU')
    plt.xlabel('Epoch')
    plt.ylabel('MIoU')
    plt.legend()
    plt.title('Validation Mean IoU')

plt.tight_layout()
plt.show()

In [None]:
# 评估训练好的最佳模型
# trainer.evaluate() 会自动加载在训练过程中保存的最佳模型 (如果model_save_path_best被正确更新)
# 或者我们可以明确指定要加载的模型路径
print("\n--- Evaluating Best Model on Validation Set ---")

# 确保我们的模型实例 (seg_model) 加载了最佳权重
# trainer.train() 结束时，如果找到了最佳模型，会加载它。
# 如果想更明确，可以重新实例化模型并加载：
# best_seg_model = model.SimpleFCN(num_classes=NUM_CLASSES).to(device)
# best_seg_model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
# evaluation_trainer = trainer.SemanticSegmenterTrainer(model=best_seg_model, ...) # 或者直接用现有的 segmentation_trainer
# segmentation_trainer.model = best_seg_model # 更新trainer中的模型

# 使用之前训练结束时 trainer 实例中的模型 (它应该已经加载了最佳权重)
# 或者，如果你想确保从文件加载：
if os.path.exists(BEST_MODEL_PATH):
    print(f"Loading best model weights from: {BEST_MODEL_PATH}")
    segmentation_trainer.model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
else:
    print(f"Warning: Best model path {BEST_MODEL_PATH} not found. Evaluating with current model weights.")


# 在验证集上进行评估
# 注意: PASCAL VOC 2012 没有官方的隐藏测试集，通常做法是在 'val' 集上报告结果。
# 如果你有单独的测试集，可以创建一个对应的 DataLoader 传给 evaluate 方法。
val_metrics, val_confusion_matrix, val_viz_samples = segmentation_trainer.evaluate(
    data_loader=val_loader, 
    checkpoint_path=None # 因为我们已经在上面加载了最佳模型，或 trainer 内部已经加载
)

print(f"\nFinal Validation MIoU: {val_metrics.get('MeanIoU', 'N/A'):.4f}")

# (可选) 打印混淆矩阵
# print("\nValidation Confusion Matrix:")
# print(val_confusion_matrix)

In [None]:
# 可视化来自评估过程的一些样本预测结果
print("\n--- Visualizing Evaluation Samples ---")
eval_images, eval_gt_masks, eval_pred_masks = val_viz_samples

# 使用 dataset.py 中的 VOC_COLORMAP 和 VOC_CLASSES
voc_colormap_torch = dataset.voc_colormap_to_tensor(dataset.VOC_COLORMAP).to(device)

num_samples_to_show = min(len(eval_images), 5) # 最多显示5个

if num_samples_to_show > 0:
    fig, axes = plt.subplots(num_samples_to_show, 3, figsize=(12, num_samples_to_show * 4))
    if num_samples_to_show == 1: # 如果只有一个样本，axes不是数组
        axes = [axes] 
        
    for i in range(num_samples_to_show):
        img_pil = dataset.tensor_to_pil(eval_images[i])
        gt_mask_pil = dataset.mask_to_pil_color(eval_gt_masks[i], colormap=dataset.VOC_COLORMAP)
        pred_mask_pil = dataset.mask_to_pil_color(eval_pred_masks[i], colormap=dataset.VOC_COLORMAP) # pred_masks已经是类别索引了

        axes[i][0].imshow(img_pil)
        axes[i][0].set_title(f"Image {i+1}")
        axes[i][0].axis('off')

        axes[i][1].imshow(gt_mask_pil)
        axes[i][1].set_title(f"Ground Truth {i+1}")
        axes[i][1].axis('off')

        axes[i][2].imshow(pred_mask_pil)
        axes[i][2].set_title(f"Prediction {i+1}")
        axes[i][2].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No samples available for visualization from evaluation.")

print("\n--- End of Notebook ---")