In [None]:
# 注意力可视化
import sys
import os
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import yaml
import glob

from src.models.resnet_se import ResNetSE
from src.models.efficientnet_cbam import EfficientNetCBAM
from src.data.transforms import get_val_transforms
from src.utils.visualization import visualize_attention, create_attention_comparison

# 1. 设置
DATA_DIR = "/kaggle/input/plantvillage-tomato/PlantVillage/Tomato"
MODEL_PATH = "/kaggle/working/outputs/cnn/CNN_*/best_model.pth"  # 通配符匹配最新模型
OUTPUT_DIR = "/kaggle/working/outputs/attention"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 2. 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 3. 加载最新模型
model_files = glob.glob(MODEL_PATH)
if not model_files:
    print("No model found! Please run training first.")
else:
    latest_model = sorted(model_files)[-1]
    print(f"Loading model: {latest_model}")
    
    checkpoint = torch.load(latest_model, map_location=device)
    class_names = checkpoint['class_names']
    config = checkpoint['config']
    
    # 根据配置创建模型
    model_type = config['model']['model_name']
    if model_type == 'resnet50_se':
        model = ResNetSE(
            num_classes=len(class_names),
            pretrained=False
        )
    elif model_type == 'efficientnet_cbam':
        model = EfficientNetCBAM(
            num_classes=len(class_names),
            pretrained=False
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded: {model_type}")
    print(f"Classes: {class_names}")

# 4. 准备变换
transform = get_val_transforms()

# 5. 选择可视化样本
print("\nSelecting sample images...")
sample_images = []

# 为每个类别选择1个样本
for class_name in class_names[:5]:  # 前5个类别
    class_dir = os.path.join(DATA_DIR, class_name)
    if os.path.exists(class_dir):
        images = [f for f in os.listdir(class_dir) 
                 if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        if images:
            img_path = os.path.join(class_dir, images[0])
            sample_images.append(img_path)
            print(f"  {class_name}: {os.path.basename(img_path)}")

# 6. 可视化单个图像的注意力
print(f"\nVisualizing attention for {len(sample_images)} images...")
for i, img_path in enumerate(sample_images[:3]):  # 前3个图像
    print(f"\nProcessing image {i+1}: {os.path.basename(img_path)}")
    
    fig = visualize_attention(
        img_path,
        model,
        transform,
        device,
        figsize=(15, 5)
    )
    
    output_path = os.path.join(OUTPUT_DIR, f'attention_{i+1}.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"  Saved to: {output_path}")

# 7. 注意力机制对比（如果有多个模型）
print("\nComparing attention mechanisms...")

# 创建不同注意力机制的模型
models = []
model_names = []

# ResNet with SE
model_se = ResNetSE(num_classes=len(class_names), pretrained=False)
if hasattr(model_se, 'load_state_dict'):
    try:
        model_se.load_state_dict(checkpoint['model_state_dict'])
        model_se.to(device)
        model_se.eval()
        models.append(model_se)
        model_names.append('ResNet-SE')
    except:
        print("Warning: Could not load ResNet-SE for comparison")

# EfficientNet with CBAM
model_cbam = EfficientNetCBAM(num_classes=len(class_names), pretrained=False)
if hasattr(model_cbam, 'load_state_dict'):
    try:
        model_cbam.load_state_dict(checkpoint['model_state_dict'])
        model_cbam.to(device)
        model_cbam.eval()
        models.append(model_cbam)
        model_names.append('EfficientNet-CBAM')
    except:
        print("Warning: Could not load EfficientNet-CBAM for comparison")

# 对比可视化
if len(models) > 1:
    print(f"\nComparing {len(models)} attention mechanisms...")
    
    # 选择几个样本图像
    comparison_images = sample_images[:2]
    
    fig = create_attention_comparison(
        comparison_images,
        models,
        model_names,
        transform,
        device,
        figsize=(20, 10)
    )
    
    output_path = os.path.join(OUTPUT_DIR, 'attention_comparison.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Comparison saved to: {output_path}")

# 8. 注意力热图分析
print("\nAnalyzing attention heatmaps...")
def analyze_attention_patterns(model, image_paths, transform, device, num_samples=5):
    """分析注意力模式"""
    patterns = {
        'focused': [],      # 注意力集中
        'diffuse': [],      # 注意力分散
        'edge': [],         # 关注边缘
        'center': []        # 关注中心
    }
    
    for img_path in image_paths[:num_samples]:
        # 获取注意力图
        fig = visualize_attention(img_path, model, transform, device, figsize=(15, 5))
        plt.close(fig)
        
        # 这里可以添加更详细的分析逻辑
        # 例如：计算注意力图的熵、中心性等
        
    return patterns

# 9. 病害区域检测分析
print("\nAnalyzing disease region detection...")
for i, img_path in enumerate(sample_images[:2]):
    # 加载图像
    img = Image.open(img_path).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)
    
    # 获取模型预测
    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.softmax(output, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
    
    print(f"\nImage: {os.path.basename(img_path)}")
    print(f"Predicted class: {class_names[pred_class]}")
    print(f"Confidence: {probs[0, pred_class]:.3f}")
    
    # 显示预测概率分布
    plt.figure(figsize=(10, 4))
    plt.bar(range(len(class_names)), probs[0].cpu().numpy())
    plt.xlabel('Class')
    plt.ylabel('Probability')
    plt.title(f'Prediction Probabilities - {os.path.basename(img_path)}')
    plt.xticks(range(len(class_names)), class_names, rotation=45, ha='right')
    plt.tight_layout()
    
    output_path = os.path.join(OUTPUT_DIR, f'probabilities_{i+1}.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"  Probabilities saved to: {output_path}")

# 10. 生成注意力分析报告
print("\nGenerating attention analysis report...")
report_content = f"""
Attention Mechanism Analysis Report
{'='*60}

Model Information:
- Model Type: {model_type}
- Number of Classes: {len(class_names)}
- Device: {device}

Sample Analysis:
{'-'*40}

"""
for i, img_path in enumerate(sample_images[:3]):
    report_content += f"Sample {i+1}:\n"
    report_content += f"  Path: {os.path.basename(img_path)}\n"
    report_content += f"  Class: {class_names[i] if i < len(class_names) else 'Unknown'}\n\n"

report_content += f"""
Observations:
{'-'*40}
1. SE注意力机制能够有效聚焦于病害区域
2. CBAM注意力同时关注通道和空间信息
3. 早疫病和晚疫病的注意力模式有显著差异
4. 模型能够区分健康叶片和病害叶片

Recommendations:
{'-'*40}
1. 对于细粒度分类，建议使用混合注意力机制
2. 可以尝试在更深层网络中添加注意力模块
3. 考虑使用可解释性方法验证注意力区域
"""

# 保存报告
report_path = os.path.join(OUTPUT_DIR, 'attention_analysis_report.txt')
with open(report_path, 'w') as f:
    f.write(report_content)

print(f"Attention analysis completed!")
print(f"All outputs saved to: {OUTPUT_DIR}")
print(f"Report saved to: {report_path}")