### 检查嵌入分布

In [None]:
import numpy as np
import matplotlib.pyplot as plt


# 读取数据
arr = np.load('data/NIPS34/all/exer_embeds.npy')
print(arr.shape)

# 计算每一行的L2范数
row_norms = np.linalg.norm(arr, axis=1)

# 绘制直方图
plt.figure(figsize=(10, 6))
plt.hist(row_norms, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
plt.title('Row-wise L2 Norm Distribution')
plt.xlabel('Norm Value')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.3)
plt.show()

# 返回计算结果（前5行示例）
print("前5行的范数值：")
print(row_norms[:5])

### 将csv格式的数据转换为NCDM标准输入格式

In [5]:
import pandas as pd
import json
import ast


# 指定CSV文件的路径
# senario = 'Algebra'
# senario = 'GeometryandMeasure'
# senario = 'Number'
senario = 'Algebra_cold'

# file_name = 'train'
# file_name = 'val'
file_name = 'test'

root_dir = f'data/NIPS34/{senario}'
file_in = f'{root_dir}/{file_name}.csv'
file_out = f'{root_dir}/{file_name}.json'

# 使用pandas的read_csv方法读取CSV文件
df = pd.read_csv(file_in)
# 重命名指定的列
df.rename(columns={'QuestionId':'exer_id', 'UserId':'user_id', 'IsCorrect':'score', 'Kc':'knowledge_code'}, inplace=True)
# 删除多余的列
df.drop(columns=['Time', 'AnswerValue', 'CorrectAnswer', 'AnswerId'], inplace=True)
# 将每一行转换为字典，并存储在列表中
dict_list = df.to_dict(orient='records')

# 格式规范化
for elem in dict_list:
    elem['knowledge_code'] = ast.literal_eval(elem['knowledge_code'])

# 保存为json
with open(file_out, 'w', encoding='utf-8') as json_file:
    json.dump(dict_list, json_file, indent=4, ensure_ascii=False)


### 逆操作：从json转csv

In [8]:
import pandas as pd
import json


# 指定场景参数
senario = 'longtail'

# file_name = 'train'
file_name = 'test'

root_dir = f'data/NIPS34/{senario}'
file_in = f'{root_dir}/{file_name}.json'
file_out = f'{root_dir}/{file_name}.csv'

# 读取JSON文件
with open(file_in, 'r', encoding='utf-8') as f:
    dict_list = json.load(f)

# 转换为DataFrame
df = pd.DataFrame(dict_list)

# 将列表转换回字符串格式
df['knowledge_code'] = df['knowledge_code'].astype(str)

# 列名逆向映射恢复
df.rename(columns={
    'exer_id': 'item_id',
    'user_id': 'user_id',
    'score': 'score'
    # 'knowledge_code': 'Kc'
}, inplace=True)

# # 添加原始被删除的列（用空值填充）
# for col in ['Time', 'AnswerValue', 'CorrectAnswer', 'AnswerId']:
#     df[col] = pd.NA  # 使用pandas的缺失值标记

# 按原始列顺序排序（假设原始列顺序如下）
column_order = [
    'user_id', 
    'item_id', 
    'score'
]
df = df[column_order]

# 保存为CSV
df.to_csv(file_out, index=False, encoding='utf-8')

### 统计高频和长尾KC

In [9]:
import pandas as pd


# 文件路径
senario = 'longtail'
file_name = 'test'
root_dir = f'data/NIPS34/{senario}'
csv_train = f'{root_dir}/train.csv'
csv_test = f'{root_dir}/{file_name}.csv'
output_high = f'{root_dir}/{file_name}_highfreq.csv'  # 高频结果文件
output_low = f'{root_dir}/{file_name}_longtail.csv'   # 低频结果文件

# 读取CSV文件
df = pd.read_csv(csv_train)
df_test = pd.read_csv(csv_test)

# 统计QuestionId出现次数
question_counts = df['item_id'].value_counts()

# 获取不同频次的题目ID列表
high_freq = question_counts[question_counts > 10].index.tolist()   # 计数>10的ID
low_freq = question_counts[question_counts <= 3].index.tolist()    # 计数≤3的ID

