In [80]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from skimage import io
import seaborn as sns
import warnings
import numpy as np
import warnings
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import warnings
from pylab import mpl, plt
import matplotlib.patches as mpatches
from tqdm.notebook import tqdm

# best font and style settings for notebook 
warnings.filterwarnings('ignore')
sns.set_style("white")
mpl.rcParams['font.family'] = 'MiSans'

model_path = r"./Qwen3-0.6B"  # modify to your Qwen Path
# model_path = r"./Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")

In [81]:
from delta_trainer import train_delta_from_H, generate_by_H, evaluate_slot_ceval, evaluate_slot_ceval_eos, \
    evaluate_slot_ceval_eos_2

# 构造 prompt & 得到 H_state
prompt = "请写一段关于AI教育的引言。"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True, return_dict=True)
H = outputs.hidden_states[-1]

# 调用 delta 训练
delta_3 = train_delta_from_H(model, tokenizer, prompt, H, step=3)
delta_10 = train_delta_from_H(model, tokenizer, prompt, H, step=10)
delta_30 = train_delta_from_H(model, tokenizer, prompt, H, step=30)


In [None]:
# generate_by_H(model=model, prompt=prompt, tokenizer=tokenizer, delta=delta_3, answer_len=200)

In [82]:
from datasets import get_dataset_config_names

# 获取本地路径 "./ceval-exam" 中可用的所有子数据集名称（config names）
dataset_path = "./ceval-exam"
dataset_names = get_dataset_config_names(path=dataset_path)
dataset_names

['accountant',
 'advanced_mathematics',
 'art_studies',
 'basic_medicine',
 'business_administration',
 'chinese_language_and_literature',
 'civil_servant',
 'clinical_medicine',
 'college_chemistry',
 'college_economics',
 'college_physics',
 'college_programming',
 'computer_architecture',
 'computer_network',
 'discrete_mathematics',
 'education_science',
 'electrical_engineer',
 'environmental_impact_assessment_engineer',
 'fire_engineer',
 'high_school_biology',
 'high_school_chemistry',
 'high_school_chinese',
 'high_school_geography',
 'high_school_history',
 'high_school_mathematics',
 'high_school_physics',
 'high_school_politics',
 'ideological_and_moral_cultivation',
 'law',
 'legal_professional',
 'logic',
 'mao_zedong_thought',
 'marxism',
 'metrology_engineer',
 'middle_school_biology',
 'middle_school_chemistry',
 'middle_school_geography',
 'middle_school_history',
 'middle_school_mathematics',
 'middle_school_physics',
 'middle_school_politics',
 'modern_chinese_histor

In [84]:
from datasets import load_dataset

dataset = load_dataset(r"./ceval-exam", name="computer_network")
print(dataset['val'][10])

{'id': 10, 'question': '____采用链路状态算法', 'A': 'RIP', 'B': 'OSPF', 'C': 'BGP-4', 'D': 'EGP', 'answer': 'B', 'explanation': ''}


In [None]:
def evaluate_slot_ceval_eos(model, tokenizer, delta, example, max_len=20, verbose=True):
    """
    基于 generate_by_H_eos 的评估函数，用于 C-Eval 单选题目。

    返回：
    - predict_option: 预测选项，如 'A'
    - is_correct: 是否预测正确
    """
    prompt = f"""以下是一道单项选择题，请你阅读题目并选择最合适的选项。

题目：{example['question']}

选项：
A. {example['A']}
B. {example['B']}
C. {example['C']}
D. {example['D']}

答案是："""

    output_text = generate_by_H_eos(model, prompt, tokenizer, delta, answer_len=max_len)

    if verbose:
        print("🔍 模型生成结果:\n", output_text)

    predict_option = None
    for option in ['A', 'B', 'C', 'D']:
        if option in output_text:
            predict_option = option
            break

    is_correct = (predict_option == example['answer'])
    # return predict_option, is_correct
    return output_text, predict_option, example['answer'], is_correct

In [85]:
from delta_trainer import generate_by_H_eos

import torch
from tqdm.notebook import tqdm


def generate_by_H_eos_fast(model, prompt, tokenizer, delta, answer_len=100):
    """
    使用 past_key_values 加速，支持 eos 截断的 H 层扰动生成。

    参数：
    - model: 支持 use_cache 的 decoder-only 模型（如 GPT 系列）
    - prompt: 输入文本
    - tokenizer: 分词器
    - delta: shape=[1, 1, hidden_size] 的扰动张量
    - answer_len: 最多生成 token 数

    返回：
    - record_txt: 解码后的文本（不含 prompt 部分）
    """
    eos_token_id = tokenizer.eos_token_id
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]  # [1, L_prompt]

    with torch.no_grad():
        # 初始化推理，缓存 key_values
        outputs = model(input_ids=input_ids, return_dict=True, output_hidden_states=True, use_cache=True)
        past_key_values = outputs.past_key_values

        # 首个扰动 + 生成
        H_last = outputs.hidden_states[-1][:, -1, :] + delta.squeeze(1)  # [1, hidden_size]
        logits = torch.matmul(H_last, model.lm_head.weight.T)
        next_token_id = torch.argmax(logits, dim=-1, keepdim=True)  # [1, 1]

    record = [next_token_id]  # 收集生成 token

    for _ in range(answer_len - 1):  # 已生成 1 个，最多生成 answer_len 个
        if next_token_id.item() == eos_token_id:
            break

        with torch.no_grad():
            outputs = model(
                input_ids=next_token_id,
                past_key_values=past_key_values,
                return_dict=True,
                output_hidden_states=True,
                use_cache=True
            )
            past_key_values = outputs.past_key_values
            H_last = outputs.hidden_states[-1][:, -1, :] + delta.squeeze(1)
            logits = torch.matmul(H_last, model.lm_head.weight.T)
            next_token_id = torch.argmax(logits, dim=-1, keepdim=True)  # [1, 1]

        record.append(next_token_id)

    # 拼接生成序列（不含 prompt）
    gen_ids = torch.cat(record, dim=-1)  # [1, T]
    record_txt = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return record_txt


