In [None]:
# notebooks/local_test.ipynb - 完整的本地测试代码

# %% [markdown]
# # 胸部X光疾病分类 - 本地完整测试
# 
# 这个笔记本用于在本地全面测试所有代码模块

# %%
# 安装必要依赖（如果还没安装）
# !pip install torch torchvision numpy pandas matplotlib scikit-learn

# %%
import sys
import os
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 添加项目根目录到Python路径
project_root = os.path.abspath('..')
sys.path.append(project_root)

print(f"项目根目录: {project_root}")

# %%
# 1. 测试配置加载
print("="*50)
print("1. 测试配置加载")

try:
    config_path = os.path.join(project_root, 'config/config.yaml')
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # 使用本地路径
    config['paths'] = config['paths_local']
    
    print("✓ 配置加载成功")
    print(f"  配置内容:")
    print(f"  - 图像大小: {config['data']['image_size']}")
    print(f"  - 批量大小: {config['training']['batch_size']}")
    print(f"  - 模型: {config['model']['backbone']}")
    
except Exception as e:
    print(f"✗ 配置加载失败: {e}")

# %%
# 2. 测试数据加载
print("\n" + "="*50)
print("2. 测试数据加载")

try:
    from data.preprocess import load_and_preprocess_data
    
    # 检查数据文件是否存在
    csv_path = config['paths']['csv_path']
    images_dir = config['paths']['images_dir']
    
    if os.path.exists(csv_path):
        print(f"✓ 标签文件存在: {csv_path}")
    else:
        print(f"✗ 标签文件不存在: {csv_path}")
        # 尝试查找文件
        csv_files = [f for f in os.listdir(os.path.dirname(csv_path)) if f.endswith('.csv')]
        print(f"  找到的CSV文件: {csv_files}")
    
    if os.path.exists(images_dir):
        print(f"✓ 图像目录存在: {images_dir}")
        image_files = [f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        print(f"  找到图像数量: {len(image_files[:5])} (显示前5个: {image_files[:5]})")
    else:
        print(f"✗ 图像目录不存在: {images_dir}")
    
    # 尝试加载数据（使用小样本测试）
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        print(f"\n✓ 成功加载CSV文件")
        print(f"  数据形状: {df.shape}")
        print(f"  列名: {list(df.columns)}")
        
        # 显示前几行
        print(f"\n  数据预览:")
        print(df.head())
        
        # 检查标签列
        if len(df.columns) > 1:
            print(f"\n  标签分布:")
            for col in df.columns[1:5]:  # 只显示前几个标签
                if col in df.columns:
                    pos_rate = df[col].mean()
                    print(f"  {col}: {pos_rate:.3%} ({df[col].sum()} 正样本)")
    
except Exception as e:
    print(f"✗ 数据加载失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 3. 测试数据集类
print("\n" + "="*50)
print("3. 测试数据集类")

try:
    from data.dataset import ChestXRayDataset
    from data.dataset import ChestXRayDataset as dataset_class
    
    # 创建一个小样本DataFrame用于测试
    if 'df' in locals() and len(df) > 0:
        test_df = df.head(5).copy()
        
        # 创建数据集实例
        test_dataset = dataset_class(
            test_df,
            config['paths']['images_dir'],
            transform=dataset_class.get_transforms(config, 'train'),
            phase='train'
        )
        
        print(f"✓ 数据集类创建成功")
        print(f"  数据集大小: {len(test_dataset)}")
        
        # 测试获取单个样本
        try:
            image, labels = test_dataset[0]
            print(f"✓ 成功获取样本")
            print(f"  图像形状: {image.shape}")
            print(f"  标签形状: {labels.shape}")
            print(f"  标签值: {labels.numpy()}")
            
            # 可视化一个样本
            plt.figure(figsize=(10, 4))
            
            # 显示图像
            plt.subplot(1, 2, 1)
            # 反归一化图像以便显示
            mean = np.array(config['data']['mean'])
            std = np.array(config['data']['std'])
            img_np = image.numpy().transpose(1, 2, 0)
            img_np = std * img_np + mean
            img_np = np.clip(img_np, 0, 1)
            plt.imshow(img_np)
            plt.title(f"图像示例")
            plt.axis('off')
            
            # 显示标签
            plt.subplot(1, 2, 2)
            labels_np = labels.numpy()
            classes = test_dataset.class_names[:len(labels_np)]
            colors = ['red' if l > 0.5 else 'blue' for l in labels_np]
            plt.barh(classes, labels_np, color=colors)
            plt.xlabel('概率')
            plt.title('标签')
            plt.xlim([0, 1])
            plt.tight_layout()
            
            plt.savefig('../test_sample.png', dpi=100, bbox_inches='tight')
            plt.show()
            
        except Exception as e:
            print(f"✗ 获取样本失败: {e}")
            # 尝试查找问题
            print(f"  检查图像路径: {os.path.join(config['paths']['images_dir'], test_df.iloc[0]['Image Index'])}")
            
    else:
        print("✗ 没有可用的数据用于测试")
        
except Exception as e:
    print(f"✗ 数据集类测试失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 4. 测试模型创建
print("\n" + "="*50)
print("4. 测试模型创建")

try:
    import torch
    from models.model import DenseNet121MultiLabel
    
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"  使用设备: {device}")
    
    # 创建模型
    model = DenseNet121MultiLabel(
        num_classes=config['model']['num_classes'],
        pretrained=False  # 测试时不下载预训练权重，加快速度
    )
    
    print(f"✓ 模型创建成功")
    
    # 测试模型前向传播
    batch_size = 2
    test_input = torch.randn(batch_size, 3, 512, 512)
    
    with torch.no_grad():
        model.eval()
        output = model(test_input)
        
        print(f"✓ 前向传播测试成功")
        print(f"  输入形状: {test_input.shape}")
        print(f"  输出形状: {output.shape}")
        print(f"  输出范围: {output.min().item():.4f} 到 {output.max().item():.4f}")
        print(f"  输出示例: {output[0, :5].numpy()}")  # 显示前5个输出
        
    # 计算参数量
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\n  模型参数统计:")
    print(f"  总参数量: {total_params:,}")
    print(f"  可训练参数量: {trainable_params:,}")
    print(f"  不可训练参数量: {total_params - trainable_params:,}")
    
except Exception as e:
    print(f"✗ 模型测试失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 5. 测试数据加载器
print("\n" + "="*50)
print("5. 测试数据加载器")

try:
    from torch.utils.data import DataLoader
    
    if 'test_dataset' in locals():
        # 创建数据加载器
        dataloader = DataLoader(
            test_dataset,
            batch_size=min(2, len(test_dataset)),  # 小批量
            shuffle=True,
            num_workers=0  # 测试时设为0避免问题
        )
        
        print(f"✓ 数据加载器创建成功")
        
        # 测试一个批次
        for batch_idx, (images, labels) in enumerate(dataloader):
            print(f"  批次 {batch_idx}:")
            print(f"    图像形状: {images.shape}")
            print(f"    标签形状: {labels.shape}")
            
            # 显示批次统计
            print(f"    标签统计 - 每个类别的正样本数:")
            for i in range(labels.shape[1]):
                pos_count = labels[:, i].sum().item()
                if pos_count > 0:
                    class_name = test_dataset.class_names[i] if i < len(test_dataset.class_names) else f"Class_{i}"
                    print(f"      {class_name}: {pos_count}/{labels.shape[0]}")
            
            # 只测试第一个批次
            break
        
        print(f"\n✓ 数据加载器测试通过")
        
except Exception as e:
    print(f"✗ 数据加载器测试失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 6. 测试损失函数和优化器
print("\n" + "="*50)
print("6. 测试训练组件")

try:
    import torch.nn as nn
    import torch.optim as optim
    
    # 创建简单的模型和数据
    test_model = DenseNet121MultiLabel(num_classes=5, pretrained=False)
    test_input = torch.randn(2, 3, 224, 224)
    test_target = torch.randint(0, 2, (2, 5)).float()
    
    # 测试损失函数
    criterion = nn.BCELoss()
    optimizer = optim.Adam(test_model.parameters(), lr=0.001)
    
    # 训练步骤
    optimizer.zero_grad()
    output = test_model(test_input)
    loss = criterion(output, test_target)
    loss.backward()
    optimizer.step()
    
    print(f"✓ 训练步骤测试成功")
    print(f"  损失值: {loss.item():.6f}")
    
    # 测试自定义损失函数（如果存在）
    try:
        from training.losses import WeightedBCELoss, FocalLoss
        
        # 测试加权BCE损失
        pos_weight = torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0])
        weighted_criterion = WeightedBCELoss(pos_weight=pos_weight)
        weighted_loss = weighted_criterion(output, test_target)
        print(f"✓ 加权BCE损失测试成功: {weighted_loss.item():.6f}")
        
        # 测试Focal Loss
        focal_criterion = FocalLoss(alpha=0.25, gamma=2.0)
        focal_loss = focal_criterion(output, test_target)
        print(f"✓ Focal Loss测试成功: {focal_loss.item():.6f}")
        
    except ImportError:
        print(f"⚠ 自定义损失函数未找到，跳过测试")
    
except Exception as e:
    print(f"✗ 训练组件测试失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 7. 测试评估指标
print("\n" + "="*50)
print("7. 测试评估指标")

try:
    from sklearn.metrics import roc_auc_score, f1_score
    
    # 生成模拟的预测和标签
    np.random.seed(42)
    n_samples = 100
    n_classes = 5
    
    y_true = np.random.randint(0, 2, (n_samples, n_classes))
    y_pred = np.random.rand(n_samples, n_classes)
    
    # 计算指标
    auc_scores = []
    f1_scores = []
    
    for i in range(n_classes):
        if len(np.unique(y_true[:, i])) > 1:
            auc = roc_auc_score(y_true[:, i], y_pred[:, i])
            auc_scores.append(auc)
        
        f1 = f1_score(y_true[:, i], (y_pred[:, i] > 0.5).astype(int))
        f1_scores.append(f1)
    
    print(f"✓ 评估指标计算成功")
    print(f"  AUC分数: {np.mean(auc_scores):.4f}")
    print(f"  F1分数: {np.mean(f1_scores):.4f}")
    
    # 测试项目中的指标计算函数
    try:
        from training.metrics import calculate_metrics
        metrics = calculate_metrics(y_true, y_pred, threshold=0.5)
        print(f"✓ 项目指标函数测试成功")
        print(f"  平均AUC: {metrics['auc_mean']:.4f}")
        print(f"  平均F1: {metrics['f1_mean']:.4f}")
    except ImportError:
        print(f"⚠ 项目指标函数未找到，跳过测试")
    
except Exception as e:
    print(f"✗ 评估指标测试失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 8. 测试完整训练流程（最小版本）
print("\n" + "="*50)
print("8. 测试完整训练流程")

try:
    # 创建一个极小的训练循环进行测试
    import torch
    from torch.utils.data import DataLoader, TensorDataset
    
    # 创建模拟数据
    n_samples = 10
    n_features = 3 * 224 * 224  # 简化输入
    n_classes = 5
    
    # 模拟图像数据 (10张224x224的RGB图像)
    X = torch.randn(n_samples, 3, 224, 224)
    y = torch.randint(0, 2, (n_samples, n_classes)).float()
    
    # 创建数据集和数据加载器
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    # 创建简单模型
    class SimpleModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 16, 3, padding=1)
            self.pool = torch.nn.AdaptiveAvgPool2d(1)
            self.fc = torch.nn.Linear(16, n_classes)
            self.sigmoid = torch.nn.Sigmoid()
        
        def forward(self, x):
            x = self.conv(x)
            x = self.pool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return self.sigmoid(x)
    
    model = SimpleModel()
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # 训练1个epoch
    model.train()
    for batch_idx, (data, target) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        print(f"  批次 {batch_idx}: 损失 = {loss.item():.6f}")
        
        # 只运行一个批次测试
        if batch_idx == 0:
            break
    
    print(f"✓ 最小训练流程测试成功")
    
except Exception as e:
    print(f"✗ 训练流程测试失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 9. 测试可视化函数
print("\n" + "="*50)
print("9. 测试可视化函数")

try:
    # 生成一些测试数据用于可视化
    np.random.seed(42)
    
    # 模拟训练历史
    history = {
        'train_loss': [1.0, 0.8, 0.6, 0.5, 0.45, 0.4, 0.38, 0.36, 0.35, 0.34],
        'val_loss': [1.1, 0.9, 0.7, 0.6, 0.55, 0.5, 0.48, 0.47, 0.46, 0.45],
        'train_auc': [0.5, 0.6, 0.7, 0.75, 0.78, 0.8, 0.82, 0.83, 0.84, 0.85],
        'val_auc': [0.5, 0.58, 0.65, 0.7, 0.73, 0.75, 0.77, 0.78, 0.79, 0.8],
        'train_f1': [0.4, 0.5, 0.55, 0.6, 0.63, 0.65, 0.67, 0.68, 0.69, 0.7],
        'val_f1': [0.38, 0.45, 0.5, 0.55, 0.58, 0.6, 0.62, 0.63, 0.64, 0.65]
    }
    
    # 测试可视化函数
    try:
        from utils.visualization import plot_training_history
        
        # 创建输出目录
        output_dir = "../test_outputs"
        os.makedirs(output_dir, exist_ok=True)
        
        # 绘制训练历史
        plot_training_history(history, output_dir)
        print(f"✓ 训练历史可视化成功")
        print(f"  图像保存到: {output_dir}")
        
        # 检查文件是否存在
        history_path = os.path.join(output_dir, 'training_history.png')
        if os.path.exists(history_path):
            print(f"✓ 训练历史图像文件已生成")
            
            # 显示图像
            img = plt.imread(history_path)
            plt.figure(figsize=(10, 8))
            plt.imshow(img)
            plt.axis('off')
            plt.title('训练历史图像预览')
            plt.show()
        
    except ImportError:
        print(f"⚠ 可视化函数未找到，跳过测试")
    
except Exception as e:
    print(f"✗ 可视化测试失败: {e}")
    import traceback
    traceback.print_exc()

# %%
# 10. 总结测试结果
print("\n" + "="*50)
print("测试完成总结")
print("="*50)

print("\n✓ 本地测试已完成！")
print("\n下一步建议:")
print("1. 如果所有测试都通过，可以开始完整训练")
print("2. 如果有测试失败，请根据错误信息修复代码")
print("3. 确保数据路径配置正确")
print("4. 确保所有依赖包已安装")

# 创建requirements检查
print("\n依赖包检查:")
required_packages = ['torch', 'torchvision', 'numpy', 'pandas', 
                     'matplotlib', 'scikit-learn', 'scikit-image',
                     'Pillow', 'pyyaml', 'tqdm']

for package in required_packages:
    try:
        __import__(package)
        print(f"  ✓ {package}")
    except ImportError:
        print(f"  ✗ {package} 未安装")

print("\n运行以下命令安装缺少的包:")
print("pip install torch torchvision numpy pandas matplotlib scikit-learn scikit-image Pillow pyyaml tqdm")

SyntaxError: invalid syntax (ipython-input-2404118028.py, line 1)