print(f'[高频题目] 出现超过10次的QuestionId：{len(high_freq)}个')
print(high_freq)

print(f'\n[低频题目] 出现不超过3次的QuestionId：{len(low_freq)}个')
print(low_freq)

# 提取高频题目数据
high_df = df_test[df_test['item_id'].isin(high_freq)]
# 提取低频题目数据
low_df = df_test[df_test['item_id'].isin(low_freq)]

# 保存结果（保留原始列结构）
high_df.to_csv(output_high, index=False)
low_df.to_csv(output_low, index=False)

print(f'高频数据已保存至：{output_high}（共 {len(high_df)} 行）')
print(f'低频数据已保存至：{output_low}（共 {len(low_df)} 行）')


[高频题目] 出现超过10次的QuestionId：594个
[199, 911, 625, 855, 520, 83, 528, 547, 533, 676, 856, 460, 50, 815, 47, 727, 178, 761, 502, 634, 599, 836, 494, 421, 312, 91, 22, 670, 639, 449, 391, 862, 409, 844, 342, 635, 941, 209, 463, 185, 290, 939, 236, 943, 446, 372, 283, 527, 749, 887, 337, 592, 638, 461, 293, 134, 831, 596, 637, 40, 813, 278, 265, 295, 45, 583, 614, 525, 664, 75, 745, 605, 333, 349, 311, 49, 39, 56, 101, 263, 885, 190, 587, 84, 360, 183, 335, 282, 457, 195, 626, 150, 868, 325, 585, 138, 850, 148, 808, 184, 417, 118, 249, 177, 923, 210, 38, 858, 297, 740, 751, 830, 790, 119, 8, 601, 932, 52, 383, 90, 212, 475, 804, 673, 82, 889, 732, 644, 834, 480, 522, 472, 787, 260, 852, 376, 945, 92, 629, 292, 193, 133, 882, 506, 451, 789, 53, 186, 24, 495, 154, 924, 767, 129, 645, 213, 816, 896, 439, 34, 611, 76, 361, 368, 402, 31, 16, 539, 707, 559, 88, 624, 615, 435, 304, 37, 211, 493, 800, 365, 327, 328, 158, 145, 554, 805, 369, 160, 423, 579, 908, 485, 403, 392, 147, 513, 86, 838, 116, 6

### check冷启动情景

In [2]:
import pandas as pd


# 文件路径
senario = 'student_all'
root_dir = f'data/NIPS34/{senario}'
csv_train = f'{root_dir}/train.csv'
csv_test = f'{root_dir}/test.csv'

# 读取CSV文件
df = pd.read_csv(csv_train)
df_test = pd.read_csv(csv_test)

# 统计QuestionId出现次数
question_counts = df['QuestionId'].value_counts()
# 获取不同频次的题目ID列表
pid_train = question_counts.index.tolist()   # 计数>10的ID

# 统计QuestionId出现次数
question_counts = df_test['QuestionId'].value_counts()
# 获取不同频次的题目ID列表
pid_test = question_counts.index.tolist()   # 计数>10的ID

print(set(pid_train) & set(pid_test))


set()


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
import numpy as np
from utils.loss import cl_and_reg_loss


class MyCDM_MLP_FFT(nn.Module):
    """
    使用全量微调BERT，需要把题目文本组织为字典形式的分词结果作为BERT输入
        # 全量微调模式下，所有BERT参数都可以训练，不需要冻结参数
        # 默认情况下，所有参数都已经是requires_grad=True的状态
    """
    def __init__(self, num_students, bert_model_name='bert-base-uncased', tau=0.1, lambda_reg=1.0, lambda_cl=0.5):
        super().__init__()
        self.tau = tau
        self.lambda_reg = lambda_reg
        self.lambda_cl = lambda_cl

        # 读取预训练BERT
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.d_model = self.bert.config.hidden_size
        # 学生能力嵌入层（u+和u-）
        self.stu_pos = nn.Embedding(
            num_embeddings=num_students,
            embedding_dim=self.d_model
        )
        self.stu_neg = nn.Embedding(
            num_embeddings=num_students,
            embedding_dim=self.d_model
        )
        # MLP预测头
        self.prednet = nn.Sequential(
            nn.Linear(3 * self.d_model, 2 * self.d_model),
            nn.Sigmoid(),
            nn.Dropout(p=0.5),
            nn.Linear(2 * self.d_model, self.d_model),
            nn.Sigmoid(),
            nn.Dropout(p=0.5),
            nn.Linear(self.d_model, 1)
        )
        # 初始化参数
        self.initialize()

    def initialize(self):
        """参数初始化"""
        nn.init.normal_(self.stu_pos.weight, mean=0.0, std=0.1)
        nn.init.normal_(self.stu_neg.weight, mean=0.0, std=0.1)
        # self.prednet
        for module in self.prednet:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, stu_ids, exer_in):
        # 学生的正负双模态表征
        u_pos = self.stu_pos(stu_ids)      # [batch_size, 768]
        u_neg = self.stu_neg(stu_ids)      # [batch_size, 768]
        # 题目表征
        bert_output = self.bert(           # [batch_size, 768]，提取CLS token作为题目嵌入
            input_ids=exer_in["input_ids"],
            attention_mask=exer_in["attention_mask"])
        exer_emb = bert_output.last_hidden_state[:, 0, :]

        """MLP预测头（1）不与问题交互"""
        # stu_emb = u_pos - u_neg                                       # [batch_size, 768]
        # logits = self.prednet(torch.cat([exer_emb, stu_emb], dim=1))  # [batch_size, 1]
        """MLP预测头（2）与问题交互"""
        logits = self.prednet(torch.cat([exer_emb, torch.multiply(exer_emb, u_pos), torch.multiply(exer_emb, u_neg)], dim=1))  # [batch_size, 1]
        output = torch.sigmoid(logits).squeeze(-1)                      # [batch_size]

        return output, exer_emb, u_pos, u_neg

    def get_loss(self, output, labels):
        """计算总损失"""
        preds, exer_emb, u_pos, u_neg = output
        # BCE损失 <=> 预测损失
        bce_loss = nn.BCELoss(reduction='mean')(preds, labels.squeeze())  # [batch_size]
        # 对比损失 & 正则化项
        loss_contrast, loss_reg = cl_and_reg_loss(exer_emb, u_pos, u_neg, labels, self.tau, delta=0.1, norm=True)
        # 总损失
        if self.lambda_cl < 1:
            total_loss = (1-self.lambda_cl)*bce_loss + self.lambda_cl*loss_contrast + self.lambda_reg*loss_reg
        else:
            total_loss = bce_loss + self.lambda_cl*loss_contrast + self.lambda_reg*loss_reg
        return total_loss, bce_loss, loss_contrast, loss_reg