def evaluate_slot_ceval_eos(model, tokenizer, delta, example, prompt, max_len=20, verbose=True):
    """
    基于 generate_by_H_eos 的评估函数，用于 C-Eval 单选题目。

    返回：
    - predict_option: 预测选项，如 'A'
    - is_correct: 是否预测正确
    # """

    # output_text = generate_by_H_eos(model, prompt, tokenizer, delta, answer_len=max_len)
    output_text = generate_by_H_eos_fast(model, prompt, tokenizer, delta, answer_len=max_len)

    if verbose:
        print("🔍 模型生成结果:\n", output_text)

    predict_option = None
    for option in ['A', 'B', 'C', 'D']:
        if option in output_text:
            predict_option = option
            break

    is_correct = (predict_option == example['answer'])
    # return predict_option, is_correct
    return output_text, predict_option, example['answer'], is_correct


def eval_dataset(dataset_name, step=3, max_len=50, lr=1e-2):
    # dataset_name = "computer_network"
    dataset = load_dataset(r"./ceval-exam", name=dataset_name)

    correct = 0
    total = 0

    answer_sheet = []
    for ex in tqdm(dataset['val']):
        # === 构造每道题的 Prompt ===
        prompt = f"""以下是一道单项选择题，请你阅读题目，选择最合适的选项。
        题目：{ex['question']}
        选项：
        A. {ex['A']}
        B. {ex['B']}
        C. {ex['C']}
        D. {ex['D']}
        答案是："""
        prompt = f"""请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。
        
        题目：
        {ex['question']}
        
        选项：
        A. {ex['A']}
        B. {ex['B']}
        C. {ex['C']}
        D. {ex['D']}
        
        答案是："""

