# Swin Transformer 农作物制图训练和评估\n\n本notebook演示如何使用Swin Transformer模型进行多光谱时序数据的农作物分类任务。Swin Transformer采用层次化设计和移位窗口注意力机制，在保持高精度的同时实现了线性计算复杂度。\n\n## 🔄 双模式设计\n\n**📋 本notebook支持两种运行模式：**\n\n### 🧪 TEST模式 \n- **目的**: 本地快速测试代码正确性\n- **配置**: 轻量级模型，小数据集，少量epochs\n- **用时**: ~3分钟\n- **适用**: 开发调试，验证代码逻辑\n- **环境**: 本地CPU/GPU，低内存要求\n\n### 🚀 TRAIN模式\n- **目的**: 云端GPU完整训练获得最佳性能\n- **配置**: 标准模型，完整数据集，完整epochs\n- **用时**: ~8-16小时\n- **适用**: 正式训练，获得部署模型\n- **环境**: 云端GPU，高性能要求\n\n**💡 使用建议：**\n1. 本地开发时使用TEST模式验证代码\n2. 确认无误后切换到TRAIN模式进行完整训练\n3. 在云端GPU上运行TRAIN模式以获得最佳效果\n\n---\n\n## 目录\n1. [运行模式配置](#运行模式配置)\n2. [环境设置和导入](#环境设置和导入)\n3. [数据加载和探索](#数据加载和探索)\n4. [模型创建和配置](#模型创建和配置)\n5. [训练过程](#训练过程)\n6. [模型评估](#模型评估)\n7. [结果可视化](#结果可视化)\n8. [窗口注意力分析](#窗口注意力分析)\n9. [模型推理](#模型推理)"

In [None]:
# 🔧 运行模式配置
# 设置 RUNNING_MODE 来选择运行模式
# - "TEST": 本地测试模式，快速验证代码正确性
# - "TRAIN": 云端训练模式，完整训练流程

RUNNING_MODE = "TEST"  # 👈 修改这里来切换模式: "TEST" 或 "TRAIN"

print(f"🔄 当前运行模式: {RUNNING_MODE}")

if RUNNING_MODE == "TEST":
    print("🧪 TEST模式 - 用于本地测试:")
    print("  ✓ 使用小数据集")
    print("  ✓ 快速训练 (5 epochs)")
    print("  ✓ 小模型配置")
    print("  ✓ 快速验证代码正确性")
elif RUNNING_MODE == "TRAIN":
    print("🚀 TRAIN模式 - 用于云端GPU训练:")
    print("  ✓ 使用完整数据集")
    print("  ✓ 完整训练 (50+ epochs)")
    print("  ✓ 标准模型配置")
    print("  ✓ 最佳性能优化")
else:
    raise ValueError("RUNNING_MODE 必须是 'TEST' 或 'TRAIN'")

## 1. 运行模式配置

选择运行模式来适配不同的使用场景：
- **TEST模式**：快速验证代码正确性，适合本地开发
- **TRAIN模式**：完整训练流程，适合云端GPU训练

# Swin Transformer 农作物制图训练和评估

本notebook演示如何使用Swin Transformer模型进行多光谱时序数据的农作物分类任务。Swin Transformer采用层次化设计和移位窗口注意力机制，在保持高精度的同时实现了线性计算复杂度。