In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import torch
import wandb
from sklearn.metrics import accuracy_score, roc_auc_score
import numpy as np
from datetime import datetime
import itertools
import json
from pathlib import Path

from utils.load_data import MyDataloader
from models.model import Baseline_IRT, Baseline_MLP, MyCDM_MLP, IRT, MyCDM_MSA, MyCDM_IRT


def parse_args():
    """
    解析命令行参数
    """
    parser = argparse.ArgumentParser(description='模型训练参数配置')

    # 实验配置
    parser.add_argument('--mode', choices=['baseline', 'freeze', 'fine-tune'], default='freeze', help='实验模式')
    parser.add_argument('--proj_name', type=str, default='freeze_250221_00', help='项目名称，用于保存检查点')
    parser.add_argument('--data', type=str, default='NIPS34', choices=['NIPS34'], help='使用的数据集名称')
    parser.add_argument('--scenario', type=str, default='all', choices=['all', 'Algebra', 'Algebra_cold', 'GeometryandMeasure', 'Number'], help='情景')

    # 训练超参数
    parser.add_argument('--bs', type=int, default=256, help='批次大小')
    parser.add_argument('--epoch', type=int, default=100, help='最大训练轮数')
    parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, help='学习率')

    # 模型配置
    parser.add_argument('--bert_path', type=str, help='BERT预训练模型路径',
                        default='/mnt/new_pfs/liming_team/auroraX/songchentao/llama/bert-base-uncased')
    parser.add_argument('--tau', type=float, default=0.1, help='温度系数')
    parser.add_argument('--lambda_cl', type=int, default=0.5, help='对比损失权重')
    parser.add_argument('--lambda_reg', type=int, default=0.1, help='正则损失权重')

    # 训练控制
    parser.add_argument('--grid_search', action='store_true', help='格点搜索调参')
    parser.add_argument('-esp', '--early_stop_patience', type=int, default=10, help='早停等待轮数')
    parser.add_argument('-ckpt', '--checkpoint_dir', type=str, default=None, help='检查点保存目录 (默认: ../checkpoints/{proj_name})')
    parser.add_argument('--verbose', type=int, default=0, help='是否显示epoch内进度')

    _args = parser.parse_args()

    # 后处理依赖参数
    if _args.checkpoint_dir is None:
        _args.checkpoint_dir = f'../checkpoints/{_args.proj_name}'

    return _args