#         prompt = f"""请阅读以下单项选择题，并从四个选项中选出一个最合适的答案。
# 
# 题目：
# {ex['question']}
# 
# 选项：
# A. {ex['A']}
# B. {ex['B']}
# C. {ex['C']}
# D. {ex['D']}
# 
# 请直接回答，答案是："""

        # === 获取 H_state ===
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)
        H = outputs.hidden_states[-1]

        # === 训练 delta（例如3步）===
        delta = train_delta_from_H(model, tokenizer, prompt, H, step=step, lr=lr)

        # === 推理与评估 ===
        pred_txt, pre_answer, answer, is_correct = evaluate_slot_ceval_eos(model=model, tokenizer=tokenizer,
                                                                           delta=delta,
                                                                           example=ex, max_len=max_len, prompt=prompt,
                                                                           verbose=False)
        correct += int(is_correct)
        total += 1
        answer_sheet.append([prompt, pred_txt, pre_answer, answer, is_correct, dataset_name])
    print(f"🎯 {dataset_name} Accuracy (per-question delta): {correct}/{total} = {correct / total:.2%}")
    return answer_sheet

In [88]:
model_path.split("/")[-1]

'Qwen3-0.6B'

In [ ]:
for step in [3, 0, 6, 9]:
    max_len = 5
    answer_sheet = []
    for dataset_name in tqdm(dataset_names[:]):
        answer_sheet += eval_dataset(dataset_name=dataset_name, step=step, max_len=max_len, lr=1e-2)
        df_answer = pd.DataFrame(answer_sheet)
        df_answer.to_csv(f"./eval_result/0_6B/answer_step_{step}.csv", index=False)
pd.DataFrame(answer_sheet)

  0%|          | 0/52 [00:00<?, ?it/s]

  0%|          | 0/49 [00:00<?, ?it/s]

🎯 accountant Accuracy (per-question delta): 25/49 = 51.02%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 advanced_mathematics Accuracy (per-question delta): 4/19 = 21.05%


  0%|          | 0/33 [00:00<?, ?it/s]

🎯 art_studies Accuracy (per-question delta): 17/33 = 51.52%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 basic_medicine Accuracy (per-question delta): 13/19 = 68.42%


  0%|          | 0/33 [00:00<?, ?it/s]

🎯 business_administration Accuracy (per-question delta): 15/33 = 45.45%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 chinese_language_and_literature Accuracy (per-question delta): 10/23 = 43.48%


  0%|          | 0/47 [00:00<?, ?it/s]

🎯 civil_servant Accuracy (per-question delta): 20/47 = 42.55%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 clinical_medicine Accuracy (per-question delta): 9/22 = 40.91%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 college_chemistry Accuracy (per-question delta): 9/24 = 37.50%


  0%|          | 0/55 [00:00<?, ?it/s]

🎯 college_economics Accuracy (per-question delta): 27/55 = 49.09%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 college_physics Accuracy (per-question delta): 11/19 = 57.89%


  0%|          | 0/37 [00:00<?, ?it/s]

🎯 college_programming Accuracy (per-question delta): 15/37 = 40.54%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 computer_architecture Accuracy (per-question delta): 8/21 = 38.10%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 computer_network Accuracy (per-question delta): 7/19 = 36.84%


  0%|          | 0/16 [00:00<?, ?it/s]

🎯 discrete_mathematics Accuracy (per-question delta): 5/16 = 31.25%


  0%|          | 0/29 [00:00<?, ?it/s]

🎯 education_science Accuracy (per-question delta): 19/29 = 65.52%


  0%|          | 0/37 [00:00<?, ?it/s]

🎯 electrical_engineer Accuracy (per-question delta): 16/37 = 43.24%


  0%|          | 0/31 [00:00<?, ?it/s]

🎯 environmental_impact_assessment_engineer Accuracy (per-question delta): 21/31 = 67.74%


  0%|          | 0/31 [00:00<?, ?it/s]

🎯 fire_engineer Accuracy (per-question delta): 14/31 = 45.16%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_biology Accuracy (per-question delta): 9/19 = 47.37%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_chemistry Accuracy (per-question delta): 7/19 = 36.84%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_chinese Accuracy (per-question delta): 7/19 = 36.84%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_geography Accuracy (per-question delta): 8/19 = 42.11%


  0%|          | 0/20 [00:00<?, ?it/s]

