# TCN模型训练与评估

本notebook实现了基于时序卷积网络(TCN)的农作物制图模型训练、评估和推理流程。

## 1. 环境设置与导入

In [1]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import json
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# 导入自定义模块
from TCN.model import create_tcn_model
from TCN.dataset import prepare_data, save_data_info
from TCN.utils import (
    save_checkpoint, 
    calculate_metrics,
    plot_training_history,
    plot_confusion_matrix,
    plot_class_performance
)
from TCN.train import FocalLoss

# 设置matplotlib中文显示
plt.rcParams['font.sans-serif'] = ['PingFang SC', 'SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.7.1
CUDA available: False


## 2. 配置参数

In [2]:
# 训练配置

# 首先，确定可用的最佳设备（GPU > MPS > CPU）
if torch.cuda.is_available():
    device = 'cuda'
# 检查macOS的MPS后端是否可用
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

config = {
    # 数据参数
    'data_path': '../dataset',
    'patch_size': 32,      # 从 64 降低到 32
    'stride': 16,
    'test_size': 0.2,
    'val_size': 0.1,
    'batch_size': 8,       # 从 16/8 降低到 4
    'num_workers': 2,      # 从 4 降低到 2
    
    # 模型参数
    'input_channels': 8,
    'temporal_steps': 28,
    'tcn_channels': [64, 128, 256],
    'kernel_size': 3,
    'dropout': 0.2,
    
    # 训练参数
    'epochs': 50,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'use_focal_loss': True,
    
    # 其他参数
    'save_dir': './checkpoints',
    'device': device
}

print("Training configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

# 创建保存目录
Path(config['save_dir']).mkdir(parents=True, exist_ok=True)

Training configuration:
  data_path: ../dataset
  patch_size: 32
  stride: 16
  test_size: 0.2
  val_size: 0.1
  batch_size: 8
  num_workers: 2
  input_channels: 8
  temporal_steps: 28
  tcn_channels: [64, 128, 256]
  kernel_size: 3
  dropout: 0.2
  epochs: 50
  learning_rate: 0.001
  weight_decay: 0.0001
  use_focal_loss: True
  save_dir: ./checkpoints
  device: mps


## 3. 数据加载与预处理

In [3]:
# 准备数据
print("Preparing data loaders...")
train_loader, val_loader, test_loader, data_info = prepare_data(
    data_path=config['data_path'],
    patch_size=config['patch_size'],
    stride=config['stride'],
    test_size=config['test_size'],
    val_size=config['val_size'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers']
)

# 保存数据信息
save_data_info(data_info, f"{config['save_dir']}/data_info.pkl")

print(f"Number of classes: {data_info['num_classes']}")
print(f"Class names: {data_info['class_names']}")
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
print(f"Class weights: {data_info['class_weights']}")

Preparing data loaders...
Original data shape - X: (326, 1025, 28, 8), Y: (326, 1025)
Dataset split - Train: 837, Val: 120, Test: 240
Data info saved to ./checkpoints/data_info.pkl
Number of classes: 9
Class names: {0: 'Background', 1: 'Corn', 2: 'Wheat', 3: 'Sunflower', 4: 'Pumpkin', 5: 'Artificial_Surface', 6: 'Water', 7: 'Road', 8: 'Other'}
Train samples: 837
Validation samples: 120
Test samples: 240
Class weights: tensor([1.8564e+04, 2.0179e-01, 5.6892e+00, 6.0102e-01, 3.8582e+00, 4.9530e+00,
        3.9785e+00, 1.5193e+00, 1.1988e+00])


## 4. 创建模型

In [4]:
# 创建模型
device = torch.device(config['device'])
print(f"Using device: {device}")

model = create_tcn_model(
    input_channels=config['input_channels'],
    temporal_steps=config['temporal_steps'],
    num_classes=data_info['num_classes'],
    tcn_channels=config['tcn_channels'],
    kernel_size=config['kernel_size'],
    dropout=config['dropout']
).to(device)

# 打印模型信息
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# 测试模型前向传播
dummy_input = torch.randn(2, 64, 64, 28, 8).to(device)
with torch.no_grad():
    output = model(dummy_input)
print(f"Test input shape: {dummy_input.shape}")
print(f"Test output shape: {output.shape}")

Using device: mps
Model created successfully!
Total parameters: 459,337
Trainable parameters: 459,337
Test input shape: torch.Size([2, 64, 64, 28, 8])
Test output shape: torch.Size([2, 64, 64, 9])


## 5. 设置损失函数和优化器

In [5]:
# 损失函数
if config['use_focal_loss']:
    criterion = FocalLoss(alpha=data_info['class_weights'].to(device), gamma=2.0)
    print("Using Focal Loss")
else:
    criterion = nn.CrossEntropyLoss(weight=data_info['class_weights'].to(device))
    print("Using Cross Entropy Loss")

# 优化器
optimizer = optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

# 学习率调度器
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,
    T_mult=2
)

print(f"Optimizer: AdamW (lr={config['learning_rate']}, weight_decay={config['weight_decay']})")
print(f"Scheduler: CosineAnnealingWarmRestarts")

Using Focal Loss
Optimizer: AdamW (lr=0.001, weight_decay=0.0001)
Scheduler: CosineAnnealingWarmRestarts


## 6. 训练函数定义

In [6]:
from tqdm.notebook import tqdm
from torch.cuda.amp import GradScaler, autocast

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, scaler):
    """训练一个epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch} - Training')
    
    for batch_idx, (data, targets) in enumerate(pbar):
        data, targets = data.to(device), targets.to(device)
        
        optimizer.zero_grad()
        
        # 混合精度
        with autocast():
            outputs = model(data)
            outputs = outputs.permute(0, 3, 1, 2)  # (batch, num_classes, height, width)
            loss = criterion(outputs, targets)
        
        # 反向传播
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # 统计
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.numel()
        correct += predicted.eq(targets).sum().item()
        
        # 更新进度条
        accuracy = 100. * correct / total
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{accuracy:.2f}%'
        })
    
    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device, num_classes):
    """验证模型"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        
        for data, targets in pbar:
            data, targets = data.to(device), targets.to(device)
            
            # 混合精度验证
            with autocast():
                outputs = model(data)
                outputs = outputs.permute(0, 3, 1, 2)
                loss = criterion(outputs, targets)
            
            total_loss += loss.item()
            
            _, predicted = outputs.max(1)
            all_preds.append(predicted.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    # 计算指标
    all_preds = np.concatenate(all_preds).flatten()
    all_targets = np.concatenate(all_targets).flatten()
    
    metrics = calculate_metrics(all_preds, all_targets, num_classes)
    val_loss = total_loss / len(val_loader)
    
    return val_loss, metrics['overall_accuracy'], metrics

print("Training functions defined successfully!")

Training functions defined successfully!


## 7. 模型训练

In [None]:
# 训练历史记录
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_miou': [],
    'learning_rates': []
}

# 混合精度训练scaler
scaler = GradScaler()

best_miou = 0
patience = 10
patience_counter = 0

print(f"Starting training for {config['epochs']} epochs...")
print("="*60)

for epoch in range(1, config['epochs'] + 1):
    print(f"\nEpoch {epoch}/{config['epochs']}")
    print("-"*40)
    
    # 训练
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device, epoch, scaler
    )
    
    # 验证
    val_loss, val_acc, val_metrics = validate(
        model, val_loader, criterion, device, dat_info['num_classes']
    )
    
    # 学习率调度
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # 记录历史
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_miou'].append(val_metrics['mean_iou'])
    history['learning_rates'].append(current_lr)
    
    # 打印结果
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"Val mIoU: {val_metrics['mean_iou']:.4f}")
    print(f"Learning Rate: {current_lr:.6f}")
    
    # 保存最佳模型
    if val_metrics['mean_iou'] > best_miou:
        best_miou = val_metrics['mean_iou']
        patience_counter = 0
        
        save_checkpoint({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_miou': best_miou,
            'config': config,
            'metrics': val_metrics
        }, f"{config['save_dir']}/best_model.pth")
        
        print(f"🎉 New best model saved! (mIoU: {best_miou:.4f})")
    else:
        patience_counter += 1
    
    # 早停检查
    if patience_counter >= patience:
        print(f"\n⏰ Early stopping triggered after {patience} epochs without improvement")
        break

