# Vision Transformer 农作物制图训练和评估

本notebook演示如何使用Vision Transformer模型进行多光谱时序数据的农作物分类任务。

## 目录
1. [环境设置和导入](#环境设置和导入)
2. [数据加载和探索](#数据加载和探索)
3. [模型创建和配置](#模型创建和配置)
4. [训练过程](#训练过程)
5. [模型评估](#模型评估)
6. [结果可视化](#结果可视化)
7. [模型推理](#模型推理)

## 1. 环境设置和导入

In [1]:
# 导入必要的库
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import time
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# 设置matplotlib中文显示
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU设备: {torch.cuda.get_device_name(0)}")
    print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch版本: 2.7.1
CUDA可用: False


In [2]:
# 导入Transformer模块
import sys
sys.path.append('..')

from Transformer.model import create_transformer_model, CropMappingTransformer
from Transformer.dataset import prepare_data, save_data_info, load_data_info
from Transformer.train import (
    train_epoch, validate, FocalLoss, DiceLoss, CombinedLoss
)
from Transformer.utils import (
    save_checkpoint, load_checkpoint, EarlyStopping,
    calculate_metrics, plot_training_history, plot_class_performance,
    plot_confusion_matrix, visualize_predictions
)

print("✅ Transformer模块导入成功!")

✅ Transformer模块导入成功!


## 2. 数据加载和探索

In [3]:
# 配置参数
config = {
    # 数据参数
    'data_path': '../dataset',
    'patch_size': 64,
    'stride': 32,
    'test_size': 0.2,
    'val_size': 0.1,
    'batch_size': 8,  # Transformer通常需要较小的batch size
    'num_workers': 4,
    
    # 模型参数
    'input_channels': 8,
    'temporal_steps': 28,
    'model_patch_size': 8,  # Transformer的patch大小
    'embed_dim': 256,
    'num_layers': 6,
    'num_heads': 8,
    'mlp_ratio': 4.0,
    'dropout': 0.1,
    
    # 训练参数
    'epochs': 50,  # 演示用较少epochs
    'learning_rate': 1e-4,  # Transformer通常使用较小的学习率
    'weight_decay': 1e-4,
    'gradient_accumulation_steps': 2,
    'max_grad_norm': 1.0,
    
    # 损失函数
    'use_combined_loss': True,
    'focal_gamma': 2.5,
    'label_smoothing': 0.1,
    
    # 学习率调度
    'scheduler_type': 'cosine',
    'min_lr': 1e-6,
    
    # 其他
    'patience': 10,
    'save_dir': './checkpoints_notebook',
    'augment_train': True
}

# 创建保存目录
Path(config['save_dir']).mkdir(parents=True, exist_ok=True)

print("📋 训练配置:")
for key, value in config.items():
    print(f"  {key}: {value}")

📋 训练配置:
  data_path: ../dataset
  patch_size: 64
  stride: 32
  test_size: 0.2
  val_size: 0.1
  batch_size: 8
  num_workers: 4
  input_channels: 8
  temporal_steps: 28
  model_patch_size: 8
  embed_dim: 256
  num_layers: 6
  num_heads: 8
  mlp_ratio: 4.0
  dropout: 0.1
  epochs: 50
  learning_rate: 0.0001
  weight_decay: 0.0001
  gradient_accumulation_steps: 2
  max_grad_norm: 1.0
  use_combined_loss: True
  focal_gamma: 2.5
  label_smoothing: 0.1
  scheduler_type: cosine
  min_lr: 1e-06
  patience: 10
  save_dir: ./checkpoints_notebook
  augment_train: True


In [None]:
# 加载和准备数据
print("🔄 加载数据...")
train_loader, val_loader, test_loader, data_info = prepare_data(
    data_path=config['data_path'],
    patch_size=config['patch_size'],
    stride=config['stride'],
    test_size=config['test_size'],
    val_size=config['val_size'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    augment_train=config['augment_train']
)

print(f"\n📊 数据集信息:")
print(f"  类别数量: {data_info['num_classes']}")
print(f"  训练批次: {len(train_loader)}")
print(f"  验证批次: {len(val_loader)}")
print(f"  测试批次: {len(test_loader)}")
print(f"  输入形状: {data_info['input_shape']}")

# 保存数据信息
save_data_info(data_info, f"{config['save_dir']}/data_info.pkl")

# 显示类别信息
print(f"\n🏷️ 类别信息:")
for idx, name in data_info['class_names'].items():
    print(f"  {idx}: {name}")

print(f"\n⚖️ 类别权重: {data_info['class_weights'].numpy()}")

In [None]:
# 数据可视化
# 获取一个批次的数据进行可视化
data_iter = iter(train_loader)
sample_batch_x, sample_batch_y = next(data_iter)

print(f"样本批次形状 - X: {sample_batch_x.shape}, Y: {sample_batch_y.shape}")
print(f"X数值范围: [{sample_batch_x.min():.3f}, {sample_batch_x.max():.3f}]")
print(f"Y唯一值: {torch.unique(sample_batch_y)}")

# 可视化几个样本
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

for i in range(min(8, sample_batch_x.shape[0])):
    # 显示第一个时间步的RGB合成图像
    rgb_img = sample_batch_x[i, :, :, 0, :3].numpy()
    rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
    
    axes[i].imshow(rgb_img)
    axes[i].set_title(f'Sample {i+1} - RGB (t=0)')
    axes[i].axis('off')

plt.tight_layout()
plt.show()

# 可视化对应的标签
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

for i in range(min(8, sample_batch_y.shape[0])):
    im = axes[i].imshow(sample_batch_y[i].numpy(), cmap='tab10', vmin=0, vmax=7)
    axes[i].set_title(f'Sample {i+1} - Labels')
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 3. 模型创建和配置

In [None]:
# 设置设备
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(f"🖥️ 使用设备: {device}")

# 创建Transformer模型
model = create_transformer_model(
    input_channels=config['input_channels'],
    temporal_steps=config['temporal_steps'],
    num_classes=data_info['num_classes'],
    patch_size=config['model_patch_size'],
    embed_dim=config['embed_dim'],
    num_layers=config['num_layers'],
    num_heads=config['num_heads'],
    mlp_ratio=config['mlp_ratio'],
    dropout=config['dropout']
).to(device)

# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n🔧 模型信息:")
print(f"  总参数量: {total_params:,}")
print(f"  可训练参数: {trainable_params:,}")
print(f"  模型大小: {total_params * 4 / (1024**2):.1f} MB (FP32)")

# 测试前向传播
model.eval()
with torch.no_grad():
    test_input = sample_batch_x[:2].to(device)
    test_output = model(test_input)
    print(f"\n🧪 前向传播测试:")
    print(f"  输入形状: {test_input.shape}")
    print(f"  输出形状: {test_output.shape}")
    print(f"  输出数值范围: [{test_output.min():.3f}, {test_output.max():.3f}]")

print("\n✅ 模型创建成功!")

In [None]:
# 配置损失函数
if config['use_combined_loss']:
    criterion = CombinedLoss(
        alpha=data_info['class_weights'].to(device),
        gamma=config['focal_gamma'],
        label_smoothing=config['label_smoothing']
    )
    print("🎯 使用组合损失函数 (Focal + Dice + CrossEntropy)")
else:
    criterion = nn.CrossEntropyLoss(
        weight=data_info['class_weights'].to(device),
        label_smoothing=config['label_smoothing']
    )
    print("🎯 使用交叉熵损失函数")

# 配置优化器
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay'],
    betas=(0.9, 0.999),
    eps=1e-8
)

print(f"⚙️ 优化器: AdamW (lr={config['learning_rate']}, wd={config['weight_decay']})")

# 配置学习率调度器
if config['scheduler_type'] == 'cosine':
    from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,
        T_mult=2,
        eta_min=config['min_lr']
    )
    print(f"📈 学习率调度器: CosineAnnealingWarmRestarts")
else:
    scheduler = None
    print("📈 无学习率调度器")

# 混合精度训练
from torch.cuda.amp import GradScaler
scaler = GradScaler()
print("⚡ 启用混合精度训练")

# 早停
early_stopping = EarlyStopping(
    patience=config['patience'],
    verbose=True,
    save_path=f"{config['save_dir']}/best_model.pth"
)
print(f"⏹️ 早停机制 (patience={config['patience']})")

print("\n✅ 训练组件配置完成!")

## 4. 训练过程

In [None]:
# 训练历史记录
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_miou': [],
    'learning_rates': []
}

best_miou = 0.0
start_time = time.time()

print(f"🚀 开始训练 Transformer 模型 ({config['epochs']} epochs)")
print(f"⏰ 开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 80)

In [None]:
# 训练循环
for epoch in range(1, config['epochs'] + 1):
    epoch_start_time = time.time()
    
    print(f"\nEpoch {epoch}/{config['epochs']}")
    print("-" * 50)
    
    # 训练阶段
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device, scaler, epoch,
        scheduler=None,  # 我们手动调用scheduler
        scheduler_step_per_batch=False,
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        max_grad_norm=config['max_grad_norm']
    )
    
    # 验证阶段
    val_loss, val_acc, val_metrics = validate(
        model, val_loader, criterion, device, data_info['num_classes']
    )
    
    # 学习率调度
    if scheduler is not None:
        scheduler.step()
    
    # 获取当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    
    # 记录历史
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_miou'].append(val_metrics['mean_iou'])
    history['learning_rates'].append(current_lr)
    
    # 计算epoch时间
    epoch_time = time.time() - epoch_start_time
    
    # 打印结果
    print(f"Epoch {epoch} 结果:")
    print(f"  训练 - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
    print(f"  验证 - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
    print(f"  验证 mIoU: {val_metrics['mean_iou']:.4f}")
    print(f"  学习率: {current_lr:.2e}")
    print(f"  耗时: {epoch_time:.1f}s")
    
    # 保存最佳模型
    if val_metrics['mean_iou'] > best_miou:
        best_miou = val_metrics['mean_iou']
        save_checkpoint({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'scaler_state_dict': scaler.state_dict(),
            'best_miou': best_miou,
            'config': config,
            'metrics': val_metrics
        }, f"{config['save_dir']}/best_model.pth")
        print(f"  💾 新的最佳模型已保存! (mIoU: {best_miou:.4f})")
    
    # 早停检查
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("\n⏹️ 触发早停!")
        break
    
    # 每5个epoch显示一次进度
    if epoch % 5 == 0:
        total_time = time.time() - start_time
        avg_epoch_time = total_time / epoch
        estimated_total = avg_epoch_time * config['epochs']
        print(f"  📊 进度: {epoch/config['epochs']*100:.1f}%, 预计剩余时间: {(estimated_total - total_time)/60:.1f}分钟")

total_training_time = time.time() - start_time
print(f"\n🎉 训练完成!")
print(f"⏱️ 总训练时间: {total_training_time/60:.1f}分钟")
print(f"🏆 最佳 mIoU: {best_miou:.4f}")

In [None]:
# 保存训练历史
with open(f"{config['save_dir']}/history.json", 'w') as f:
    json.dump(history, f, indent=4)

print("💾 训练历史已保存")

# 绘制训练曲线
plot_training_history(
    history, 
    save_path=f"{config['save_dir']}/training_curves.png",
    title="Transformer Training History"
)

print("📈 训练曲线已保存")

## 5. 模型评估

In [None]:
# 加载最佳模型进行评估
print("📂 加载最佳模型进行评估...")

# 重新创建模型
eval_model = create_transformer_model(
    input_channels=config['input_channels'],
    temporal_steps=config['temporal_steps'],
    num_classes=data_info['num_classes'],
    patch_size=config['model_patch_size'],
    embed_dim=config['embed_dim'],
    num_layers=config['num_layers'],
    num_heads=config['num_heads'],
    mlp_ratio=config['mlp_ratio'],
    dropout=config['dropout']
).to(device)

# 加载最佳权重
checkpoint = load_checkpoint(f"{config['save_dir']}/best_model.pth", eval_model)
eval_model.eval()

print(f"✅ 模型加载成功 (训练epoch: {checkpoint['epoch']})")

In [None]:
# 在测试集上评估
print("🧪 在测试集上评估模型...")

eval_model.eval()
all_preds = []
all_targets = []
all_probs = []

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets_cpu = targets.numpy()
        
        outputs = eval_model(data)  # (batch, height, width, num_classes)
        probs = torch.softmax(outputs, dim=-1)
        _, predicted = outputs.max(-1)
        
        all_preds.append(predicted.cpu().numpy())
        all_targets.append(targets_cpu)
        all_probs.append(probs.cpu().numpy())

# 合并所有批次
all_preds = np.concatenate(all_preds).flatten()
all_targets = np.concatenate(all_targets).flatten()

# 计算详细指标
metrics = calculate_metrics(all_preds, all_targets, data_info['num_classes'])
metrics['class_names'] = data_info['class_names']

print(f"\n📊 测试集评估结果:")
print(f"  总体准确率: {metrics['overall_accuracy']:.4f}")
print(f"  平均准确率: {metrics['mean_accuracy']:.4f}")
print(f"  平均IoU: {metrics['mean_iou']:.4f}")
print(f"  平均精确率: {metrics['mean_precision']:.4f}")
print(f"  平均召回率: {metrics['mean_recall']:.4f}")

In [None]:
# 显示每个类别的详细指标
print(f"\n📋 各类别详细指标:")
print(f"{'类别':<15} {'准确率':<10} {'IoU':<10} {'精确率':<10} {'召回率':<10}")
print("-" * 60)

for i in range(data_info['num_classes']):
    class_name = data_info['class_names'][i]
    acc = metrics['per_class_accuracy'][i]
    iou = metrics['per_class_iou'][i]
    precision = metrics['per_class_precision'][i]
    recall = metrics['per_class_recall'][i]
    
    print(f"{class_name:<15} {acc:<10.4f} {iou:<10.4f} {precision:<10.4f} {recall:<10.4f}")

## 6. 结果可视化

In [None]:
# 绘制混淆矩阵
plot_confusion_matrix(
    metrics['confusion_matrix'],
    list(data_info['class_names'].values()),
    save_path=f"{config['save_dir']}/confusion_matrix.png",
    title="Transformer Confusion Matrix"
)

print("📊 混淆矩阵已保存")

In [None]:
# 绘制类别性能图表
plot_class_performance(
    metrics, 
    save_path=f"{config['save_dir']}/class_performance.png",
    title="Transformer Per-Class Performance"
)

print("📈 类别性能图表已保存")

In [None]:
# 可视化预测结果
print("🎨 生成预测可视化...")

# 获取一批测试数据
test_iter = iter(test_loader)
test_batch_x, test_batch_y = next(test_iter)

# 进行预测
eval_model.eval()
with torch.no_grad():
    test_batch_x_device = test_batch_x.to(device)
    test_outputs = eval_model(test_batch_x_device)
    _, test_predictions = test_outputs.max(-1)

# 可视化预测结果
visualize_predictions(
    test_batch_x[:6],  # 显示6个样本
    test_batch_y[:6],
    test_predictions[:6].cpu(),
    list(data_info['class_names'].values()),
    num_samples=6,
    save_path=f"{config['save_dir']}/predictions_visualization.png",
    title="Transformer Predictions vs Ground Truth"
)

print("🖼️ 预测可视化已保存")

## 7. 模型推理

In [None]:
# 演示单个patch的推理
print("🔮 演示模型推理...")

# 获取一个样本
sample_x = test_batch_x[0:1].to(device)  # 取第一个样本
sample_y = test_batch_y[0:1]

# 推理
eval_model.eval()
with torch.no_grad():
    # 前向传播
    inference_start = time.time()
    outputs = eval_model(sample_x)
    inference_time = time.time() - inference_start
    
    # 获取预测和置信度
    probs = torch.softmax(outputs, dim=-1)
    confidence, predictions = probs.max(-1)
    
    # 计算准确率
    accuracy = (predictions.cpu() == sample_y).float().mean().item()

print(f"\n⚡ 推理性能:")
print(f"  推理时间: {inference_time*1000:.2f}ms")
print(f"  输入形状: {sample_x.shape}")
print(f"  输出形状: {outputs.shape}")
print(f"  预测准确率: {accuracy:.4f}")
print(f"  平均置信度: {confidence.mean():.4f}")
print(f"  最小置信度: {confidence.min():.4f}")
print(f"  最大置信度: {confidence.max():.4f}")

In [None]:
# 可视化单个样本的推理结果
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# 原始RGB图像
rgb_img = sample_x[0, :, :, 0, :3].cpu().numpy()
rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
axes[0].imshow(rgb_img)
axes[0].set_title('Input RGB (t=0)', fontsize=14)
axes[0].axis('off')

# 真实标签
gt_img = sample_y[0].numpy()
im1 = axes[1].imshow(gt_img, cmap='tab10', vmin=0, vmax=7)
axes[1].set_title('Ground Truth', fontsize=14)
axes[1].axis('off')

# 预测结果
pred_img = predictions[0].cpu().numpy()
im2 = axes[2].imshow(pred_img, cmap='tab10', vmin=0, vmax=7)
axes[2].set_title(f'Prediction (Acc: {accuracy:.3f})', fontsize=14)
axes[2].axis('off')

# 置信度图
conf_img = confidence[0].cpu().numpy()
im3 = axes[3].imshow(conf_img, cmap='viridis')
axes[3].set_title('Confidence Map', fontsize=14)
axes[3].axis('off')
plt.colorbar(im3, ax=axes[3], label='Confidence')

plt.tight_layout()
plt.show()

print("🎨 单样本推理可视化完成")

## 总结

### 🎯 训练结果总结

In [None]:
print("=" * 80)
print("🎯 TRANSFORMER 农作物制图模型训练总结")
print("=" * 80)

print(f"\n📊 模型配置:")
print(f"  模型类型: Vision Transformer")
print(f"  嵌入维度: {config['embed_dim']}")
print(f"  层数: {config['num_layers']}")
print(f"  注意力头数: {config['num_heads']}")
print(f"  总参数量: {total_params:,}")

print(f"\n🏃 训练过程:")
print(f"  训练epochs: {len(history['train_loss'])}")
print(f"  总训练时间: {total_training_time/60:.1f}分钟")
print(f"  平均每epoch: {total_training_time/len(history['train_loss']):.1f}秒")
print(f"  最佳验证mIoU: {best_miou:.4f}")

print(f"\n📈 最终性能:")
print(f"  测试集总体准确率: {metrics['overall_accuracy']:.4f}")
print(f"  测试集平均IoU: {metrics['mean_iou']:.4f}")
print(f"  测试集平均精确率: {metrics['mean_precision']:.4f}")
print(f"  测试集平均召回率: {metrics['mean_recall']:.4f}")

print(f"\n💾 保存的文件:")
save_dir = Path(config['save_dir'])
saved_files = list(save_dir.glob('*'))
for file in saved_files:
    print(f"  {file.name}")

print(f"\n🚀 使用建议:")
print(f"  1. 模型已保存到: {config['save_dir']}/best_model.pth")
print(f"  2. 用于推理: python -m Transformer.inference --model-path {config['save_dir']}/best_model.pth")
print(f"  3. 继续训练: 加载checkpoint并调整学习率")
print(f"  4. 模型部署: 考虑量化或剪枝以减少计算需求")

print("\n" + "=" * 80)
print("🎉 Transformer农作物制图模型训练完成!")
print("=" * 80)