🎯 high_school_history Accuracy (per-question delta): 16/20 = 80.00%


  0%|          | 0/18 [00:00<?, ?it/s]

🎯 high_school_mathematics Accuracy (per-question delta): 6/18 = 33.33%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_physics Accuracy (per-question delta): 11/19 = 57.89%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_politics Accuracy (per-question delta): 12/19 = 63.16%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 ideological_and_moral_cultivation Accuracy (per-question delta): 14/19 = 73.68%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 law Accuracy (per-question delta): 6/24 = 25.00%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 legal_professional Accuracy (per-question delta): 8/23 = 34.78%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 logic Accuracy (per-question delta): 13/22 = 59.09%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 mao_zedong_thought Accuracy (per-question delta): 19/24 = 79.17%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 marxism Accuracy (per-question delta): 12/19 = 63.16%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 metrology_engineer Accuracy (per-question delta): 14/24 = 58.33%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 middle_school_biology Accuracy (per-question delta): 14/21 = 66.67%


  0%|          | 0/20 [00:00<?, ?it/s]

🎯 middle_school_chemistry Accuracy (per-question delta): 13/20 = 65.00%


  0%|          | 0/12 [00:00<?, ?it/s]

🎯 middle_school_geography Accuracy (per-question delta): 6/12 = 50.00%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 middle_school_history Accuracy (per-question delta): 12/22 = 54.55%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 middle_school_mathematics Accuracy (per-question delta): 6/19 = 31.58%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 middle_school_physics Accuracy (per-question delta): 10/19 = 52.63%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 middle_school_politics Accuracy (per-question delta): 15/21 = 71.43%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 modern_chinese_history Accuracy (per-question delta): 11/23 = 47.83%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 operating_system Accuracy (per-question delta): 9/19 = 47.37%


  0%|          | 0/49 [00:00<?, ?it/s]

🎯 physician Accuracy (per-question delta): 24/49 = 48.98%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 plant_protection Accuracy (per-question delta): 13/22 = 59.09%


  0%|          | 0/18 [00:00<?, ?it/s]

🎯 probability_and_statistics Accuracy (per-question delta): 4/18 = 22.22%


  0%|          | 0/29 [00:00<?, ?it/s]

🎯 professional_tour_guide Accuracy (per-question delta): 15/29 = 51.72%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 sports_science Accuracy (per-question delta): 7/19 = 36.84%


  0%|          | 0/49 [00:00<?, ?it/s]

🎯 tax_accountant Accuracy (per-question delta): 24/49 = 48.98%


  0%|          | 0/44 [00:00<?, ?it/s]

🎯 teacher_qualification Accuracy (per-question delta): 28/44 = 63.64%


  0%|          | 0/46 [00:00<?, ?it/s]

🎯 urban_and_rural_planner Accuracy (per-question delta): 29/46 = 63.04%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 veterinary_medicine Accuracy (per-question delta): 10/23 = 43.48%


  0%|          | 0/52 [00:00<?, ?it/s]

  0%|          | 0/49 [00:00<?, ?it/s]

🎯 accountant Accuracy (per-question delta): 25/49 = 51.02%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 advanced_mathematics Accuracy (per-question delta): 4/19 = 21.05%


  0%|          | 0/33 [00:00<?, ?it/s]

🎯 art_studies Accuracy (per-question delta): 17/33 = 51.52%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 basic_medicine Accuracy (per-question delta): 13/19 = 68.42%


  0%|          | 0/33 [00:00<?, ?it/s]

🎯 business_administration Accuracy (per-question delta): 15/33 = 45.45%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 chinese_language_and_literature Accuracy (per-question delta): 10/23 = 43.48%


  0%|          | 0/47 [00:00<?, ?it/s]

🎯 civil_servant Accuracy (per-question delta): 20/47 = 42.55%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 clinical_medicine Accuracy (per-question delta): 9/22 = 40.91%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 college_chemistry Accuracy (per-question delta): 9/24 = 37.50%


  0%|          | 0/55 [00:00<?, ?it/s]