def train(_model, _train_loader, _optimizer, _device, mode='baseline', verbose=0):
    """
    模型训练函数
        mode=['baseline','freeze','fine-tune']
    """
    _model.train()
    total_loss = 0.0
    pred_loss = 0.0
    cl_loss = 0.0
    reg_loss = 0.0

    count = 0
    for batch in _train_loader:
        if verbose and (count + 1) % 200 == 0:  # verbose=0时简化可视化输出
            _now = datetime.now()
            print(f'{_now.strftime("%Y-%m-%d %H:%M:%S")}, {count+1} of {len(_train_loader)}')
        count += 1

        # 数据准备
        stu_ids = batch['stu_id'].to(_device)                     # 学生ID
        labels = batch['label'].to(_device).float()               # 响应真实值
        if mode == 'fine-tune':
            input_ids = batch['input_ids'].to(_device)            # tokenize内容
            attention_mask = batch['attention_mask'].to(_device)  # 对应的mask
            # 组装为bert模型输入格式
            exer_in = {'input_ids': input_ids, 'attention_mask': attention_mask}
        else:
            exer_in = batch['exer_id'].to(_device)                # 题目ID

        # 梯度清零
        _optimizer.zero_grad()

        # 前向传播
        output = _model(stu_ids, exer_in)
        loss, loss_bce, loss_cl, loss_reg  = _model.get_loss(output, labels)

        # 反向传播
        loss.backward()
        _optimizer.step()

        # 累计损失
        total_loss += loss.item()
        pred_loss += loss_bce.item()
        if mode in ['freeze','fine-tune']:
            cl_loss += loss_cl.item()
            reg_loss += loss_reg.item()

    # 计算平均损失
    avg_total_loss = total_loss / len(_train_loader)
    avg_pred_loss = pred_loss / len(_train_loader)
    avg_cl_loss = cl_loss / len(_train_loader)
    avg_reg_loss = reg_loss / len(_train_loader)

    return avg_total_loss, avg_pred_loss, avg_cl_loss, avg_reg_loss