## 目录
1. [环境设置和导入](#环境设置和导入)
2. [数据加载和探索](#数据加载和探索)
3. [模型创建和配置](#模型创建和配置)
4. [训练过程](#训练过程)
5. [模型评估](#模型评估)
6. [结果可视化](#结果可视化)
7. [窗口注意力分析](#窗口注意力分析)
8. [模型推理](#模型推理)

## 1. 环境设置和导入

In [None]:
# 导入必要的库
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import time
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# 检查是否有einops库（Swin Transformer需要）
try:
    import einops
    print("✓ einops available")
except ImportError:
    print("❌ einops not found. Please install: pip install einops")

# 设置matplotlib中文显示
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU设备: {torch.cuda.get_device_name(0)}")
    print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# 导入Swin Transformer模块
import sys
sys.path.append('..')

from model import create_swin_model, CropMappingSwinTransformer
from dataset import prepare_data, save_data_info, load_data_info
from train import (
    train_epoch, validate, FocalLoss, DiceLoss, CombinedLoss
)
from utils import (
    save_checkpoint, load_checkpoint, EarlyStopping,
    calculate_metrics, plot_training_history, plot_class_performance,
    plot_confusion_matrix, visualize_predictions, plot_window_attention
)

print("✅ Swin Transformer模块导入成功!")

## 2. 数据加载和探索

In [None]:
# 📋 根据运行模式配置参数
if RUNNING_MODE == "TEST":
    # 🧪 TEST模式配置 - 快速测试代码正确性
    config = {
        # 数据参数 - 使用更小的数据量
        'data_path': '../dataset',
        'patch_size': 32,     # 更小的patch大小
        'stride': 16,         # 更小的stride，减少数据量
        'test_size': 0.3,     # 更大的测试集比例
        'val_size': 0.2,      # 更大的验证集比例
        'batch_size': 2,      # 小batch size用于测试
        'num_workers': 2,     # 少量工作进程
        
        # 模型参数 - 轻量级配置
        'input_channels': 8,
        'temporal_steps': 28,
        'swin_patch_size': 4,
        'embed_dim': 48,      # 更小的嵌入维度
        'depths': [1, 1, 2, 1],  # 更浅的网络
        'num_heads': [2, 4, 8, 16],  # 更少的注意力头
        'window_size': 4,     # 更小的窗口
        'mlp_ratio': 2.0,     # 更小的MLP比例
        'dropout': 0.0,
        'attn_dropout': 0.0,
        'drop_path_rate': 0.05,
        
        # 训练参数 - 快速训练
        'epochs': 5,          # 很少的epochs用于测试
        'learning_rate': 1e-3, # 稍大的学习率快速收敛
        'weight_decay': 0.01,
        'gradient_accumulation_steps': 1,
        'max_grad_norm': 1.0,
        
        # 损失函数
        'use_combined_loss': False,  # 简单损失函数
        'focal_gamma': 2.0,
        'label_smoothing': 0.05,
        
        # 学习率调度
        'scheduler_type': 'none',  # 不使用调度器
        'min_lr': 1e-5,
        
        # 其他
        'patience': 5,        # 短耐心值
        'save_dir': '../Models/swin-transformer/test',  # 修改输出路径
        'augment_train': False  # 不使用数据增强
    }
    print("🧪 使用TEST模式配置 - 适合快速验证代码")
    
elif RUNNING_MODE == "TRAIN":
    # 🚀 TRAIN模式配置 - 完整训练流程
    config = {
        # 数据参数 - 使用完整数据集
        'data_path': '../dataset',
        'patch_size': 64,     # 标准patch大小
        'stride': 32,         # 标准stride
        'test_size': 0.2,
        'val_size': 0.1,
        'batch_size': 8,      # 适中的batch size
        'num_workers': 4,
        
        # 模型参数 - 标准配置
        'input_channels': 8,
        'temporal_steps': 28,
        'swin_patch_size': 4,
        'embed_dim': 96,      # 标准嵌入维度
        'depths': [2, 2, 6, 2],  # 标准深度
        'num_heads': [3, 6, 12, 24],  # 标准注意力头数
        'window_size': 7,     # 标准窗口大小
        'mlp_ratio': 4.0,
        'dropout': 0.0,
        'attn_dropout': 0.0,
        'drop_path_rate': 0.1,
        
        # 训练参数 - 完整训练
        'epochs': 100,        # 完整训练轮数
        'learning_rate': 1e-4, # 标准学习率
        'weight_decay': 0.05, # Swin推荐的权重衰减
        'gradient_accumulation_steps': 2,
        'max_grad_norm': 5.0,
        
        # 损失函数
        'use_combined_loss': True,  # 使用组合损失
        'focal_gamma': 2.5,
        'label_smoothing': 0.1,
        
        # 学习率调度
        'scheduler_type': 'cosine',
        'min_lr': 1e-6,
        
        # 其他
        'patience': 15,       # 长耐心值
        'save_dir': '../Models/swin-transformer/train',  # 修改输出路径
        'augment_train': True  # 使用数据增强
    }
    print("🚀 使用TRAIN模式配置 - 适合完整训练")

# 创建保存目录
Path(config['save_dir']).mkdir(parents=True, exist_ok=True)

print(f"\n📋 当前配置 ({RUNNING_MODE}模式):")
for key, value in config.items():
    print(f"  {key}: {value}")

if RUNNING_MODE == "TEST":
    print("\n⚠️ 注意: TEST模式使用简化配置，仅用于验证代码正确性!")
else:
    print("\n💪 TRAIN模式: 使用完整配置进行最佳性能训练!")

print(f"\n💾 模型输出目录: {config['save_dir']}")

In [None]:
# 🔄 加载和准备数据（适配不同模式）
print(f\"🔄 加载数据 ({RUNNING_MODE}模式)...\")\n\ntrain_loader, val_loader, test_loader, data_info = prepare_data(\n    data_path=config['data_path'],\n    patch_size=config['patch_size'],\n    stride=config['stride'],\n    test_size=config['test_size'],\n    val_size=config['val_size'],\n    batch_size=config['batch_size'],\n    num_workers=config['num_workers'],\n    augment_train=config['augment_train'],\n    swin_patch_size=config['swin_patch_size']\n)\n\nprint(f\"\\n📊 数据集信息 ({RUNNING_MODE}模式):\")\nprint(f\"  类别数量: {data_info['num_classes']}\")\nprint(f\"  训练批次: {len(train_loader)}\")\nprint(f\"  验证批次: {len(val_loader)}\")\nprint(f\"  测试批次: {len(test_loader)}\")\nprint(f\"  调整后patch大小: {data_info['patch_size']}\")\nprint(f\"  Swin patch大小: {data_info['swin_patch_size']}\")\nprint(f\"  输入形状: {data_info['input_shape']}\")\n\n# 估算数据量\ntotal_train_samples = len(train_loader) * config['batch_size']\ntotal_val_samples = len(val_loader) * config['batch_size']\nprint(f\"  训练样本数: ~{total_train_samples}\")\nprint(f\"  验证样本数: ~{total_val_samples}\")\n\nif RUNNING_MODE == \"TEST\":\n    print(f\"\\n🧪 TEST模式数据量: 使用小数据集快速测试\")\n    print(f\"  预计单epoch时间: ~30秒\")\n    print(f\"  总训练时间: ~3分钟\")\nelse:\n    print(f\"\\n🚀 TRAIN模式数据量: 使用完整数据集\")\n    print(f\"  预计单epoch时间: ~5-10分钟\")\n    print(f\"  总训练时间: ~8-16小时\")\n\n# 保存数据信息\nsave_data_info(data_info, f\"{config['save_dir']}/data_info.pkl\")\n\n# 显示类别信息\nprint(f\"\\n🏷️ 类别信息:\")\nfor idx, name in data_info['class_names'].items():\n    print(f\"  {idx}: {name}\")\n\nprint(f\"\\n⚖️ 类别权重: {data_info['class_weights'].numpy()}\")"

In [None]:
# 📊 数据可视化（适配不同模式）\n# 获取一个批次的数据进行可视化\ndata_iter = iter(train_loader)\nsample_batch_x, sample_batch_y = next(data_iter)\n\nprint(f\"样本批次形状 - X: {sample_batch_x.shape}, Y: {sample_batch_y.shape}\")\nprint(f\"X数值范围: [{sample_batch_x.min():.3f}, {sample_batch_x.max():.3f}]\")\nprint(f\"Y唯一值: {torch.unique(sample_batch_y)}\")\n\n# 根据模式决定可视化样本数量\nif RUNNING_MODE == \"TEST\":\n    num_samples = min(4, sample_batch_x.shape[0])  # TEST模式显示更少样本\n    print(f\"\\n🧪 TEST模式: 显示 {num_samples} 个样本进行快速验证\")\nelse:\n    num_samples = min(8, sample_batch_x.shape[0])  # TRAIN模式显示更多样本\n    print(f\"\\n🚀 TRAIN模式: 显示 {num_samples} 个样本进行详细分析\")\n\n# 可视化RGB样本\nrows = 2 if num_samples > 4 else 1\ncols = min(4, num_samples)\nfig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))\nfig.suptitle(f'Swin Transformer Input Samples - RGB Visualization ({RUNNING_MODE} Mode)', fontsize=16)\n\nif rows == 1:\n    axes = axes if num_samples > 1 else [axes]\nelse:\n    axes = axes.ravel()\n\nfor i in range(num_samples):\n    # 显示第一个时间步的RGB合成图像\n    rgb_img = sample_batch_x[i, :, :, 0, :3].numpy()\n    rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())\n    \n    ax_idx = i if rows == 1 else i\n    axes[ax_idx].imshow(rgb_img)\n    axes[ax_idx].set_title(f'Sample {i+1} - RGB (t=0)')\n    axes[ax_idx].axis('off')\n\nplt.tight_layout()\nplt.show()\n\n# 可视化对应的标签\nfig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))\nfig.suptitle(f'Ground Truth Labels ({RUNNING_MODE} Mode)', fontsize=16)\n\nif rows == 1:\n    axes = axes if num_samples > 1 else [axes]\nelse:\n    axes = axes.ravel()\n\nfor i in range(num_samples):\n    ax_idx = i if rows == 1 else i\n    im = axes[ax_idx].imshow(sample_batch_y[i].numpy(), cmap='tab10', vmin=0, vmax=7)\n    axes[ax_idx].set_title(f'Sample {i+1} - Labels')\n    axes[ax_idx].axis('off')\n\nplt.tight_layout()\nplt.show()\n\nif RUNNING_MODE == \"TEST\":\n    print(\"✓ 数据加载测试完成，代码运行正常!\")\nelse:\n    print(\"✓ 数据加载完成，准备开始完整训练!\")"

## 3. 模型创建和配置

In [None]:
# 设置设备
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(f"🖥️ 使用设备: {device}")

# 创建Swin Transformer模型
model = create_swin_model(
    input_channels=config['input_channels'],
    temporal_steps=config['temporal_steps'],
    num_classes=data_info['num_classes'],
    patch_size=config['swin_patch_size'],
    embed_dim=config['embed_dim'],
    depths=config['depths'],
    num_heads=config['num_heads'],
    window_size=config['window_size'],
    mlp_ratio=config['mlp_ratio'],
    drop_rate=config['dropout'],
    attn_drop_rate=config['attn_dropout'],
    drop_path_rate=config['drop_path_rate']
).to(device)

# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n🔧 Swin Transformer模型信息:")
print(f"  总参数量: {total_params:,}")
print(f"  可训练参数: {trainable_params:,}")
print(f"  模型大小: {total_params * 4 / (1024**2):.1f} MB (FP32)")
print(f"  嵌入维度: {config['embed_dim']}")
print(f"  层数配置: {config['depths']}")
print(f"  注意力头数: {config['num_heads']}")
print(f"  窗口大小: {config['window_size']}")

# 测试前向传播
model.eval()
with torch.no_grad():
    test_input = sample_batch_x[:2].to(device)
    test_output = model(test_input)
    print(f"\n🧪 前向传播测试:")
    print(f"  输入形状: {test_input.shape}")
    print(f"  输出形状: {test_output.shape}")
    print(f"  输出数值范围: [{test_output.min():.3f}, {test_output.max():.3f}]")

print("\n✅ Swin Transformer模型创建成功!")

In [None]:
# 配置损失函数
if config['use_combined_loss']:
    criterion = CombinedLoss(
        alpha=data_info['class_weights'].to(device),
        gamma=config['focal_gamma'],
        label_smoothing=config['label_smoothing']
    )
    print("🎯 使用组合损失函数 (Focal + Dice + CrossEntropy)")
else:
    criterion = nn.CrossEntropyLoss(
        weight=data_info['class_weights'].to(device),
        label_smoothing=config['label_smoothing']
    )
    print("🎯 使用交叉熵损失函数")

# 配置优化器 - Swin Transformer专用参数
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay'],
    betas=(0.9, 0.999),
    eps=1e-8
)

print(f"⚙️ 优化器: AdamW (lr={config['learning_rate']}, wd={config['weight_decay']})")

# 配置学习率调度器
if config['scheduler_type'] == 'cosine':
    from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,
        T_mult=2,
        eta_min=config['min_lr']
    )
    print(f"📈 学习率调度器: CosineAnnealingWarmRestarts (min_lr={config['min_lr']})")
else:
    scheduler = None
    print("📈 无学习率调度器")

# 混合精度训练
from torch.cuda.amp import GradScaler
scaler = GradScaler()
print("⚡ 启用混合精度训练")

# 早停
early_stopping = EarlyStopping(
    patience=config['patience'],
    verbose=True,
    save_path=f"{config['save_dir']}/best_model.pth"
)
print(f"⏹️ 早停机制 (patience={config['patience']})")

print("\n✅ 训练组件配置完成!")

## 4. 训练过程

In [None]:
# 🚀 训练准备（适配不同模式）\nhistory = {\n    'train_loss': [],\n    'train_acc': [],\n    'val_loss': [],\n    'val_acc': [],\n    'val_miou': [],\n    'learning_rates': []\n}\n\nbest_miou = 0.0\nstart_time = time.time()\n\nif RUNNING_MODE == \"TEST\":\n    print(f\"🧪 开始TEST模式训练 - Swin Transformer ({config['epochs']} epochs)\")\n    print(f\"⚡ 目标: 快速验证代码正确性\")\n    print(f\"🔧 模型配置: 轻量级 (embed_dim={config['embed_dim']}, depths={config['depths']})\")\nelse:\n    print(f\"🚀 开始TRAIN模式训练 - Swin Transformer ({config['epochs']} epochs)\")\n    print(f\"🎯 目标: 获得最佳模型性能\")\n    print(f\"💪 模型配置: 标准配置 (embed_dim={config['embed_dim']}, depths={config['depths']})\")\n\nprint(f\"⏰ 开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\")\nprint(f\"🎯 特征: {config['depths']}层深度, {config['embed_dim']}嵌入维度, {config['window_size']}窗口大小\")\nprint(\"=\" * 80)"

In [None]:
# 训练循环
for epoch in range(1, config['epochs'] + 1):
    epoch_start_time = time.time()
    
    print(f"\nEpoch {epoch}/{config['epochs']}")
    print("-" * 50)
    
    # 训练阶段
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device, scaler, epoch,
        scheduler=None,  # 我们手动调用scheduler
        scheduler_step_per_batch=False,
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        max_grad_norm=config['max_grad_norm']
    )
    
    # 验证阶段
    val_loss, val_acc, val_metrics = validate(
        model, val_loader, criterion, device, data_info['num_classes']
    )
    
    # 学习率调度
    if scheduler is not None:
        scheduler.step()
    
    # 获取当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    
    # 记录历史
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_miou'].append(val_metrics['mean_iou'])
    history['learning_rates'].append(current_lr)
    
    # 计算epoch时间
    epoch_time = time.time() - epoch_start_time
    
    # 打印结果
    print(f"Epoch {epoch} 结果:")
    print(f"  训练 - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
    print(f"  验证 - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
    print(f"  验证 mIoU: {val_metrics['mean_iou']:.4f}")
    print(f"  学习率: {current_lr:.2e}")
    print(f"  耗时: {epoch_time:.1f}s")
    
    # 保存最佳模型
    if val_metrics['mean_iou'] > best_miou:
        best_miou = val_metrics['mean_iou']
        save_checkpoint({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'scaler_state_dict': scaler.state_dict(),
            'best_miou': best_miou,
            'config': config,
            'metrics': val_metrics
        }, f"{config['save_dir']}/best_model.pth")
        print(f"  💾 新的最佳模型已保存! (mIoU: {best_miou:.4f})")
    
    # 早停检查
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("\n⏹️ 触发早停!")
        break
    
    # 每5个epoch显示一次进度
    if epoch % 5 == 0:
        total_time = time.time() - start_time
        avg_epoch_time = total_time / epoch
        estimated_total = avg_epoch_time * config['epochs']
        print(f"  📊 进度: {epoch/config['epochs']*100:.1f}%, 预计剩余时间: {(estimated_total - total_time)/60:.1f}分钟")

total_training_time = time.time() - start_time
print(f"\n🎉 Swin Transformer训练完成!")
print(f"⏱️ 总训练时间: {total_training_time/60:.1f}分钟")
print(f"🏆 最佳 mIoU: {best_miou:.4f}")

In [None]:
# 保存训练历史
with open(f"{config['save_dir']}/history.json", 'w') as f:
    json.dump(history, f, indent=4)

print("💾 训练历史已保存")

# 绘制训练曲线
plot_training_history(
    history, 
    save_path=f"{config['save_dir']}/training_curves.png",
    title="Swin Transformer Training History"
)

print("📈 训练曲线已保存")

## 5. 模型评估

In [None]:
# 加载最佳模型进行评估
print("📂 加载最佳Swin Transformer模型进行评估...")

# 重新创建模型
eval_model = create_swin_model(
    input_channels=config['input_channels'],
    temporal_steps=config['temporal_steps'],
    num_classes=data_info['num_classes'],
    patch_size=config['swin_patch_size'],
    embed_dim=config['embed_dim'],
    depths=config['depths'],
    num_heads=config['num_heads'],
    window_size=config['window_size'],
    mlp_ratio=config['mlp_ratio'],
    drop_rate=config['dropout'],
    attn_drop_rate=config['attn_dropout'],
    drop_path_rate=0.0  # 评估时不使用drop_path
).to(device)

# 加载最佳权重
checkpoint = load_checkpoint(f"{config['save_dir']}/best_model.pth", eval_model)
eval_model.eval()

print(f"✅ 模型加载成功 (训练epoch: {checkpoint['epoch']})")
print(f"🏆 训练时最佳mIoU: {checkpoint.get('best_miou', 'N/A'):.4f}")

In [None]:
# 在测试集上评估
print("🧪 在测试集上评估Swin Transformer模型...")

eval_model.eval()
all_preds = []
all_targets = []
all_probs = []

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets_cpu = targets.numpy()
        
        outputs = eval_model(data)  # (batch, height, width, num_classes)
        probs = torch.softmax(outputs, dim=-1)
        _, predicted = outputs.max(-1)
        
        all_preds.append(predicted.cpu().numpy())
        all_targets.append(targets_cpu)
        all_probs.append(probs.cpu().numpy())

# 合并所有批次
all_preds = np.concatenate(all_preds).flatten()
all_targets = np.concatenate(all_targets).flatten()

# 计算详细指标
metrics = calculate_metrics(all_preds, all_targets, data_info['num_classes'])
metrics['class_names'] = data_info['class_names']

print(f"\n📊 Swin Transformer测试集评估结果:")
print(f"  总体准确率: {metrics['overall_accuracy']:.4f}")
print(f"  平均准确率: {metrics['mean_accuracy']:.4f}")
print(f"  平均IoU: {metrics['mean_iou']:.4f}")
print(f"  平均精确率: {metrics['mean_precision']:.4f}")
print(f"  平均召回率: {metrics['mean_recall']:.4f}")

In [None]:
# 显示每个类别的详细指标
print(f"\n📋 各类别详细指标 (Swin Transformer):")
print(f"{'类别':<15} {'准确率':<10} {'IoU':<10} {'精确率':<10} {'召回率':<10}")
print("-" * 65)

for i in range(data_info['num_classes']):
    class_name = data_info['class_names'][i]
    acc = metrics['per_class_accuracy'][i]
    iou = metrics['per_class_iou'][i]
    precision = metrics['per_class_precision'][i]
    recall = metrics['per_class_recall'][i]
    
    print(f"{class_name:<15} {acc:<10.4f} {iou:<10.4f} {precision:<10.4f} {recall:<10.4f}")

# 找出表现最好和最差的类别
best_class_idx = np.argmax(metrics['per_class_iou'])
worst_class_idx = np.argmin(metrics['per_class_iou'])

print(f"\n🥇 表现最好的类别: {data_info['class_names'][best_class_idx]} (IoU: {metrics['per_class_iou'][best_class_idx]:.4f})")
print(f"🥉 表现最差的类别: {data_info['class_names'][worst_class_idx]} (IoU: {metrics['per_class_iou'][worst_class_idx]:.4f})")

## 6. 结果可视化

In [None]:
# 绘制混淆矩阵
plot_confusion_matrix(
    metrics['confusion_matrix'],
    list(data_info['class_names'].values()),
    save_path=f"{config['save_dir']}/confusion_matrix.png",
    title="Swin Transformer Confusion Matrix"
)

print("📊 混淆矩阵已保存")

In [None]:
# 绘制类别性能图表
plot_class_performance(
    metrics, 
    save_path=f"{config['save_dir']}/class_performance.png",
    title="Swin Transformer Per-Class Performance"
)

print("📈 类别性能图表已保存")

In [None]:
# 可视化预测结果
print("🎨 生成Swin Transformer预测可视化...")

# 获取一批测试数据
test_iter = iter(test_loader)
test_batch_x, test_batch_y = next(test_iter)

# 进行预测
eval_model.eval()
with torch.no_grad():
    test_batch_x_device = test_batch_x.to(device)
    test_outputs = eval_model(test_batch_x_device)
    _, test_predictions = test_outputs.max(-1)

# 可视化预测结果
visualize_predictions(
    test_batch_x[:6],  # 显示6个样本
    test_batch_y[:6],
    test_predictions[:6].cpu(),
    list(data_info['class_names'].values()),
    num_samples=6,
    save_path=f"{config['save_dir']}/predictions_visualization.png",
    title="Swin Transformer Predictions vs Ground Truth"
)

print("🖼️ 预测可视化已保存")

## 7. 窗口注意力分析

In [None]:
# Swin Transformer特有的窗口注意力分析
print("🔍 分析Swin Transformer的窗口注意力模式...")

# 选择一个样本进行详细分析
sample_idx = 0
sample_input = test_batch_x[sample_idx:sample_idx+1].to(device)
sample_target = test_batch_y[sample_idx]

eval_model.eval()
with torch.no_grad():
    # 获取预测
    sample_output = eval_model(sample_input)
    sample_probs = torch.softmax(sample_output, dim=-1)
    confidence, prediction = sample_probs.max(-1)
    
    # 尝试获取patch embedding特征
    try:
        patch_features, (patch_h, patch_w) = eval_model.patch_embed(sample_input)
        print(f"✓ Patch features shape: {patch_features.shape}")
        print(f"✓ Patch resolution: {patch_h} x {patch_w}")
    except Exception as e:
        print(f"⚠️ Cannot extract patch features: {e}")
        patch_features = None

# 创建窗口分析可视化
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
fig.suptitle('Swin Transformer Window Analysis', fontsize=16)

# 原始RGB图像
rgb_img = sample_input[0, :, :, 0, :3].cpu().numpy()
rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
axes[0, 0].imshow(rgb_img)
axes[0, 0].set_title('Input RGB (t=0)')
axes[0, 0].axis('off')

# 真实标签
axes[0, 1].imshow(sample_target.numpy(), cmap='tab10')
axes[0, 1].set_title('Ground Truth')
axes[0, 1].axis('off')

# 预测结果
pred_img = prediction[0].cpu().numpy()
axes[0, 2].imshow(pred_img, cmap='tab10')
axes[0, 2].set_title('Swin Prediction')
axes[0, 2].axis('off')

# 预测置信度
conf_img = confidence[0].cpu().numpy()
im1 = axes[0, 3].imshow(conf_img, cmap='viridis')
axes[0, 3].set_title('Prediction Confidence')
axes[0, 3].axis('off')
plt.colorbar(im1, ax=axes[0, 3])

# 窗口划分可视化
window_size = config['window_size']
H, W = sample_input.shape[1:3]

# 创建窗口网格
window_grid = np.zeros((H, W))
window_id = 0
for i in range(0, H, window_size):
    for j in range(0, W, window_size):
        end_i = min(i + window_size, H)
        end_j = min(j + window_size, W)
        window_grid[i:end_i, j:end_j] = window_id % 8  # 8种颜色循环
        window_id += 1

im2 = axes[1, 0].imshow(window_grid, cmap='Set3')
axes[1, 0].set_title(f'Window Partition ({window_size}x{window_size})')
axes[1, 0].axis('off')
plt.colorbar(im2, ax=axes[1, 0])

# patch特征可视化（如果可用）
if patch_features is not None and patch_features.shape[1] > 0:
    # 显示第一个特征通道
    feature_map = patch_features[0, :, 0].view(patch_h, patch_w).cpu().numpy()
    im3 = axes[1, 1].imshow(feature_map, cmap='coolwarm')
    axes[1, 1].set_title('Patch Features (Channel 0)')
    axes[1, 1].axis('off')
    plt.colorbar(im3, ax=axes[1, 1])
else:
    axes[1, 1].text(0.5, 0.5, 'Patch Features\nNot Available', 
                   ha='center', va='center', transform=axes[1, 1].transAxes)
    axes[1, 1].axis('off')

# 计算样本准确率
sample_accuracy = (pred_img == sample_target.numpy()).mean()

# 类别分布
unique, counts = np.unique(pred_img, return_counts=True)
axes[1, 2].bar(unique, counts, color='lightblue', edgecolor='navy')
axes[1, 2].set_xlabel('Class')
axes[1, 2].set_ylabel('Pixel Count')
axes[1, 2].set_title('Predicted Class Distribution')
axes[1, 2].grid(True, alpha=0.3)

# 性能摘要
axes[1, 3].text(0.1, 0.8, f'Sample Accuracy: {sample_accuracy:.3f}', transform=axes[1, 3].transAxes, fontsize=12)
axes[1, 3].text(0.1, 0.7, f'Mean Confidence: {conf_img.mean():.3f}', transform=axes[1, 3].transAxes, fontsize=12)
axes[1, 3].text(0.1, 0.6, f'Window Size: {window_size}x{window_size}', transform=axes[1, 3].transAxes, fontsize=12)
axes[1, 3].text(0.1, 0.5, f'Embed Dim: {config["embed_dim"]}', transform=axes[1, 3].transAxes, fontsize=12)
axes[1, 3].text(0.1, 0.4, f'Depths: {config["depths"]}', transform=axes[1, 3].transAxes, fontsize=12)
axes[1, 3].set_title('Model Configuration')
axes[1, 3].axis('off')

plt.tight_layout()
plt.savefig(f"{config['save_dir']}/window_analysis.png", dpi=300, bbox_inches='tight')
plt.show()

print("🔍 窗口注意力分析完成")

## 8. 模型推理

In [None]:
# 演示单个patch的推理
print("🔮 演示Swin Transformer模型推理...")

# 获取一个样本
sample_x = test_batch_x[0:1].to(device)  # 取第一个样本
sample_y = test_batch_y[0:1]

# 推理
eval_model.eval()
with torch.no_grad():
    # 前向传播
    inference_start = time.time()
    outputs = eval_model(sample_x)
    inference_time = time.time() - inference_start
    
    # 获取预测和置信度
    probs = torch.softmax(outputs, dim=-1)
    confidence, predictions = probs.max(-1)
    
    # 计算准确率
    accuracy = (predictions.cpu() == sample_y).float().mean().item()

print(f"\n⚡ Swin Transformer推理性能:")
print(f"  推理时间: {inference_time*1000:.2f}ms")
print(f"  输入形状: {sample_x.shape}")
print(f"  输出形状: {outputs.shape}")
print(f"  预测准确率: {accuracy:.4f}")
print(f"  平均置信度: {confidence.mean():.4f}")
print(f"  最小置信度: {confidence.min():.4f}")
print(f"  最大置信度: {confidence.max():.4f}")
print(f"  窗口大小: {config['window_size']}x{config['window_size']}")
print(f"  层次结构: {config['depths']} layers")

In [None]:
# 可视化单个样本的推理结果和窗口分析
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Swin Transformer Single Sample Analysis', fontsize=16)

# 原始RGB图像
rgb_img = sample_x[0, :, :, 0, :3].cpu().numpy()
rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
axes[0, 0].imshow(rgb_img)
axes[0, 0].set_title('Input RGB (t=0)', fontsize=14)
axes[0, 0].axis('off')

# 真实标签
gt_img = sample_y[0].numpy()
im1 = axes[0, 1].imshow(gt_img, cmap='tab10', vmin=0, vmax=7)
axes[0, 1].set_title('Ground Truth', fontsize=14)
axes[0, 1].axis('off')

# 预测结果
pred_img = predictions[0].cpu().numpy()
im2 = axes[0, 2].imshow(pred_img, cmap='tab10', vmin=0, vmax=7)
axes[0, 2].set_title(f'Swin Prediction (Acc: {accuracy:.3f})', fontsize=14)
axes[0, 2].axis('off')

# 置信度图
conf_img = confidence[0].cpu().numpy()
im3 = axes[1, 0].imshow(conf_img, cmap='viridis')
axes[1, 0].set_title('Prediction Confidence', fontsize=14)
axes[1, 0].axis('off')
plt.colorbar(im3, ax=axes[1, 0], label='Confidence')

# 误差图（红色表示错误预测）
error_map = (pred_img != gt_img).astype(float)
im4 = axes[1, 1].imshow(error_map, cmap='Reds')
axes[1, 1].set_title(f'Prediction Errors ({error_map.mean()*100:.1f}%)', fontsize=14)
axes[1, 1].axis('off')
plt.colorbar(im4, ax=axes[1, 1], label='Error')

# 置信度分布直方图
axes[1, 2].hist(conf_img.flatten(), bins=30, color='skyblue', alpha=0.7, edgecolor='black')
axes[1, 2].axvline(conf_img.mean(), color='red', linestyle='--', 
                  label=f'Mean: {conf_img.mean():.3f}')
axes[1, 2].set_xlabel('Confidence')
axes[1, 2].set_ylabel('Frequency')
axes[1, 2].set_title('Confidence Distribution')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{config['save_dir']}/single_sample_analysis.png", dpi=300, bbox_inches='tight')
plt.show()

print("🎨 单样本Swin Transformer推理可视化完成")

## 总结

### 🎯 Swin Transformer训练结果总结

In [None]:
print(\"=\" * 80)\nprint(f\"🎯 SWIN TRANSFORMER 农作物制图模型训练总结 ({RUNNING_MODE}模式)\")\nprint(\"=\" * 80)\n\nprint(f\"\\n🔄 运行模式: {RUNNING_MODE}\")\nif RUNNING_MODE == \"TEST\":\n    print(\"  ✓ 快速验证代码正确性\")\n    print(\"  ✓ 使用轻量级模型配置\")\n    print(\"  ✓ 少量训练轮数\")\n    print(\"  ✓ 适合本地开发环境\")\nelse:\n    print(\"  ✓ 完整训练流程\")\n    print(\"  ✓ 标准模型配置\")\n    print(\"  ✓ 充足训练轮数\")\n    print(\"  ✓ 适合云端GPU环境\")\n\nprint(f\"\\n📊 模型配置:\")\nprint(f\"  模型类型: Swin Transformer\")\nprint(f\"  嵌入维度: {config['embed_dim']}\")\nprint(f\"  层数配置: {config['depths']}\")\nprint(f\"  注意力头数: {config['num_heads']}\")\nprint(f\"  窗口大小: {config['window_size']}\")\nprint(f\"  总参数量: {total_params:,}\")\n\nprint(f\"\\n🏃 训练过程:\")\nprint(f\"  训练epochs: {len(history['train_loss'])}\")\nprint(f\"  总训练时间: {total_training_time/60:.1f}分钟\")\nprint(f\"  平均每epoch: {total_training_time/len(history['train_loss']):.1f}秒\")\nprint(f\"  最佳验证mIoU: {best_miou:.4f}\")\nif len(history['learning_rates']) > 0:\n    print(f\"  最终学习率: {history['learning_rates'][-1]:.2e}\")\n\nprint(f\"\\n📈 最终性能:\")\nprint(f\"  测试集总体准确率: {metrics['overall_accuracy']:.4f}\")\nprint(f\"  测试集平均IoU: {metrics['mean_iou']:.4f}\")\nprint(f\"  测试集平均精确率: {metrics['mean_precision']:.4f}\")\nprint(f\"  测试集平均召回率: {metrics['mean_recall']:.4f}\")\n\nprint(f\"\\n🎯 Swin Transformer特色:\")\nprint(f\"  ✓ 层次化特征提取: 4个阶段，分辨率逐步降低\")\nprint(f\"  ✓ 移位窗口注意力: 建立跨窗口连接\")\nprint(f\"  ✓ 线性复杂度: 相对于输入大小呈线性关系\")\nprint(f\"  ✓ 多尺度融合: 通过Patch Merging实现特征融合\")\n\nprint(f\"\\n💾 保存的文件:\")\nsave_dir = Path(config['save_dir'])\nsaved_files = list(save_dir.glob('*'))\nfor file in saved_files:\n    print(f\"  {file.name}\")\n\nprint(f\"\\n🚀 使用建议:\")\nif RUNNING_MODE == \"TEST\":\n    print(f\"  ✅ 代码测试完成，可以切换到TRAIN模式进行完整训练\")\n    print(f\"  📝 修改第一个cell: RUNNING_MODE = 'TRAIN'\")\n    print(f\"  🔄 重新运行notebook进行完整训练\")\n    print(f\"  ☁️ 建议在云端GPU上运行TRAIN模式\")\nelse:\n    print(f\"  1. 模型已保存到: {config['save_dir']}/best_model.pth\")\n    print(f\"  2. 用于推理: python -m Swin-Transformer.inference --model-path {config['save_dir']}/best_model.pth\")\n    print(f\"  3. 窗口分析: python -m Swin-Transformer.evaluate --analyze-windows\")\n    print(f\"  4. 模型部署: 考虑导出为ONNX格式\")\n\nprint(f\"\\n📊 与其他模型对比:\")\nprint(f\"  vs TCN: 更强的多尺度特征提取能力\")\nprint(f\"  vs Vision Transformer: 更高的计算效率和更好的归纳偏置\")\nprint(f\"  特点: 结合CNN的层次化设计和Transformer的全局建模能力\")\n\nprint(\"\\n\" + \"=\" * 80)\nif RUNNING_MODE == \"TEST\":\n    print(\"🧪 Swin Transformer代码测试完成!\")\n    print(\"✅ 所有组件运行正常，可以进行完整训练!\")\nelse:\n    print(\"🎉 Swin Transformer农作物制图模型训练完成!\")\n    print(\"🌟 层次化窗口注意力，线性复杂度，SOTA性能!\")\nprint(\"=\" * 80)"