🎯 college_economics Accuracy (per-question delta): 26/55 = 47.27%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 college_physics Accuracy (per-question delta): 11/19 = 57.89%


  0%|          | 0/37 [00:00<?, ?it/s]

🎯 college_programming Accuracy (per-question delta): 16/37 = 43.24%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 computer_architecture Accuracy (per-question delta): 8/21 = 38.10%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 computer_network Accuracy (per-question delta): 7/19 = 36.84%


  0%|          | 0/16 [00:00<?, ?it/s]

🎯 discrete_mathematics Accuracy (per-question delta): 5/16 = 31.25%


  0%|          | 0/29 [00:00<?, ?it/s]

🎯 education_science Accuracy (per-question delta): 19/29 = 65.52%


  0%|          | 0/37 [00:00<?, ?it/s]

🎯 electrical_engineer Accuracy (per-question delta): 15/37 = 40.54%


  0%|          | 0/31 [00:00<?, ?it/s]

🎯 environmental_impact_assessment_engineer Accuracy (per-question delta): 21/31 = 67.74%


  0%|          | 0/31 [00:00<?, ?it/s]

🎯 fire_engineer Accuracy (per-question delta): 14/31 = 45.16%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_biology Accuracy (per-question delta): 9/19 = 47.37%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_chemistry Accuracy (per-question delta): 7/19 = 36.84%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_chinese Accuracy (per-question delta): 7/19 = 36.84%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_geography Accuracy (per-question delta): 10/19 = 52.63%


  0%|          | 0/20 [00:00<?, ?it/s]

🎯 high_school_history Accuracy (per-question delta): 17/20 = 85.00%


  0%|          | 0/18 [00:00<?, ?it/s]

🎯 high_school_mathematics Accuracy (per-question delta): 6/18 = 33.33%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_physics Accuracy (per-question delta): 11/19 = 57.89%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_politics Accuracy (per-question delta): 12/19 = 63.16%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 ideological_and_moral_cultivation Accuracy (per-question delta): 14/19 = 73.68%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 law Accuracy (per-question delta): 6/24 = 25.00%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 legal_professional Accuracy (per-question delta): 8/23 = 34.78%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 logic Accuracy (per-question delta): 12/22 = 54.55%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 mao_zedong_thought Accuracy (per-question delta): 18/24 = 75.00%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 marxism Accuracy (per-question delta): 12/19 = 63.16%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 metrology_engineer Accuracy (per-question delta): 14/24 = 58.33%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 middle_school_biology Accuracy (per-question delta): 14/21 = 66.67%


  0%|          | 0/20 [00:00<?, ?it/s]

🎯 middle_school_chemistry Accuracy (per-question delta): 13/20 = 65.00%


  0%|          | 0/12 [00:00<?, ?it/s]

🎯 middle_school_geography Accuracy (per-question delta): 6/12 = 50.00%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 middle_school_history Accuracy (per-question delta): 12/22 = 54.55%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 middle_school_mathematics Accuracy (per-question delta): 6/19 = 31.58%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 middle_school_physics Accuracy (per-question delta): 10/19 = 52.63%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 middle_school_politics Accuracy (per-question delta): 15/21 = 71.43%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 modern_chinese_history Accuracy (per-question delta): 11/23 = 47.83%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 operating_system Accuracy (per-question delta): 9/19 = 47.37%


  0%|          | 0/49 [00:00<?, ?it/s]

🎯 physician Accuracy (per-question delta): 24/49 = 48.98%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 plant_protection Accuracy (per-question delta): 13/22 = 59.09%


  0%|          | 0/18 [00:00<?, ?it/s]

🎯 probability_and_statistics Accuracy (per-question delta): 5/18 = 27.78%


  0%|          | 0/29 [00:00<?, ?it/s]

In [None]:
for step in [3, 0, 6, 9]:
    max_len = 5
    answer_sheet = []
    for dataset_name in tqdm(dataset_names[:]):
        answer_sheet += eval_dataset(dataset_name=dataset_name, step=step, max_len=max_len, lr=1e-2)
        df_answer = pd.DataFrame(answer_sheet)
        df_answer.to_csv(f"./eval_result/1_7B/answer_step_{step}.csv", index=False)
