# TrackML Data Exploration & TrackFormer Training

这个notebook用于:
1. 探索TrackML数据集
2. 演示TrackFormer训练过程
3. 可视化结果

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from src.dataset import TrackMLDataset
from src.trackformer import create_trackformer_600mev
from src.losses import LossModule
from src.trainer import train
from src.visual import plot_event_3d, plot_event_rz

# 设置matplotlib显示中文
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']
plt.rcParams['axes.unicode_minus'] = False

## 1. 数据集探索

In [None]:
# 加载数据
data_dir = '../data/train_sample'
detectors = pd.read_csv('../data/detectors.csv')

# 选择一个事件进行探索
event_name = 'event000001000'
hits = pd.read_csv(f'{data_dir}/{event_name}-hits.csv')
particles = pd.read_csv(f'{data_dir}/{event_name}-particles.csv')
truth = pd.read_csv(f'{data_dir}/{event_name}-truth.csv')

print(f"事件 {event_name} 数据统计:")
print(f"总hits数: {len(hits)}")
print(f"总粒子数: {len(particles)}")
print(f"有效轨迹数 (nhits≥3): {len(particles[particles['nhits'] >= 3])}")
print(f"真实轨迹-hit关联数: {len(truth)}")

In [None]:
# 可视化hit分布
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# XY平面分布
ax1.scatter(hits['x'], hits['y'], alpha=0.6, s=1)
ax1.set_xlabel('X (mm)')
ax1.set_ylabel('Y (mm)')
ax1.set_title('Hit分布 (XY平面)')
ax1.grid(True, alpha=0.3)

# RZ平面分布
r = np.sqrt(hits['x']**2 + hits['y']**2)
ax2.scatter(hits['z'], r, alpha=0.6, s=1)
ax2.set_xlabel('Z (mm)')
ax2.set_ylabel('R (mm)')
ax2.set_title('Hit分布 (RZ平面)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 2. TrackFormer模型演示

In [None]:
# 创建数据集
import os
all_event_ids = sorted(
    set(
        fname.split('-')[0]
        for fname in os.listdir(data_dir) if fname.endswith('-hits.csv')
    )
)

# 使用前几个事件作为演示
demo_ids = all_event_ids[:5]
dataset = TrackMLDataset(data_dir, detectors, demo_ids)

print(f"演示数据集统计:")
print(f"事件数: {len(dataset)}")
print(f"特征维度: {dataset.feature_dim}")

# 查看一个样本
sample = dataset[0]
print(f"\n样本数据形状:")
for key, value in sample.items():
    print(f"{key}: {value.shape}")

In [None]:
# 创建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"使用设备: {device}")

model = create_trackformer_600mev(input_dim=dataset.feature_dim)
model = model.to(device)

print(f"\nTrackFormer模型参数统计:")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数数: {total_params:,}")
print(f"可训练参数数: {trainable_params:,}")

## 3. 快速训练演示

In [None]:
# 准备训练
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
loss_fn = LossModule()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("开始快速训练演示 (5个epoch)...")

# 训练几个epoch作为演示
demo_losses = []
for epoch in range(5):
    epoch_loss = train(model, loss_fn, dataloader, optimizer, device)
    demo_losses.append(epoch_loss)
    print(f"Epoch {epoch + 1}/5, Loss: {epoch_loss:.4f}")

# 绘制loss曲线
plt.figure(figsize=(8, 5))
plt.plot(range(1, 6), demo_losses, 'b-o', linewidth=2, markersize=8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('TrackFormer训练Loss (演示)')
plt.grid(True, alpha=0.3)
plt.show()

print("\n演示完成! 完整训练请使用: python main_train.py")

## 4. 模型推理演示

In [None]:
# 使用训练后的模型进行推理
model.eval()
with torch.no_grad():
    sample_batch = next(iter(dataloader))
    X = sample_batch['X'].squeeze(0).to(device)
    
    # 提取坐标
    x, y, z = X[:, 1], X[:, 2], X[:, 3]
    r = torch.sqrt(x**2 + y**2)
    phi = torch.atan2(y, x)
    coords = torch.stack([r, phi, z], dim=1)
    
    # 模型推理
    output = model(X, coords)
    
    print("模型输出:")
    for key, value in output.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: {value.shape}")
        else:
            print(f"{key}: {type(value)}")
    
    if 'track_logits' in output and output['track_logits'].numel() > 0:
        n_predicted_tracks = (torch.sigmoid(output['track_logits']) > 0.5).sum().item()
        print(f"\n预测的轨迹数: {n_predicted_tracks}")
    else:
        print("\n模型未预测到有效轨迹 (需要更多训练)")

## 总结

本notebook演示了:
1. **数据探索**: TrackML数据集的基本统计和可视化
2. **模型构建**: TrackFormer模型的创建和参数统计
3. **训练演示**: 快速训练过程和loss可视化
4. **推理测试**: 模型输出格式和预测结果

### 下一步:
- 运行 `python main_train.py --mode single --epochs 100` 进行完整单次训练
- 运行 `python main_train.py --mode kfold --epochs 50 --folds 5` 进行K折交叉验证
- 使用 `src/visual.py` 中的功能进行详细可视化分析