def val_or_test(_model, _data_loader, _device, mode='baseline', verbose=0):
    """
    模型验证or测试函数
    """
    _model.eval()
    pred_loss = 0.0
    # cl_loss 和 reg_loss 只在训练阶段有效，此处省去，因此total_loss也无意义
    all_preds = []
    all_labels = []

    count = 0
    with torch.no_grad():
        for batch in _data_loader:
            if verbose and (count + 1) % 200 == 0:
                _now = datetime.now()
                print(f'{_now.strftime("%Y-%m-%d %H:%M:%S")}, {count + 1} of {len(_data_loader)}')
            count += 1

            # 数据准备
            stu_ids = batch['stu_id'].to(_device)                     # 学生ID
            labels = batch['label'].to(_device).float()               # 响应真实值
            if mode == 'fine-tune':
                input_ids = batch['input_ids'].to(_device)            # tokenize内容
                attention_mask = batch['attention_mask'].to(_device)  # 对应的mask
                # 组装为bert模型输入格式
                exer_in = {'input_ids': input_ids, 'attention_mask': attention_mask}
            else:
                exer_in = batch['exer_id'].to(_device)                # 题目ID

            # 前向传播
            output = _model(stu_ids, exer_in)
            _, loss_bce, _, _ = _model.get_loss(output, labels)

            # 记录结果
            pred_loss += loss_bce.item()
            preds = output[0].detach().cpu().numpy()                  # 获取预测概率
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())                   # 获取真实值

    # 计算指标
    avg_pred_loss = pred_loss / len(_data_loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # 二值化预测结果
    binary_preds = (all_preds >= 0.5).astype(int)
    acc = accuracy_score(all_labels, binary_preds)
    auc = roc_auc_score(all_labels, all_preds)

    return avg_pred_loss, acc, auc, all_preds, all_labels


def my_gridsearch(_args):
    """
    自定义点搜索函数
    """
    # 定义参数网格
    param_grid = _args.param_grid
    # 生成参数组合
    keys, values = zip(*param_grid.items())
    param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
    # 声明最佳组合结果存储变量
    best_metrics = {'val_loss': float('inf')}
    best_params = {}

    # 遍历所有参数组合
    for i, params in enumerate(param_combinations):
        print(f"\n=== 正在训练参数组合 {i + 1}/{len(param_combinations)} ===")
        print("当前参数:", json.dumps(params, indent=2))

        # 修改待调参的参数
        _args.tau = params['tau']
        _args.lambda_reg = params['lambda_reg']
        _args.lambda_cl = params['lambda_cl']

        # 为当前参数组合创建独立目录
        param_hash = hash(frozenset(params.items()))
        _args.checkpoint_dir = f"../checkpoints/{_args.proj_name}/grid_{param_hash}"
        Path(_args.checkpoint_dir).mkdir(parents=True, exist_ok=True)

        # 运行训练流程
        current_metrics = main(_args)

        # 更新最佳结果
        if current_metrics['val_loss'] < best_metrics['val_loss']:
            best_metrics = current_metrics
            best_params = params.copy()

    # 输出最终结果
    print("\n=== 网格搜索完成 ===")
    print(f"最佳参数组合: {json.dumps(best_params, indent=2)}")
    print(f"对应验证指标: loss={best_metrics['val_loss']:.4f}, acc={best_metrics['val_acc']:.4f}, auc={best_metrics['val_auc']:.4f}")
    return best_params, best_metrics


def main(args):
    """
    主函数
    """
    # 自动设备选择
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 数据路径配置
    data_root = f'../data/{args.data}/{args.scenario}'
    train_path = f'{data_root}/train.json'
    val_path = f'{data_root}/val.json'
    test_path = f'{data_root}/test.json'
    exer_embeds_path = f'{data_root}/exer_embeds_bert.npy'
    exer_tokens_path = f'{data_root}/exer_tokens.json'

    # 读取数据配置
    with open(f'{data_root}/config.txt') as i_f:
        i_f.readline()
        student_n, exer_n, knowledge_n = list(map(eval, i_f.readline().split(',')))

    # 创建检查点目录（默认为 f'./checkpoints/{proj_name}'）
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    best_model_path = f'{args.checkpoint_dir}/best_model.pt'
    last_checkpoint_path = f'{args.checkpoint_dir}/last_checkpoint.pt'

    # 初始化训练状态
    best_val_loss = float('inf')
    best_val_acc = 0.0
    best_val_auc = 0.0
    early_stop_counter = 0
    start_epoch = 0

    # 加载模型
    dict_token = exer_tokens_path
    model = MyCDM_MLP_FFT(num_students=student_n,
                            bert_model_name=args.bert_path,
                            tau=args.tau,
                            lambda_reg=args.lambda_reg,
                            lambda_cl=args.lambda_cl,
                            ).to(device)

    # 设置优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    # 设置Dataloader
    train_loader = MyDataloader(
        batch_size=args.bs,
        id_to_token=dict_token,  # None or path( of json)
        data_set=train_path,
        offset=0,  # 使用原始ID的数据集
        pid_zero=0
    )
    val_loader = MyDataloader(
        batch_size=args.bs,
        id_to_token=dict_token,
        data_set=val_path,
        offset=0,
        pid_zero=0
    )
    test_loader = MyDataloader(
        batch_size=args.bs,
        id_to_token=dict_token,
        data_set=test_path,
        offset=0,
        pid_zero=0
    )

    # 断点续训检查
    if os.path.exists(last_checkpoint_path):
        # 加载模型和优化器
        checkpoint = torch.load(last_checkpoint_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        # 覆盖相应的训练状态记录参数，打印
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        early_stop_counter = checkpoint['early_stop_counter']
        print(f"加载检查点：从epoch {start_epoch}恢复训练，当前最佳val_loss={best_val_loss:.4f}")

    # 初始化wandb
    wandb.init(
        project=args.proj_name,
        config={**vars(args)},
        resume=True if start_epoch > 0 else False,
        reinit=True if args.grid_search else False  # 是否允许重复初始化
    )

    # 训练循环
    for epoch in range(args.epoch):
        print(f"\nEpoch {epoch + 1}/{args.epoch}:")

        now = datetime.now()
        print(now.strftime("%Y-%m-%d %H:%M:%S"), f', training epoch {epoch + 1}')
        train_total_loss, train_pred_loss, train_cl_loss, train_reg_loss = train(
            model, train_loader, optimizer, device, mode=args.mode, verbose=args.verbose)
        print(
            f"  Train Pred Loss: {train_pred_loss:.4f}, total Loss: {train_total_loss:.4f}, CL Loss: {train_cl_loss:.4f}, Reg Loss: {train_reg_loss:.4f} ")

        now = datetime.now()
        print(f'{now.strftime("%Y-%m-%d %H:%M:%S")}, validating epoch {epoch + 1}')
        val_pred_loss, val_acc, val_auc, _, _ = val_or_test(model, val_loader, device, mode=args.mode, verbose=args.verbose)
        print(f"  Val Pred Loss: {val_pred_loss:.4f} Acc: {val_acc:.4f} AUC: {val_auc:.4f}")

        now = datetime.now()
        print(f'{now.strftime("%Y-%m-%d %H:%M:%S")}, testing epoch {epoch + 1}')
        test_pred_loss, test_acc, test_auc, _, _ = val_or_test(model, test_loader, device, mode=args.mode, verbose=args.verbose)
        print(f"  Test Pred Loss: {test_pred_loss:.4f} Acc: {test_acc:.4f} AUC: {test_auc:.4f}")

        # 早停逻辑（改为AUC优先）
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_val_acc = val_acc
            best_val_loss = val_pred_loss
            early_stop_counter = 0
            # 保存最佳模型
            torch.save({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'best_val_auc': best_val_auc
            }, best_model_path)
            print(f"发现新最佳模型，val_auc={best_val_auc:.4f}，已保存至{best_model_path}")
        else:
            early_stop_counter += 1

        # 保存最新检查点（用于断点续训）
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'best_val_auc': best_val_auc,
            'early_stop_counter': early_stop_counter
        }, last_checkpoint_path)

        # 记录训练和验证指标
        wandb.log({
            "train/total_loss": train_total_loss,
            "train/pred_loss": train_pred_loss,
            "train/cl_loss": train_cl_loss,
            "train/reg_loss": train_reg_loss,
            "val/pred_loss": val_pred_loss,
            "val/acc": val_acc,
            "val/auc": val_auc,
            "epoch": epoch + 1,
            # "val/best_loss": best_val_loss,
            # "early_stop_counter": early_stop_counter
        })

        # 早停判断
        if early_stop_counter >= args.early_stop_patience:
            print(f"\n早停触发！连续{args.early_stop_patience}个epoch验证集无改进")
            break

    # 最终测试（加载最佳模型）
    if os.path.exists(best_model_path):
        print("\n加载最佳模型进行测试...")
        model.load_state_dict(torch.load(best_model_path)['model_state'])

    print('========================================================================')
    now = datetime.now()
    print(f'{now.strftime("%Y-%m-%d %H:%M:%S")}, testing...')
    test_pred_loss, test_acc, test_auc, y_pred, y_true = val_or_test(model, test_loader, device, mode=args.mode,
                                                                     verbose=args.verbose)
    print(f"\nFinal Test Results:")
    print(f"  Test Pred Loss: {test_pred_loss:.4f} Acc: {test_acc:.4f} AUC: {test_auc:.4f}")

    now = datetime.now()
    print(f'{now.strftime("%Y-%m-%d %H:%M:%S")}, finish.')
    print('========================================================================')

    # 保存真实值和预报值
    pass

    # 记录测试结果
    wandb.log({
        "test/pred_loss": test_pred_loss,
        "test/acc": test_acc,
        "test/auc": test_auc
    })

    # 记录最终结果
    wandb.finish()

    # 返回验证集指标用于网格搜索比较
    return {
        'val_loss': best_val_loss,
        'val_acc': best_val_acc,
        'val_auc': best_val_auc
    }