print(f"\n✅ Training completed! Best mIoU: {best_miou:.4f}")

Starting training for 50 epochs...

Epoch 1/50
----------------------------------------


Epoch 1 - Training:   0%|          | 0/105 [00:00<?, ?it/s]

NameError: name 'dat' is not defined

## 8. 训练历史可视化

In [None]:
# 绘制训练历史
plot_training_history(history, save_path=f"{config['save_dir']}/training_curves.png")

# 保存训练历史
with open(f"{config['save_dir']}/history.json", 'w') as f:
    json.dump(history, f, indent=4)

print("Training history plots generated and saved!")

## 9. 模型评估

In [None]:
# 加载最佳模型
checkpoint = torch.load(f"{config['save_dir']}/best_model.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model (epoch {checkpoint['epoch']}, mIoU: {checkpoint['best_miou']:.4f})")

# 在测试集上评估
test_loss, test_acc, test_metrics = validate(
    model, test_loader, criterion, device, data_info['num_classes']
)

print(f"\nTest Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")
print(f"Test mIoU: {test_metrics['mean_iou']:.4f}")
print(f"Mean Class Accuracy: {test_metrics['mean_accuracy']:.4f}")

# 详细的每类指标
print(f"\nPer-class metrics:")
for i, class_name in data_info['class_names'].items():
    if i < len(test_metrics['per_class_accuracy']):
        acc = test_metrics['per_class_accuracy'][i]
        iou = test_metrics['per_class_iou'][i]
        print(f"  {class_name}: Acc={acc:.3f}, IoU={iou:.3f}")

## 10. 评估结果可视化

In [None]:
# 绘制混淆矩阵
class_names_list = [data_info['class_names'][i] for i in range(data_info['num_classes'])]
plot_confusion_matrix(
    test_metrics['confusion_matrix'],
    class_names_list,
    save_path=f"{config['save_dir']}/confusion_matrix.png"
)

