# 测试数据生成、训练、测试流程

这个notebook用于测试整个pipeline：
1. 数据生成
2. 模型创建
3. 训练流程
4. 测试流程


In [None]:
# 导入必要的库
import sys
import os

# 添加项目根目录到路径（notebook在notebooks/目录下）
current_dir = os.getcwd()
if 'notebooks' in current_dir:
    project_root = os.path.dirname(current_dir)
else:
    project_root = current_dir

if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("导入基础库...")
import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import DataLoader

print("导入项目模块...")
# 先导入基础模块（这些导入很快）
from data import get_loaders, GridDataGenerator, GridDataset, grid_collate
from models import get_model

# 延迟导入train和test（因为它们会导入analyze，而analyze导入很重的库）
# 我们会在需要时才导入它们

print("基础导入完成！")
print(f"项目根目录: {project_root}")
print("注意: train和test函数会在需要时导入（因为它们依赖analyze模块，导入较慢）")


## 1. 设置参数


In [None]:
# 创建一个简单的参数类
class Args:
    def __init__(self):
        # 设备设置
        self.use_cuda = False
        self.device = "cpu"
        self.seed = 0
        
        # 数据集设置
        self.use_images = False  # 使用索引而不是图像
        self.image_dir = 'images/faces16'
        self.training_regime = 'ungrouped'  # 使用ungrouped模式，更简单
        self.grid_size = 4
        self.ctx_order = 'first'
        self.inner_4x4 = False
        
        # 训练设置
        self.bs = 16  # 较小的batch size用于测试
        self.lr = 0.001
        self.n_steps = 100  # 少量步数用于测试
        self.print_every = 10
        self.test_every = 25
        self.analyze_every = 50
        
        # 模型设置
        self.model_name = 'rnn'
        self.ctx_scale = 1.0
        self.measure_grad_norm = False
        
        # 分析设置
        self.dim_red_method = 'pca'

# 创建参数对象
args = Args()

# 设置随机种子
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

print("参数设置完成！")
print(f"模型: {args.model_name}")
print(f"训练步数: {args.n_steps}")
print(f"Batch size: {args.bs}")


## 2. 测试数据生成


In [None]:
# 测试数据生成器
print("生成数据...")
grid = GridDataGenerator(
    training_regime=args.training_regime,
    size=args.grid_size,
    use_images=args.use_images,
    image_dir=args.image_dir,
    inner_4x4=args.inner_4x4
)

print(f"训练样本数: {len(grid.train)}")
print(f"测试样本数: {len(grid.test)}")
print(f"分析样本数: {len(grid.analyze)}")

# 查看一个样本
sample = grid.train[0]
ctx, loc1, loc2, y, info = sample
print(f"\n示例样本:")
print(f"  上下文: {ctx}")
print(f"  位置1: {loc1}, 位置2: {loc2}")
print(f"  标签: {y}")
print(f"  信息: {info}")


In [None]:
# 创建数据集和数据加载器
train_set = GridDataset(grid.train, grid.loc2idx, grid.idx2tensor)
test_set = GridDataset(grid.test, grid.loc2idx, grid.idx2tensor)

train_loader = DataLoader(train_set, batch_size=args.bs, shuffle=True, 
                          collate_fn=grid_collate)
test_loader = DataLoader(test_set, batch_size=args.bs, shuffle=False,
                         collate_fn=grid_collate)

# 测试一个batch
batch = next(iter(train_loader))
ctx, f1, f2, y, info = batch
print(f"Batch形状:")
print(f"  上下文: {ctx.shape}")
print(f"  葡萄酒1: {f1.shape}")
print(f"  葡萄酒2: {f2.shape}")
print(f"  标签: {y.shape}")
print(f"  信息键: {info.keys()}")


## 3. 测试模型创建


In [None]:
# 创建模型
print("创建模型...")
model = get_model(args)
model.to(args.device)

# 打印模型结构
print(f"\n模型: {args.model_name}")
print(f"参数数量: {sum(p.numel() for p in model.parameters())}")

