In [None]:
# CNN模型训练
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import yaml
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from src.data.dataset import create_dataloaders
from src.models.resnet_se import ResNetSE
from src.models.efficientnet_cbam import EfficientNetCBAM
from src.training.trainer import Trainer
from src.training.metrics import MetricsCalculator, FineGrainedMetrics
from src.utils.visualization import plot_training_history
from src.utils.logger import ExperimentTracker

# 1. 设置
DATA_DIR = "/kaggle/input/plantvillage-tomato/PlantVillage/Tomato"
CONFIG_PATH = "../configs/cnn.yaml"
OUTPUT_DIR = "/kaggle/working/outputs/cnn"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 2. 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# 3. 加载配置
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# 4. 初始化实验跟踪器
experiment_name = f"CNN_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
tracker = ExperimentTracker(os.path.join(OUTPUT_DIR, experiment_name))
tracker.log_config(config)

# 5. 创建数据加载器
print("Creating data loaders...")
train_loader, val_loader, test_loader, class_names = create_dataloaders(
    DATA_DIR,
    batch_size=config['data']['batch_size'],
    num_workers=config['data']['num_workers']
)

print(f"Classes: {class_names}")
print(f"Number of classes: {len(class_names)}")

# 6. 创建模型
print("\nCreating model...")
model_type = config['model']['model_name']

if model_type == 'resnet50_se':
    model = ResNetSE(
        num_classes=len(class_names),
        pretrained=config['model']['pretrained']
    )
elif model_type == 'efficientnet_cbam':
    model = EfficientNetCBAM(
        num_classes=len(class_names),
        pretrained=config['model']['pretrained']
    )
else:
    raise ValueError(f"Unknown model type: {model_type}")

print(f"Model created: {model_type}")
tracker.log_message(f"Model type: {model_type}")

# 7. 创建训练器
print("\nCreating trainer...")
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    config=config,
    class_names=class_names
)

# 8. 训练模型
print("\nStarting training...")
epochs = config['training']['epochs']
history = trainer.train(epochs)

# 9. 绘制训练历史
print("\nPlotting training history...")
fig_history = plot_training_history(history, figsize=(15, 10))
history_path = os.path.join(tracker.log_dir, 'training_history.png')
plt.savefig(history_path, dpi=300, bbox_inches='tight')
plt.show()
tracker.log_artifact('figure', history_path, 'Training history plot')

# 10. 测试模型
print("\nTesting model...")
test_results = trainer.test()

# 11. 计算详细指标
print("\nComputing detailed metrics...")
metrics_calc = MetricsCalculator(len(class_names), class_names)
detailed_metrics = metrics_calc.compute_metrics(
    test_results['true_labels'],
    test_results['predictions'],
    test_results['probabilities']
)

# 记录指标
tracker.log_metrics(detailed_metrics, "test")

# 12. 细粒度分析
print("\nPerforming fine-grained analysis...")
fg_metrics = FineGrainedMetrics()

# 分析相似类别混淆
similar_classes = []
for i, name_i in enumerate(class_names):
    for j, name_j in enumerate(class_names):
        if i < j:
            # 基于名称相似性（如早疫病 vs 晚疫病）
            if ('early' in name_i.lower() and 'late' in name_j.lower()) or \
               ('late' in name_i.lower() and 'early' in name_j.lower()):
                similar_classes.append((i, j))

if similar_classes:
    confusion_results = fg_metrics.compute_similarity_confusion(
        detailed_metrics['confusion_matrix'],
        similar_classes
    )
    print("\nSimilar class confusion analysis:")
    for pair, rate in confusion_results.items():
        print(f"  {pair}: {rate:.4f}")
    
    tracker.log_metrics(confusion_results, "fine_grained")

# 分析困难样本
hard_samples = fg_metrics.analyze_hard_samples(
    test_results['predictions'],
    test_results['probabilities'],
    test_results['true_labels'],
    test_results['image_paths'],
    top_k=5
)

print("\nHard sample analysis saved")

# 13. 可视化结果
print("\nGenerating visualizations...")

# 混淆矩阵
fig_cm = metrics_calc.plot_confusion_matrix(
    detailed_metrics['confusion_matrix'],
    title=f'{model_type} - Confusion Matrix',
    figsize=(12, 10)
)
cm_path = os.path.join(tracker.log_dir, 'confusion_matrix.png')
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
plt.show()
tracker.log_artifact('figure', cm_path, 'Confusion matrix')

# ROC曲线
fig_roc = metrics_calc.plot_roc_curves(
    test_results['true_labels'],
    test_results['probabilities'],
    figsize=(12, 10)
)
roc_path = os.path.join(tracker.log_dir, 'roc_curves.png')
plt.savefig(roc_path, dpi=300, bbox_inches='tight')
plt.show()
tracker.log_artifact('figure', roc_path, 'ROC curves')

# 预测结果可视化
fig_pred = visualize_predictions(
    test_results,
    class_names,
    num_samples=12,
    figsize=(20, 15)
)
pred_path = os.path.join(tracker.log_dir, 'predictions.png')
plt.savefig(pred_path, dpi=300, bbox_inches='tight')
plt.show()
tracker.log_artifact('figure', pred_path, 'Prediction samples')

# 14. 保存最佳模型
best_model_path = os.path.join(tracker.log_dir, 'best_model.pth')
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'config': config,
    'test_accuracy': detailed_metrics['accuracy']
}, best_model_path)

tracker.log_artifact('model', best_model_path, 'Best model checkpoint')

# 15. 保存实验结果
experiment_file = tracker.save_experiment()
print(f"\nExperiment completed!")
print(f"Test Accuracy: {detailed_metrics['accuracy']:.4f}")
print(f"Results saved to: {tracker.log_dir}")
print(f"Experiment file: {experiment_file}")

# 16. 模型比较（如果有多个模型）
print("\nGenerating model comparison...")
comparison_data = {
    'Model': [model_type],
    'Accuracy': [detailed_metrics['accuracy']],
    'Precision': [detailed_metrics['precision']],
    'Recall': [detailed_metrics['recall']],
    'F1-Score': [detailed_metrics['f1']]
}

if 'roc_auc' in detailed_metrics:
    comparison_data['ROC-AUC'] = [detailed_metrics['roc_auc']]

df_comparison = pd.DataFrame(comparison_data)
print("\nModel Performance:")
print(df_comparison.to_string(index=False))

# 保存比较结果
comparison_path = os.path.join(tracker.log_dir, 'model_comparison.csv')
df_comparison.to_csv(comparison_path, index=False)
tracker.log_artifact('data', comparison_path, 'Model comparison CSV')