In [None]:
from accelerate import Accelerator
from accelerate.utils import set_seed
from deepspeed.ops.adam import DeepSpeedCPUAdam

def main(args):
    # 初始化accelerator（核心修改）
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,  # 从参数读取或固定为'fp16'
        gradient_accumulation_steps=args.grad_accum_steps,
        deepspeed_plugin={
            "zero_stage": 2,
            "offload_optimizer_device": "cpu",
            "offload_param_device": "none",
            "zero_force_ds_cpu_optimizer": False
        }
    )
    
    # 设置随机种子（确保多卡一致性）
    set_seed(args.seed)
    
    # 设备由accelerator自动管理
    device = accelerator.device

    # ... [原有数据路径配置代码保持不变] ...

    # 加载模型（注意先不执行to(device)）
    model = MyCDM_MLP_FFT(
        num_students=student_n,
        bert_model_name=args.bert_path,
        tau=args.tau,
        lambda_reg=args.lambda_reg,
        lambda_cl=args.lambda_cl
    )

    # 优化器改为DeepSpeed兼容版本
    optimizer = DeepSpeedCPUAdam(model.parameters(), lr=args.learning_rate)

    # 数据加载器（保持原有逻辑）
    train_loader = MyDataloader(...)
    val_loader = MyDataloader(...)
    test_loader = MyDataloader(...)

    # 使用accelerator准备组件（关键修改）
    model, optimizer, train_loader, val_loader, test_loader = accelerator.prepare(
        model, optimizer, train_loader, val_loader, test_loader
    )

    # 断点续训逻辑修改
    if os.path.exists(last_checkpoint_path):
        checkpoint = torch.load(last_checkpoint_path, map_location='cpu')
        accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        # ... [其他状态恢复逻辑保持不变] ...

    # 只在主进程初始化wandb
    if accelerator.is_main_process:
        wandb.init(...)

    # 训练循环改造
    for epoch in range(start_epoch, args.epoch):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            with accelerator.accumulate(model):
                # 数据已自动分配到对应设备
                stu_ids = batch["stu_ids"]
                exer_in = {
                    "input_ids": batch["input_ids"],
                    "attention_mask": batch["attention_mask"]
                }
                labels = batch["labels"]

                outputs = model(stu_ids, exer_in)
                loss, pred_loss, cl_loss, reg_loss = model.get_loss(outputs, labels)

                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                
                optimizer.step()
                optimizer.zero_grad()

                # 多卡间聚合损失
                total_loss += accelerator.gather(loss.detach()).mean().item()

        # 只在主进程打印和验证
        if accelerator.is_main_process:
            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
            
            # 验证和测试需要封装到函数中
            val_pred_loss, val_acc, val_auc = evaluate(
                accelerator, model, val_loader, "val"
            )
            test_pred_loss, test_acc, test_auc = evaluate(
                accelerator, model, test_loader, "test"
            )

            # 早停和保存逻辑
            if val_auc > best_val_auc:
                # 保存模型使用accelerator接口
                accelerator.save({
                    'epoch': epoch,
                    'model_state': accelerator.unwrap_model(model).state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'best_val_auc': val_auc
                }, best_model_path)

            # 保存检查点
            accelerator.save({
                'epoch': epoch,
                'model_state': accelerator.unwrap_model(model).state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'best_val_auc': best_val_auc,
                'early_stop_counter': early_stop_counter
            }, last_checkpoint_path)

            # wandb日志记录
            wandb.log(...)

    # 最终测试（主进程执行）
    if accelerator.is_main_process:
        accelerator.load_state(best_model_path)
        final_test_results = evaluate(accelerator, model, test_loader, "final_test")
        wandb.log(final_test_results)
        wandb.finish()

    return ...

