In [80]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from skimage import io
import seaborn as sns
import warnings
import numpy as np
import warnings
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import warnings
from pylab import mpl, plt
import matplotlib.patches as mpatches
from tqdm.notebook import tqdm

# best font and style settings for notebook 
warnings.filterwarnings('ignore')
sns.set_style("white")
mpl.rcParams['font.family'] = 'MiSans'

model_path = r"./Qwen3-0.6B"  # modify to your Qwen Path
# model_path = r"./Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
from datasets import load_dataset

dataset = load_dataset("parquet", data_files={
    "train": "./gsm8k/main/train-00000-of-00001.parquet",
    "test": "./gsm8k/main/test-00000-of-00001.parquet"
})

train_data = dataset["train"]
test_data = dataset["test"]



In [12]:
from delta_trainer import train_delta_from_H

import torch
from tqdm.notebook import tqdm
from datasets import load_dataset # 确保这一行在您的代码中存在

# 假设 generate_by_H_eos_fast 的定义如您所提供

# 假设 train_delta_from_H 的定义如下 (请根据您的实际实现填充)
# 为了测试目的，我将提供一个占位符实现


def generate_by_H_eos_fast(model, prompt, tokenizer, delta, answer_len=100):
    """
    使用 past_key_values 加速，支持 eos 截断的 H 层扰动生成。

    参数：
    - model: 支持 use_cache 的 decoder-only 模型（如 GPT 系列）
    - prompt: 输入文本
    - tokenizer: 分词器
    - delta: shape=[1, 1, hidden_size] 的扰动张量
    - answer_len: 最多生成 token 数

    返回：
    - record_txt: 解码后的文本（不含 prompt 部分）
    """
    eos_token_id = tokenizer.eos_token_id
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]  # [1, L_prompt]

    with torch.no_grad():
        # 初始化推理，缓存 key_values
        outputs = model(input_ids=input_ids, return_dict=True, output_hidden_states=True, use_cache=True)
        past_key_values = outputs.past_key_values

        # 首个扰动 + 生成
        H_last = outputs.hidden_states[-1][:, -1, :] + delta.squeeze(1)  # [1, hidden_size]
        logits = torch.matmul(H_last, model.lm_head.weight.T)
        next_token_id = torch.argmax(logits, dim=-1, keepdim=True)  # [1, 1]

    record = [next_token_id]  # 收集生成 token

    for _ in range(answer_len - 1):  # 已生成 1 个，最多生成 answer_len 个
        if next_token_id.item() == eos_token_id:
            break

        with torch.no_grad():
            outputs = model(
                input_ids=next_token_id,
                past_key_values=past_key_values,
                return_dict=True,
                output_hidden_states=True,
                use_cache=True
            )
            past_key_values = outputs.past_key_values
            H_last = outputs.hidden_states[-1][:, -1, :] + delta.squeeze(1)
            logits = torch.matmul(H_last, model.lm_head.weight.T)
            next_token_id = torch.argmax(logits, dim=-1, keepdim=True)  # [1, 1]

        record.append(next_token_id)

    # 拼接生成序列（不含 prompt）
    gen_ids = torch.cat(record, dim=-1)  # [1, T]
    record_txt = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return record_txt
def evaluate_gsm8k_eos(model, tokenizer, delta, example, max_len=200, verbose=True):
    """
    基于 generate_by_H_eos_fast 的评估函数，用于 GSM8K 数据集。
    目标是生成解答和最终的数字答案。

    返回：
    - generated_text: 模型生成的完整文本（解答 + 答案）
    - extracted_answer: 从生成文本中提取的数字答案
    - actual_answer: 实际的数字答案
    - is_correct: 答案是否正确
    """
    prompt = f"问题: {example['question']}\n答案是:" # GSM8K prompt 示例

    generated_text = generate_by_H_eos_fast(model, prompt, tokenizer, delta, answer_len=max_len)

    if verbose:
        print("🔍 模型生成结果:\n", generated_text)

    # 从生成文本中提取数字答案
    # GSM8K 的答案通常在 "\\n#### " 之后
    extracted_answer = None
    if "####" in generated_text:
        try:
            # 找到最后一个 #### 后的内容
            answer_part = generated_text.split("####")[-1].strip()
            # 尝试将答案部分转换为整数或浮点数
            extracted_answer = float(answer_part)
        except ValueError:
            pass # 如果无法转换为数字，则保持 None

    # GSM8K 的真实答案是一个字符串，可能需要转换为数字进行比较
    actual_answer = None
    try:
        actual_answer = float(example['answer'].split("####")[-1].strip())
    except ValueError:
        pass

    is_correct = (extracted_answer is not None and actual_answer is not None and abs(extracted_answer - actual_answer) < 1e-6) # 使用一个小的容差进行浮点数比较

    return generated_text, extracted_answer, actual_answer, is_correct


def eval_gsm8k_dataset(model, tokenizer, step=3, max_len=200, lr=1e-2): # 增加 model 和 tokenizer 参数
    # 加载 GSM8K 测试数据集
    dataset = load_dataset("parquet", data_files={
        "train": "./gsm8k/main/train-00000-of-00001.parquet",
        "test": "./gsm8k/main/test-00000-of-00001.parquet"
    })
    gsm8k_test_data = dataset["test"]

    correct = 0
    total = 0
    results_sheet = []

    for ex in tqdm(gsm8k_test_data):
        # === 构造每道题的 Prompt ===
        # GSM8K 的 prompt 通常是问题本身，然后模型生成思考过程和答案
        prompt = f"问题: {ex['question']}\n答案是:"

        # === 获取 H_state ===
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)
        H = outputs.hidden_states[-1]

        # === 训练 delta（例如3步）===
        # 这里的 train_delta_from_H 需要根据 GSM8K 的目标（生成正确的数字答案）来调整其内部损失函数
        # 例如，可以尝试优化 delta 使模型生成下一个 token 的概率分布更接近正确答案的 token 序列
        delta = train_delta_from_H(model, tokenizer, prompt, H, step=step, lr=lr)

        # === 推理与评估 ===
        gen_txt, pred_ans, actual_ans, is_correct = evaluate_gsm8k_eos(model=model, tokenizer=tokenizer,
                                                                       delta=delta, example=ex,
                                                                       max_len=max_len, verbose=False)
        correct += int(is_correct)
        total += 1
        results_sheet.append([prompt, gen_txt, pred_ans, actual_ans, is_correct, "gsm8k"])

    print(f"🎯 GSM8K Accuracy (per-question delta): {correct}/{total} = {correct / total:.2%}")
    return results_sheet



print(train_data[0])

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}


In [ ]:
gsm8k_results = eval_gsm8k_dataset(model, tokenizer, step=3, max_len=200, lr=1e-2)