# 绘制类别性能图
plot_class_performance(
    test_metrics, 
    save_path=f"{config['save_dir']}/class_performance.png"
)

## 11. 预测样本可视化

In [None]:
from TCN.utils import visualize_predictions

# 获取一批测试数据
data_iter = iter(test_loader)
test_batch_data, test_batch_targets = next(data_iter)

# 预测
model.eval()
with torch.no_grad():
    test_batch_data_device = test_batch_data.to(device)
    test_outputs = model(test_batch_data_device)
    _, test_predictions = test_outputs.max(-1)

# 可视化前5个样本的预测结果
visualize_predictions(
    test_batch_data[:5],
    test_batch_targets[:5],
    test_predictions[:5].cpu(),
    class_names_list,
    num_samples=5,
    save_path=f"{config['save_dir']}/predictions_visualization.png"
)

## 12. 模型推理示例

In [None]:
from TCN.inference import CropMappingInference

# 初始化推理器
inferencer = CropMappingInference(
    model_path=f"{config['save_dir']}/best_model.pth",
    device=config['device']
)

# 加载完整数据集进行推理演示（使用较小的区域）
full_x_data = np.load('../dataset/x.npy')
full_y_data = np.load('../dataset/y.npy')

# 选择一个较小的区域进行快速推理演示
crop_size = 128
x_crop = full_x_data[:crop_size, :crop_size]
y_crop = full_y_data[:crop_size, :crop_size]

print(f"Running inference on {x_crop.shape[:2]} region...")

# 进行推理
predictions = inferencer.predict_full_image(
    x_crop,
    patch_size=32,  # 使用较小的patch以加快推理
    overlap=8,
    batch_size=4
)

# 可视化推理结果
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# 显示RGB合成图（使用第一个时间步的前三个波段）
rgb = x_crop[:, :, 0, :3]
rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
axes[0].imshow(rgb)
axes[0].set_title('Input (RGB Composite)')
axes[0].axis('off')

# 显示真实标签
im1 = axes[1].imshow(y_crop, cmap='tab10', vmin=0, vmax=8)
axes[1].set_title('Ground Truth')
axes[1].axis('off')

# 显示预测结果
axes[2].imshow(predictions, cmap='tab10', vmin=0, vmax=8)
axes[2].set_title('TCN Predictions')
axes[2].axis('off')

# 添加颜色条
cbar = plt.colorbar(im1, ax=axes, orientation='horizontal', fraction=0.05, pad=0.1)
cbar.set_ticks(range(len(class_names_list)))
cbar.set_ticklabels([name.split('(')[0].strip() if '(' in name else name for name in class_names_list])

plt.tight_layout()
plt.savefig(f"{config['save_dir']}/inference_example.png", dpi=300, bbox_inches='tight')
plt.show()

# 计算推理准确率
inference_acc = np.mean(predictions == y_crop)
print(f"\nInference accuracy on cropped region: {inference_acc:.4f}")

# 统计预测分布
unique, counts = np.unique(predictions, return_counts=True)
print("\nPrediction distribution:")
for label, count in zip(unique, counts):
    percentage = count / predictions.size * 100
    class_name = data_info['class_names'].get(label, f"Unknown({label})")
    print(f"  {class_name}: {count:,} pixels ({percentage:.1f}%)")

## 13. 总结与结论

In [None]:
# 生成最终总结报告
summary_report = f"""
=== TCN模型训练总结报告 ===
训练时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
设备: {device}

模型配置:
- 输入通道数: {config['input_channels']}
- 时间步数: {config['temporal_steps']}
- TCN通道: {config['tcn_channels']}
- 类别数: {data_info['num_classes']}
- 总参数量: {total_params:,}

训练配置:
- 训练轮数: {len(history['train_loss'])}
- 批次大小: {config['batch_size']}
- 学习率: {config['learning_rate']}
- 切片大小: {config['patch_size']}

最佳性能 (验证集):
- mIoU: {best_miou:.4f}
- 准确率: {max(history['val_acc']):.2f}%

测试集性能:
- mIoU: {test_metrics['mean_iou']:.4f}
- 总体准确率: {test_metrics['overall_accuracy']:.4f}
- 平均类别准确率: {test_metrics['mean_accuracy']:.4f}

文件输出:
- 最佳模型: {config['save_dir']}/best_model.pth
- 训练历史: {config['save_dir']}/history.json
- 数据信息: {config['save_dir']}/data_info.pkl
- 训练曲线: {config['save_dir']}/training_curves.png
- 混淆矩阵: {config['save_dir']}/confusion_matrix.png

=== 报告结束 ===
"""

print(summary_report)

# 保存报告
with open(f"{config['save_dir']}/training_summary.txt", 'w', encoding='utf-8') as f:
    f.write(summary_report)

print(f"\n🎯 训练完成！所有结果已保存到: {config['save_dir']}/")
print(f"📊 最佳模型mIoU: {best_miou:.4f}")
print(f"📈 测试集mIoU: {test_metrics['mean_iou']:.4f}")