pd.DataFrame(answer_sheet)

In [79]:
for step in [15]:
    max_len = 5
    answer_sheet = []
    for dataset_name in tqdm(dataset_names[:]):
        answer_sheet += eval_dataset(dataset_name=dataset_name, step=step, max_len=max_len, lr=1e-2)
        df_answer = pd.DataFrame(answer_sheet)
        df_answer.to_csv(f"./eval_result/1_7B/answer_step_{step}.csv", index=False)
pd.DataFrame(answer_sheet)

  0%|          | 0/52 [00:00<?, ?it/s]

  0%|          | 0/49 [00:00<?, ?it/s]

🎯 accountant Accuracy (per-question delta): 25/49 = 51.02%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 advanced_mathematics Accuracy (per-question delta): 8/19 = 42.11%


  0%|          | 0/33 [00:00<?, ?it/s]

🎯 art_studies Accuracy (per-question delta): 21/33 = 63.64%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 basic_medicine Accuracy (per-question delta): 13/19 = 68.42%


  0%|          | 0/33 [00:00<?, ?it/s]

🎯 business_administration Accuracy (per-question delta): 20/33 = 60.61%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 chinese_language_and_literature Accuracy (per-question delta): 14/23 = 60.87%


  0%|          | 0/47 [00:00<?, ?it/s]

🎯 civil_servant Accuracy (per-question delta): 25/47 = 53.19%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 clinical_medicine Accuracy (per-question delta): 10/22 = 45.45%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 college_chemistry Accuracy (per-question delta): 12/24 = 50.00%


  0%|          | 0/55 [00:00<?, ?it/s]

🎯 college_economics Accuracy (per-question delta): 25/55 = 45.45%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 college_physics Accuracy (per-question delta): 8/19 = 42.11%


  0%|          | 0/37 [00:00<?, ?it/s]

🎯 college_programming Accuracy (per-question delta): 26/37 = 70.27%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 computer_architecture Accuracy (per-question delta): 12/21 = 57.14%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 computer_network Accuracy (per-question delta): 10/19 = 52.63%


  0%|          | 0/16 [00:00<?, ?it/s]

🎯 discrete_mathematics Accuracy (per-question delta): 4/16 = 25.00%


  0%|          | 0/29 [00:00<?, ?it/s]

🎯 education_science Accuracy (per-question delta): 20/29 = 68.97%


  0%|          | 0/37 [00:00<?, ?it/s]

🎯 electrical_engineer Accuracy (per-question delta): 17/37 = 45.95%


  0%|          | 0/31 [00:00<?, ?it/s]

🎯 environmental_impact_assessment_engineer Accuracy (per-question delta): 21/31 = 67.74%


  0%|          | 0/31 [00:00<?, ?it/s]

🎯 fire_engineer Accuracy (per-question delta): 18/31 = 58.06%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_biology Accuracy (per-question delta): 8/19 = 42.11%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_chemistry Accuracy (per-question delta): 9/19 = 47.37%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_chinese Accuracy (per-question delta): 9/19 = 47.37%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_geography Accuracy (per-question delta): 12/19 = 63.16%


  0%|          | 0/20 [00:00<?, ?it/s]

🎯 high_school_history Accuracy (per-question delta): 15/20 = 75.00%


  0%|          | 0/18 [00:00<?, ?it/s]

🎯 high_school_mathematics Accuracy (per-question delta): 5/18 = 27.78%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_physics Accuracy (per-question delta): 11/19 = 57.89%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 high_school_politics Accuracy (per-question delta): 14/19 = 73.68%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 ideological_and_moral_cultivation Accuracy (per-question delta): 14/19 = 73.68%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 law Accuracy (per-question delta): 10/24 = 41.67%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 legal_professional Accuracy (per-question delta): 9/23 = 39.13%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 logic Accuracy (per-question delta): 12/22 = 54.55%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 mao_zedong_thought Accuracy (per-question delta): 20/24 = 83.33%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 marxism Accuracy (per-question delta): 14/19 = 73.68%


  0%|          | 0/24 [00:00<?, ?it/s]

