## This pipeline is for model training, evaluation and visualization

In [3]:
# 导入标准库和第三方库
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch.cuda.amp import autocast, GradScaler
import os

# 导入同级目录的自定义模块
from mamba2_icl import Mamba2ICL
from generate_icl_data import (generate_linear_data, generate_gaussian_kernel_data,
                              generate_nonlinear_dynamical_data)
from evaluate_icl import evaluate

# 设置随机种子以确保可复现性
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# 确认设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 创建输出目录用于保存中间文件
output_dir = "experiment_outputs"
os.makedirs(output_dir, exist_ok=True)

ImportError: /root/code/mamba_icl_env/lib/python3.10/site-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE

In [None]:
# 数据生成参数（基于提案）
num_train_prompts = 10000
num_test_prompts = 1000
context_size = 20
d = 20

# 生成数据集
print("Generating datasets...")
train_linear = generate_linear_data(num_train_prompts, context_size, d)
test_linear = generate_linear_data(num_test_prompts, context_size, d)
train_gaussian = generate_gaussian_kernel_data(num_train_prompts, context_size, d)
test_gaussian = generate_gaussian_kernel_data(num_test_prompts, context_size, d)
train_dynamical = generate_nonlinear_dynamical_data(num_train_prompts, context_size, d)
test_dynamical = generate_nonlinear_dynamical_data(num_test_prompts, context_size, d)

# 保存数据集
print("Saving datasets...")
torch.save(train_linear, os.path.join(output_dir, "train_linear.pt"))
torch.save(test_linear, os.path.join(output_dir, "test_linear.pt"))
torch.save(train_gaussian, os.path.join(output_dir, "train_gaussian.pt"))
torch.save(test_gaussian, os.path.join(output_dir, "test_gaussian.pt"))
torch.save(train_dynamical, os.path.join(output_dir, "train_dynamical.pt"))
torch.save(test_dynamical, os.path.join(output_dir, "test_dynamical.pt"))

# 验证数据集大小
print(f"Train linear prompts: {len(train_linear)}")
print(f"Test linear prompts: {len(test_linear)}")
print(f"Train gaussian prompts: {len(train_gaussian)}")
print(f"Test gaussian prompts: {len(test_gaussian)}")
print(f"Train dynamical prompts: {len(train_dynamical)}")
print(f"Test dynamical prompts: {len(test_dynamical)}")

In [None]:
# 初始化 Mamba2 ICL 模型
model = Mamba2ICL(d_model=20, d_state=64, d_conv=4, expand=2).to(device)

# 定义优化器和混合精度训练工具
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

# 打印模型结构
print(model)

In [None]:
# 训练参数
num_epochs = 10
datasets = [
    ("linear", train_linear),
    ("gaussian", train_gaussian),
    ("dynamical", train_dynamical)
]

# 记录损失
loss_history = []

# 训练循环
print("Starting training...")
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    num_batches = 0
    
    for dataset_name, train_data in datasets:
        for X, Y, x_query, y_query in train_data:
            optimizer.zero_grad()
            with autocast():
                # 构建输入序列：20 个上下文 (x, y) + 1 个查询 (x, 0)
                input_seq = torch.cat([
                    torch.cat([X, Y.unsqueeze(-1)], dim=-1),
                    torch.cat([x_query.unsqueeze(0), torch.zeros(1, 1)], dim=-1)
                ], dim=0).to(device)
                output = model(input_seq.unsqueeze(0))[:, -1, :]  # 预测查询点
                loss = torch.nn.functional.mse_loss(output, y_query.to(device))
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            epoch_loss += loss.item()
            num_batches += 1
    
    avg_loss = epoch_loss / num_batches
    loss_history.append({"epoch": epoch + 1, "loss": avg_loss})
    print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.6f}")

# 保存损失记录
loss_df = pd.DataFrame(loss_history)
loss_df.to_csv(os.path.join(output_dir, "loss_history.csv"), index=False)
print("Loss history saved to experiment_outputs/loss_history.csv")

In [None]:
# 保存模型权重
model_path = os.path.join(output_dir, "mamba2_icl_model.pth")
torch.save(model.state_dict(), model_path)
print(f"Model weights saved to {model_path}")

In [None]:
# 评估模型
print("Evaluating model...")
results = {
    "Dataset": ["Linear", "Gaussian", "Dynamical"],
    "MSE": [
        evaluate(model, test_linear),
        evaluate(model, test_gaussian),
        evaluate(model, test_dynamical)
    ]
}

# 保存评估结果
results_df = pd.DataFrame(results)
results_df.to_csv(os.path.join(output_dir, "evaluation_results.csv"), index=False)
print("Evaluation results saved to experiment_outputs/evaluation_results.csv")
print(results_df)

In [None]:
# 绘制损失曲线
plt.figure(figsize=(10, 6))
plt.plot(loss_df["epoch"], loss_df["loss"], marker='o', label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Average MSE Loss")
plt.title("Training Loss over Epochs")
plt.legend()
plt.grid(True)

# 保存图像
plot_path = os.path.join(output_dir, "loss_plot.png")
plt.savefig(plot_path)
plt.show()
print(f"Loss plot saved to {plot_path}")