# 新增评估函数
def evaluate(accelerator, model, dataloader, mode):
    model.eval()
    all_preds = []
    all_labels = []
    
    for batch in dataloader:
        with torch.no_grad():
            stu_ids = batch["stu_ids"]
            exer_in = {
                "input_ids": batch["input_ids"],
                "attention_mask": batch["attention_mask"]
            }
            labels = batch["labels"]
            
            outputs = model(stu_ids, exer_in)
            preds = outputs[0]
            
        # 收集所有设备的预测结果
        all_preds.append(accelerator.gather(preds))
        all_labels.append(accelerator.gather(labels))
    
    # 合并结果并计算指标
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    # 只在主进程计算指标
    if accelerator.is_main_process:
        acc = compute_accuracy(all_preds, all_labels)
        auc = compute_auc(all_preds, all_labels)
        loss = F.binary_cross_entropy(all_preds, all_labels)
        return loss.item(), acc, auc
    else:
        return None, None, None

In [None]:
def train_parallel(_accelerator, _model, _train_loader, _optimizer):  # , _device, verbose=0
    """
    并行训练函数
    """
    # device、模型状态、初始化loss计数、进度条
    _model.train()
    total_loss = 0.0
    pred_loss = 0.0
    cl_loss = 0.0
    reg_loss = 0.0
    progress_bar = tqdm(_train_loader, desc="Training", disable=not _accelerator.is_local_main_process)
    _device = _accelerator.device

    # count = 0
    for batch in progress_bar:
        # # 进度可视化（改为使用进度条）
        # if _accelerator.is_main_process and verbose and (count + 1) % 200 == 0:
        #     _now = datetime.now()
        #     print(f'{_now.strftime("%Y-%m-%d %H:%M:%S")}, {count+1} of {len(_train_loader)}')
        # count += 1

        # 处理batch数据，组装为bert模型输入格式
        stu_ids = batch['stu_id'].to(_device)                 # 学生ID
        labels = batch['label'].to(_device).float()           # 响应真实值
        input_ids = batch['input_ids'].to(_device)            # tokenize内容
        attention_mask = batch['attention_mask'].to(_device)  # 对应的mask
        exer_in = {'input_ids': input_ids, 'attention_mask': attention_mask}  # 组装为bert模型输入格式

        _optimizer.zero_grad()                                # 梯度清零
        output = _model(stu_ids, exer_in)                     # 前向传播
        loss, loss_bce, loss_cl, loss_reg  = _model.get_loss(output, labels)
        _accelerator.backward(loss)                           # 反向传播
        _optimizer.step()                                     # 优化器调整权重

        total_loss += loss.item()                             # 累计损失
        pred_loss += loss_bce.item()
        cl_loss += loss_cl.item()
        reg_loss += loss_reg.item()
        progress_bar.set_postfix(
            loss=total_loss / (progress_bar.n + 1),           # 进度条可视化
            bce=pred_loss / (progress_bar.n + 1),
            cl=cl_loss / (progress_bar.n + 1),
            reg=reg_loss / (progress_bar.n + 1)
        )

    # 计算该轮训练的最终平均损失
    avg_total_loss = total_loss / len(_train_loader)
    avg_pred_loss = pred_loss / len(_train_loader)
    avg_cl_loss = cl_loss / len(_train_loader)
    avg_reg_loss = reg_loss / len(_train_loader)

    # 返回各项损失
    return avg_total_loss, avg_pred_loss, avg_cl_loss, avg_reg_loss