🎯 metrology_engineer Accuracy (per-question delta): 18/24 = 75.00%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 middle_school_biology Accuracy (per-question delta): 18/21 = 85.71%


  0%|          | 0/20 [00:00<?, ?it/s]

🎯 middle_school_chemistry Accuracy (per-question delta): 19/20 = 95.00%


  0%|          | 0/12 [00:00<?, ?it/s]

🎯 middle_school_geography Accuracy (per-question delta): 7/12 = 58.33%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 middle_school_history Accuracy (per-question delta): 16/22 = 72.73%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 middle_school_mathematics Accuracy (per-question delta): 6/19 = 31.58%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 middle_school_physics Accuracy (per-question delta): 17/19 = 89.47%


  0%|          | 0/21 [00:00<?, ?it/s]

🎯 middle_school_politics Accuracy (per-question delta): 17/21 = 80.95%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 modern_chinese_history Accuracy (per-question delta): 12/23 = 52.17%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 operating_system Accuracy (per-question delta): 8/19 = 42.11%


  0%|          | 0/49 [00:00<?, ?it/s]

🎯 physician Accuracy (per-question delta): 29/49 = 59.18%


  0%|          | 0/22 [00:00<?, ?it/s]

🎯 plant_protection Accuracy (per-question delta): 15/22 = 68.18%


  0%|          | 0/18 [00:00<?, ?it/s]

🎯 probability_and_statistics Accuracy (per-question delta): 7/18 = 38.89%


  0%|          | 0/29 [00:00<?, ?it/s]

🎯 professional_tour_guide Accuracy (per-question delta): 13/29 = 44.83%


  0%|          | 0/19 [00:00<?, ?it/s]

🎯 sports_science Accuracy (per-question delta): 9/19 = 47.37%


  0%|          | 0/49 [00:00<?, ?it/s]

🎯 tax_accountant Accuracy (per-question delta): 22/49 = 44.90%


  0%|          | 0/44 [00:00<?, ?it/s]

🎯 teacher_qualification Accuracy (per-question delta): 36/44 = 81.82%


  0%|          | 0/46 [00:00<?, ?it/s]

🎯 urban_and_rural_planner Accuracy (per-question delta): 34/46 = 73.91%


  0%|          | 0/23 [00:00<?, ?it/s]

🎯 veterinary_medicine Accuracy (per-question delta): 13/23 = 56.52%


Unnamed: 0,0,1,2,3,4,5
0,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,D\n 你的,D,D,True,accountant
1,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,C\n 请,C,C,True,accountant
2,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,D\n 请,D,D,True,accountant
3,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,A\n 你的,A,A,True,accountant
4,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,C\n 请,C,C,True,accountant
...,...,...,...,...,...,...
1341,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,A\n 请,A,A,True,veterinary_medicine
1342,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,D\n 请,D,D,True,veterinary_medicine
1343,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,C\n 请,C,A,False,veterinary_medicine
1344,请你扮演一位专业的考试助手，阅读下面的单项选择题，并根据内容在四个选项中选出最合适的一个。\...,C\n 请,C,C,True,veterinary_medicine


In [ ]:
for step in [3, 0, 6, 12]:
    max_len = 30
    answer_sheet = []
    for i in tqdm(dataset_names[:]):
        answer_sheet += eval_dataset(i, step=step, max_len=max_len, lr=1e-3)
        df_answer = pd.DataFrame(answer_sheet)
        df_answer.to_csv(f"./eval_result/1_7B/answer_step_{step}.csv", index=False)

In [None]:
for step in [3]:
    max_len = 50
    answer_sheet = []
    for i in tqdm(dataset_names[8]):
        answer_sheet += eval_dataset(i, step=step, max_len=max_len, lr=1e-3)
        df_answer = pd.DataFrame(answer_sheet)
        df_answer.to_csv(f"./eval_result/1_7B/answer_step_{step}.csv", index=False)

In [None]:
pd.DataFrame(answer_sheet)