# JiT Lightning 功能测试 Notebook

本 notebook 用于测试和演示 JiT Lightning 版本的各个模块功能：
1. DataModule - 数据加载和预处理
2. ModelModule - 模型创建和前向传播
3. Callbacks - 各种回调函数
4. Config - 配置文件加载
5. 完整训练流程测试
6. 图像生成测试


In [2]:
import sys
sys.path.append('/home/lick/project/JiT')

import os
import torch
import numpy as np
import yaml
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image

In [4]:
def test_datamodule():
    
    from datas.datamodule import JiTDataModule
    
    # 创建 DataModule
    datamodule = JiTDataModule(
        data_path='./data/imagenet',
        img_size=256,
        batch_size=4,
        num_workers=2,
        pin_memory=True,
    )
    
    print("✓ DataModule 创建成功")
    print(f"  - 图像尺寸: {datamodule.img_size}")
    print(f"  - 批次大小: {datamodule.batch_size}")
    
    # 设置数据集
    try:
        datamodule.setup(stage='fit')
        info = datamodule.get_dataset_info()
        
        print("\n数据集信息:")
        for key, value in info.items():
            print(f"  - {key}: {value}")
        
        # 获取一个 batch
        train_loader = datamodule.train_dataloader()
        images, labels = next(iter(train_loader))
        
        print(f"\n第一个 batch:")
        print(f"  - 图像形状: {images.shape}")
        print(f"  - 图像数据类型: {images.dtype}")
        print(f"  - 图像范围: [{images.min():.0f}, {images.max():.0f}]")
        print(f"  - 标签形状: {labels.shape}")
        print(f"  - 标签示例: {labels[:4].tolist()}")
        
        print("\n✓ DataModule 测试成功")
        return datamodule, images, labels
        
    except Exception as e:
        print(f"\n⚠️  数据集未找到或加载失败: {e}")
        print("   请确保 ImageNet 数据集位于 ./data/imagenet/train/")
        return None, None, None


def test_modelmodule():
    
    from models.modelmodule import JiTLightningModule
    
    # 创建模型
    model = JiTLightningModule(
        model_name='JiT-B/16',
        img_size=256,
        num_classes=1000,
        learning_rate=1e-4,
        num_sampling_steps=10,
        cfg_scale=2.9,
    )
    
    print("✓ ModelModule 创建成功")
    print(f"  - 模型名称: {model.hparams.model_name}")
    print(f"  - 图像尺寸: {model.hparams.img_size}")
    
    # 统计参数
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\n模型参数统计:")
    print(f"  - 总参数量: {total_params / 1e6:.2f}M")
    print(f"  - 可训练参数: {trainable_params / 1e6:.2f}M")
    
    # 测试前向传播
    model.eval()
    batch_size = 2
    test_images = torch.randn(batch_size, 3, 256, 256)
    test_timesteps = torch.rand(batch_size)
    test_labels = torch.randint(0, 1000, (batch_size,))
    
    print(f"\n测试前向传播:")
    with torch.no_grad():
        output = model(test_images, test_timesteps, test_labels)
    
    print(f"  - 输入形状: {test_images.shape}")
    print(f"  - 输出形状: {output.shape}")
    print(f"  - 输出范围: [{output.min():.4f}, {output.max():.4f}]")
    
    # 测试训练步骤
    model.train()
    batch = (test_images, test_labels)
    loss = model.training_step(batch, 0)
    
    print(f"\n测试训练步骤:")
    print(f"  - 损失值: {loss.item():.6f}")
    
    print("\n✓ ModelModule 测试成功")
    return model


def test_callbacks():
    
    from callbacks.jit_callbacks import (
        EMACallback,
        JiTModelCheckpoint,
        LearningRateSchedulerCallback,
        MetricLoggerCallback,
        create_default_callbacks,
    )
    
    print("✓ 成功导入所有 Callbacks")
    
    # 测试 EMA Callback
    ema_callback = EMACallback(ema_decay1=0.9999, ema_decay2=0.9996)
    print(f"\nEMA Callback:")
    print(f"  - EMA Decay 1: {ema_callback.ema_decay1}")
    print(f"  - EMA Decay 2: {ema_callback.ema_decay2}")
    
    # 测试 Checkpoint Callback
    checkpoint_callback = JiTModelCheckpoint(
        dirpath='./test_outputs/checkpoints',
        save_last_freq=5,
    )
    print(f"\nCheckpoint Callback:")
    print(f"  - 保存目录: {checkpoint_callback.dirpath}")
    print(f"  - Last 频率: {checkpoint_callback.save_last_freq}")
    
    # 测试创建所有默认 callbacks
    callbacks = create_default_callbacks(
        save_dir='./test_outputs',
        eval_freq=40,
        img_size=256,
        epochs=600,
        enable_fid_eval=False,
    )
    
    print(f"\n创建了 {len(callbacks)} 个默认 callbacks:")
    for i, cb in enumerate(callbacks, 1):
        print(f"  {i}. {cb.__class__.__name__}")
    
    print("\n✓ Callbacks 测试成功")
    return callbacks