# 测试前向传播
model.eval()
with torch.no_grad():
    test_ctx = ctx[:4].to(args.device)
    test_f1 = f1[:4].to(args.device)
    test_f2 = f2[:4].to(args.device)
    
    y_hat, reps = model(test_ctx, test_f1, test_f2)
    print(f"\n前向传播测试:")
    print(f"  输入形状: ctx={test_ctx.shape}, f1={test_f1.shape}, f2={test_f2.shape}")
    print(f"  输出形状: {y_hat.shape}")
    print(f"  表示形状: {[(k, v.shape) for k, v in reps.items()]}")
    print(f"  预测: {torch.argmax(y_hat, dim=1)}")
    print(f"  真实标签: {y[:4]}")


## 4. 测试训练流程


In [None]:
# 使用get_loaders获取数据（这样会设置args中的一些属性）
data = get_loaders(args)
train_loader, test_loader, analyze_loader = data

# 创建新模型用于训练
model = get_model(args)
model.to(args.device)

# 设置优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fn = nn.CrossEntropyLoss()

print("开始训练...")
model.train()

losses = []
for step, batch in enumerate(train_loader):
    if step >= args.n_steps:
        break
        
    optimizer.zero_grad()
    
    ctx, f1, f2, y, info = batch
    ctx = ctx.to(args.device)
    f1 = f1.to(args.device)
    f2 = f2.to(args.device)
    y = y.to(args.device)
    
    # 前向传播
    y_hat, _ = model(ctx, f1, f2)
    loss = loss_fn(y_hat, y)
    
    # 反向传播
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (step + 1) % args.print_every == 0:
        avg_loss = np.mean(losses[-args.print_every:])
        print(f"Step {step+1}/{args.n_steps}, Loss: {avg_loss:.4f}")

print(f"\n训练完成！平均损失: {np.mean(losses):.4f}")


## 5. 测试测试流程


In [None]:
# 测试模型
print("导入test模块（可能需要几秒钟）...")
from test import test

print("测试模型...")
test_results = test(model, test_loader, args)

print(f"\n测试结果:")
print(f"  准确率: {test_results['acc']:.4f}")
print(f"  一致准确率: {test_results['cong_acc']:.4f}")
print(f"  不一致准确率: {test_results['incong_acc']:.4f}")
print(f"  上下文0各等级准确率: {[f'{acc:.3f}' for acc in test_results['loc1_ctx0_acc']]}")
print(f"  上下文1各等级准确率: {[f'{acc:.3f}' for acc in test_results['loc1_ctx1_acc']]}")


## 6. 可视化训练过程


In [None]:
# 可视化损失曲线
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.show()

# 计算移动平均
window = 10
if len(losses) >= window:
    moving_avg = np.convolve(losses, np.ones(window)/window, mode='valid')
    plt.figure(figsize=(10, 5))
    plt.plot(losses, alpha=0.3, label='Raw')
    plt.plot(range(window-1, len(losses)), moving_avg, label=f'Moving Average (window={window})')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.title('Training Loss (with moving average)')
    plt.legend()
    plt.grid(True)
    plt.show()


## 7. 完整训练流程（使用train函数）


In [None]:
# 使用完整的train函数进行训练
print("导入train模块（可能需要几秒钟，因为它会导入analyze模块）...")
from train import train

print("使用完整训练函数...")

# 重新创建模型和数据
model = get_model(args)
model.to(args.device)
data = get_loaders(args)

# 运行一个完整的训练循环
results, analyses = train(0, model, data, args)

print(f"\n训练结果:")
print(f"  训练损失: {len(results['train_losses'])} 个记录点")
print(f"  训练准确率: {len(results['train_accs'])} 个记录点")
print(f"  测试准确率: {len(results['test_accs'])} 个记录点")

if results['train_accs']:
    print(f"\n最终准确率:")
    print(f"  训练: {results['train_accs'][-1]['acc']:.4f}")
    print(f"  测试: {results['test_accs'][-1]['acc']:.4f}")


## 总结

这个notebook展示了：
1. ✅ 数据生成和加载
2. ✅ 模型创建和前向传播
3. ✅ 训练流程
4. ✅ 测试流程
5. ✅ 完整训练函数的使用

所有组件都正常工作！