def test_config():
    
    config_path = 'conf/config.yaml'
    
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    print(f"✓ 成功加载配置文件: {config_path}")
    print(f"\n配置文件包含的顶层键:")
    for key in config.keys():
        print(f"  - {key}")
    
    # Trainer 配置
    print("\nTrainer 配置:")
    for key in ['max_epochs', 'devices', 'strategy', 'precision']:
        if key in config['trainer']:
            print(f"  - {key}: {config['trainer'][key]}")
    
    # Model 配置
    print("\nModel 配置:")
    print(f"  - class_path: {config['model']['class_path']}")
    print(f"  - model_name: {config['model']['init_args']['model_name']}")
    print(f"  - img_size: {config['model']['init_args']['img_size']}")
    
    # Data 配置
    print("\nData 配置:")
    print(f"  - class_path: {config['data']['class_path']}")
    print(f"  - batch_size: {config['data']['init_args']['batch_size']}")
    
    # Callbacks 配置
    print(f"\nCallbacks 配置（共 {len(config['callbacks'])} 个）:")
    for i, cb_config in enumerate(config['callbacks'], 1):
        cb_name = cb_config['class_path'].split('.')[-1]
        print(f"  {i}. {cb_name}")
    
    print("\n✓ 配置文件测试成功")
    return config


def test_training(model, datamodule, callbacks):
    
    if datamodule is None:
        print("⚠️  跳过训练测试（数据集未加载）")
        return
    
    import lightning.pytorch as pl
    
    print("创建 Trainer（快速测试模式）...")
    
    trainer = pl.Trainer(
        max_epochs=1,
        limit_train_batches=3,
        callbacks=callbacks[:2],  # 只使用前两个 callbacks
        accelerator='cpu',
        enable_checkpointing=False,
        enable_progress_bar=True,
        enable_model_summary=False,
        logger=False,
    )
    
    print("✓ Trainer 创建成功")
    print(f"  - 最大 epochs: {trainer.max_epochs}")
    print(f"  - 限制 batches: {trainer.limit_train_batches}")
    
    print("\n开始训练测试（运行 3 个 batch）...")
    print("⚠️  这可能需要几分钟，请耐心等待...\n")
    
    try:
        trainer.fit(model, datamodule)
        print("\n✓ 训练流程测试成功")
    except Exception as e:
        print(f"\n✗ 训练测试失败: {e}")
        import traceback
        traceback.print_exc()


def test_generation(model):
    
    model.eval()
    
    print("准备生成图像...")
    print(f"  - 采样方法: {model.sampling_method}")
    print(f"  - 采样步数: {model.num_sampling_steps}")
    print(f"  - CFG 强度: {model.cfg_scale}")
    
    # 准备标签
    gen_labels = torch.tensor([1, 207, 360, 500])
    
    print(f"\n生成 {len(gen_labels)} 张图像...")
    print("⚠️  由于模型未训练，生成的图像将是噪声\n")
    
    with torch.no_grad():
        generated_images = model.generate(gen_labels, use_ema=False)
    
    print(f"✓ 图像生成完成")
    print(f"  - 生成图像形状: {generated_images.shape}")
    print(f"  - 图像范围: [{generated_images.min():.2f}, {generated_images.max():.2f}]")
    
    # 保存生成的图像
    save_dir = Path('./test_outputs/generated')
    save_dir.mkdir(parents=True, exist_ok=True)
    
    for idx, (img, label) in enumerate(zip(generated_images, gen_labels)):
        img_array = ((img + 1.0) * 127.5).clamp(0, 255).byte()
        img_array = img_array.permute(1, 2, 0).numpy()
        
        pil_img = Image.fromarray(img_array)
        save_path = save_dir / f'generated_class{label.item()}_{idx}.png'
        pil_img.save(save_path)
    
    print(f"\n✓ 图像已保存到: {save_dir}")
    print("\n✓ 图像生成测试成功")
    
    return generated_images, gen_labels

In [5]:
# 测试 1: DataModule
datamodule, images, labels = test_datamodule()

  from .autonotebook import tqdm as notebook_tqdm


✓ DataModule 创建成功
  - 图像尺寸: 256
  - 批次大小: 4

⚠️  数据集未找到或加载失败: [Errno 2] No such file or directory: './data/imagenet/train'
   请确保 ImageNet 数据集位于 ./data/imagenet/train/


In [None]:
# 测试 2: ModelModule
model = test_modelmodule()

In [None]:
# 测试 3: Callbacks
callbacks = test_callbacks()

In [None]:
# 测试 4: Config
config = test_config()

In [None]:
# 测试 5: 训练流程（可选，需要数据集）
if datamodule is not None:
    test_training(model, datamodule, callbacks)

In [None]:
# 测试 6: 图像生成
generated_images, gen_labels = test_generation(model)