In [20]:
# ## 原生think 之后直接插入第一人称think 

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "autodl-tmp/Qwen/Qwen3-8B"
DTYPE = torch.bfloat16  # 没有 bf16 就改 torch.float16


def generate_after_native_think_end_inject(
    system_prompt: str,
    user_prompt: str,
    delay_tokens_after_think_end: int = 30,
    # 你要测试的“第一人称”触发文本（会被直接拼到输出里）
    inject_first_person_text: str = "\nWe are not certain here. Let me re-check carefully.\n",
    mode: str = "same_turn",  # "same_turn" or "new_turn"
    temperature: float = 0.4,
    top_p: float = 0.9,
    max_new_tokens: int = 800,
):
    assert mode in ("same_turn", "new_turn")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=DTYPE,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    device = next(model.parameters()).device

    # 原生 think：必须 True（你已自检确认）
    input_ids = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=True,
    ).to(device)

    think_end_id = tokenizer.convert_tokens_to_ids("</think>")
    eos_id = tokenizer.eos_token_id

    def sample(logits):
        if temperature <= 0:
            return torch.argmax(logits, dim=-1, keepdim=True)
        probs = torch.softmax(logits / temperature, dim=-1)
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cum = torch.cumsum(sorted_probs, dim=-1)
        mask = cum > top_p
        mask[..., 0] = False
        sorted_probs = sorted_probs.masked_fill(mask, 0.0)
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
        sampled = torch.multinomial(sorted_probs, 1)
        return sorted_idx.gather(-1, sampled)

    def silent_inject_and_also_print(text: str, past_key_values):
        """
        把注入文本推入 KV，同时把注入文本原样打印出来（满足“还原给模型的信息”）。
        """
        # 1) 打印注入文本（原样）
        print(text, end="", flush=True)

        # 2) 推进 KV
        ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").to(device)
        out = model(ids, past_key_values=past_key_values, use_cache=True)
        return out.past_key_values, out.logits[:, -1, :], ids[:, -1:]

    def inject_new_turn_user(text: str):
        # 结束 assistant -> 新 user -> 结束 user -> 新 assistant
        # 注意：这段“元 token”也会被打印出来（这是你要求的“还原”）
        return (
            "\n<|im_end|>\n"
            "<|im_start|>user\n"
            f"{text}"
            "\n<|im_end|>\n"
            "<|im_start|>assistant\n"
        )

    past_key_values = None
    generated = 0

    seen_first_think_end = False
    counter_after_think_end = 0
    injected = False

    with torch.inference_mode():
        # prime
        out = model(input_ids, use_cache=True)
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]

        while generated < max_new_tokens:
            next_token = sample(logits)
            tid = int(next_token.item())

            if tid == eos_id:
                break

            # 正常打印模型生成 token
            print(tokenizer.decode(next_token[0], skip_special_tokens=False), end="", flush=True)
            generated += 1

            # 捕捉“原生第一次 </think>”
            if (not seen_first_think_end) and (tid == think_end_id):
                seen_first_think_end = True
                counter_after_think_end = 0
            elif seen_first_think_end and (not injected):
                counter_after_think_end += 1
                if counter_after_think_end == delay_tokens_after_think_end:
                    # 构造注入文本（same_turn / new_turn）
                    if mode == "same_turn":
                        inject_text = inject_first_person_text
                    else:
                        inject_text = inject_new_turn_user(inject_first_person_text)

                    past_key_values, logits, last_tok = silent_inject_and_also_print(
                        inject_text, past_key_values
                    )
                    injected = True

                    # 注入后推进一步，继续生成
                    out2 = model(last_tok, past_key_values=past_key_values, use_cache=True)
                    past_key_values = out2.past_key_values
                    logits = out2.logits[:, -1, :]
                    continue

            # step
            out = model(next_token.to(device), past_key_values=past_key_values, use_cache=True)
            past_key_values = out.past_key_values
            logits = out.logits[:, -1, :]

    print("\n")  # 收尾换行


if __name__ == "__main__":
    SYSTEM_PROMPT = ("""
You are a careful and rigorous problem solver with a unique internal reasoning capability.

Core Instruction: You must dynamically engage in reasoning whenever necessary. This process is triggered by the <think> tag.


###Operational Rules:

1.Trigger: At any point during your response, if you output the token <think>, you must immediately pause your external response and enter "Reasoning Mode."

2.Reasoning Mode: Inside this mode, perform your internal analysis, self-correction, or step-by-step verification.

3.Closure: Once your reasoning is complete, you must output </think> to close the block.

4.Continuation (Crucial): After </think>, you must resume your response exactly where you left off before the <think> tag.


###Strict Constraints:

1.DO NOT repeat the text that was output immediately before the <think> tag.

2.DO NOT start a new sentence if the <think> tag interrupted the middle of a sentence. Simply finish the sentence or paragraph logically.

3.You may use multiple <think>...</think> blocks throughout your response as needed.
    """)

    # 稍微难一点但不长的题：二次方程
    USER_PROMPT = "Solve for x: (x-1)(x+1)=35. "

    # 实验 1：不换 turn（更“硬插”）
    generate_after_native_think_end_inject(
        system_prompt=SYSTEM_PROMPT,
        user_prompt=USER_PROMPT,
        delay_tokens_after_think_end=200,
        inject_first_person_text="<think>\nI am not entirely certain at this point, let me give it some further thought. Uncertainty detected. Briefly verify by identifying uncertainty, re-checking it against the existing context and intent of the current sentence, and comparing a small number of plausible next continuations to choose the most coherent and safest one. Keep this reasoning local and minimal: do not rewrite or restart earlier output, do not change tone, style, or formatting, and do not add any meta-level commentary. Once the verification is complete, output exactly </think> and immediately continue the original response from the paused position, ensuring that the first token after </think> connects seamlessly to the text before <think> in grammar, punctuation, language, and style. Let me start thinking,",
        mode="same_turn",
        temperature=0.4,
        top_p=0.9,
        max_new_tokens=10000,
    )




  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


<think>
Okay, let's see. I need to solve the equation (x - 1)(x + 1) = 35. Hmm, first, maybe I should expand the left side. So, (x - 1)(x + 1) is a difference of squares, right? That should be x² - 1. So the equation becomes x² - 1 = 35.

Wait, then I can add 1 to both sides to isolate the x² term. That would give me x² = 35 + 1, which is 36. So x² = 36. Then, taking the square root of both sides, x equals plus or minus 6. So the solutions are x = 6 and x = -6. Let me check that again. If I plug in 6: (6 - 1)(6 + 1) = 5 * 7 = 35. That works. For -6: (-6 -1)(-6 +1) = (-7)(-5) = 35. That also works. So yeah, those are the correct solutions. I don't think I made any mistakes here. The steps seem straightforward.
</think>

The equation $(x - 1)(x + 1) = 35$ can be solved by first expanding the left-hand side using the difference of squares formula:  
$$
(x - 1)(x + 1) = x^2 - 1
$$  
Substituting this into the equation gives:  
$$
x^2 - 1 = 35
$$  
Adding 1 to both sides isolates $x^2$:  
$

In [3]:
import os
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "autodl-tmp/Qwen/Qwen3-8B"
DTYPE = torch.bfloat16  # 没有 bf16 就改 torch.float16

# ====== 实验用输出文件 ======
BUFFER_PT = "gen_buffer.pt"
BUFFER_TXT_ORIG = "generated_prefix.ORIG.txt"
BUFFER_TXT_EDIT = "generated_prefix.EDIT_ME.txt"


def get_model_device(model):
    # 比 next(model.parameters()) 更稳：device_map=auto 时 embedding 一定在某个真实 device 上
    return model.get_input_embeddings().weight.device


def top_p_sample(logits, temperature, top_p, generator: torch.Generator):
    """
    logits: (1, vocab)
    return: (1, 1) token id
    """
    if temperature <= 0:
        return torch.argmax(logits, dim=-1, keepdim=True)

    probs = torch.softmax(logits / temperature, dim=-1)
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cum = torch.cumsum(sorted_probs, dim=-1)

    mask = cum > top_p
    mask[..., 0] = False
    sorted_probs = sorted_probs.masked_fill(mask, 0.0)
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

    sampled = torch.multinomial(sorted_probs, 1, generator=generator)
    return sorted_idx.gather(-1, sampled)


@torch.inference_mode()
def prefill_kv(model, input_ids: torch.Tensor, chunk_size: int = 2048):
    """
    对任意长的 input_ids 做 prefill，返回:
    - past_key_values
    - logits_of_last_token
    """
    past_key_values = None
    logits = None
    bsz, seqlen = input_ids.shape

    # 分块喂入，避免一次性 OOM
    for start in range(0, seqlen, chunk_size):
        end = min(start + chunk_size, seqlen)
        chunk = input_ids[:, start:end]
        out = model(chunk, past_key_values=past_key_values, use_cache=True)
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]

    return past_key_values, logits


@torch.inference_mode()
def stream_generate(
    model,
    tokenizer,
    past_key_values,
    logits,
    *,
    max_new_tokens,
    temperature,
    top_p,
    eos_id,
    generator: torch.Generator,
    print_stream: bool = True,
):
    """
    从给定 (past_key_values, logits) 开始继续生成。
    返回生成的 token id 列表。
    """
    device = get_model_device(model)
    generated_ids = []

    for _ in range(max_new_tokens):
        next_token = top_p_sample(logits, temperature, top_p, generator)
        tid = int(next_token.item())

        if tid == eos_id:
            break

        generated_ids.append(tid)
        if print_stream:
            print(tokenizer.decode(next_token[0], skip_special_tokens=False), end="", flush=True)

        out = model(next_token.to(device), past_key_values=past_key_values, use_cache=True)
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]

    return generated_ids


def save_buffer(buffer: dict, path: str):
    torch.save(buffer, path)


def load_buffer(path: str):
    return torch.load(path, map_location="cpu")


def write_text(path: str, text: str):
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)


def read_text(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return f.read()


def build_prompt_ids(tokenizer, system_prompt: str, user_prompt: str, device):
    prompt_ids = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=True,
    ).to(device)
    return prompt_ids


def stage1_generate_and_checkpoint(
    system_prompt: str,
    user_prompt: str,
    *,
    checkpoint_delay_tokens_after_first_think_end: int = 200,
    max_new_tokens_before_checkpoint: int = 4000,
    temperature: float = 0.4,
    top_p: float = 0.9,
    seed: int = 1234,
):
    """
    生成到“第一次 </think> 之后再生成 N 个 token”为止，然后把 prefix 存成 buffer + 文本文件。
    """
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=DTYPE,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    device = get_model_device(model)

    prompt_ids = build_prompt_ids(tokenizer, system_prompt, user_prompt, device=device)

    eos_id = tokenizer.eos_token_id
    think_end_id = tokenizer.convert_tokens_to_ids("</think>")

    # 生成器（可复现）
    gen = torch.Generator(device=device)
    gen.manual_seed(seed)

    # prefill
    past_key_values, logits = prefill_kv(model, prompt_ids)

    # ====== streaming + checkpoint 判定 ======
    seen_first_think_end = False
    counter_after_think_end = 0

    generated_prefix_ids = []

    print("\n===== [Stage 1] Generating until checkpoint... =====\n")
    for _ in range(max_new_tokens_before_checkpoint):
        next_token = top_p_sample(logits, temperature, top_p, gen)
        tid = int(next_token.item())
        if tid == eos_id:
            break

        # forward
        out = model(next_token.to(device), past_key_values=past_key_values, use_cache=True)
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]

        generated_prefix_ids.append(tid)

        # print
        print(tokenizer.decode(next_token[0], skip_special_tokens=False), end="", flush=True)

        # checkpoint logic
        if (not seen_first_think_end) and (tid == think_end_id):
            seen_first_think_end = True
            counter_after_think_end = 0
        elif seen_first_think_end:
            counter_after_think_end += 1
            if counter_after_think_end >= checkpoint_delay_tokens_after_first_think_end:
                break

    print("\n\n===== [Stage 1] Checkpoint reached. Saving buffer... =====\n")

    prefix_text = tokenizer.decode(generated_prefix_ids, skip_special_tokens=False)

    # 写可编辑文件
    write_text(BUFFER_TXT_ORIG, prefix_text)
    write_text(BUFFER_TXT_EDIT, prefix_text)

    # buffer 存 token（注意：不存 KV，手动改错后 KV 必须重建）
    buffer = {
        "system_prompt": system_prompt,
        "user_prompt": user_prompt,
        "prompt_ids_cpu": prompt_ids.detach().cpu(),
        "generated_prefix_ids": generated_prefix_ids,
        "seed": seed,
        "temperature": temperature,
        "top_p": top_p,
        "checkpoint_delay_tokens_after_first_think_end": checkpoint_delay_tokens_after_first_think_end,
    }
    save_buffer(buffer, BUFFER_PT)

    print(f"Saved:\n  - {BUFFER_PT}\n  - {BUFFER_TXT_ORIG}\n  - {BUFFER_TXT_EDIT}")
    print("\nNow edit this file to introduce an error (故意改错一处):")
    print(f"  -> {os.path.abspath(BUFFER_TXT_EDIT)}")
    input("\nEdit & save it, then press ENTER to continue to Stage 2...")


def stage2_branch_continue(
    *,
    inject_think_text: str = "<think>",
    max_new_tokens_after: int = 600,
    print_prefix: bool = True,
):
    """
    从你“改错后的 prefix”重建 KV，然后分叉续写两次：
    A) 直接续写
    B) 在续写点先注入 <think> 再续写
    """
    buffer = load_buffer(BUFFER_PT)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=DTYPE,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    device = get_model_device(model)

    prompt_ids = buffer["prompt_ids_cpu"].to(device)
    temperature = float(buffer["temperature"])
    top_p = float(buffer["top_p"])
    base_seed = int(buffer["seed"])

    eos_id = tokenizer.eos_token_id

    # 读“你改错后的文本”
    edited_text = read_text(BUFFER_TXT_EDIT)

    # 重新 tokenize（这一步就是“把你改错后的内容”变成模型上下文）
    edited_prefix_ids = tokenizer.encode(edited_text, add_special_tokens=False)
    edited_prefix_ids_t = torch.tensor([edited_prefix_ids], dtype=torch.long, device=device)

    # ===== Branch A: direct continuation =====
    print("\n\n==============================")
    print("==== Branch A: direct gen ====")
    print("==============================\n")

    if print_prefix:
        print(edited_text, end="", flush=True)

    full_ids_A = torch.cat([prompt_ids, edited_prefix_ids_t], dim=1)
    past_A, logits_A = prefill_kv(model, full_ids_A)

    genA = torch.Generator(device=device)
    genA.manual_seed(base_seed + 1)

    stream_generate(
        model,
        tokenizer,
        past_A,
        logits_A,
        max_new_tokens=max_new_tokens_after,
        temperature=temperature,
        top_p=top_p,
        eos_id=eos_id,
        generator=genA,
        print_stream=True,
    )
    print("\n")

    # ===== Branch B: inject <think> then continuation =====
    print("\n\n=========================================")
    print("==== Branch B: inject <think> then gen ===")
    print("=========================================\n")

    if print_prefix:
        print(edited_text, end="", flush=True)

    full_ids_B = torch.cat([prompt_ids, edited_prefix_ids_t], dim=1)
    past_B, logits_B = prefill_kv(model, full_ids_B)

    # 注入 <think>（你要求：在继续生成的地方“紧接着”thinking tag token）
    print(inject_think_text, end="", flush=True)
    inject_ids = tokenizer.encode(inject_think_text, add_special_tokens=False)
    inject_ids_t = torch.tensor([inject_ids], dtype=torch.long, device=device)

    out_inj = model(inject_ids_t, past_key_values=past_B, use_cache=True)
    past_B = out_inj.past_key_values
    logits_B = out_inj.logits[:, -1, :]

    genB = torch.Generator(device=device)
    genB.manual_seed(base_seed + 2)

    stream_generate(
        model,
        tokenizer,
        past_B,
        logits_B,
        max_new_tokens=max_new_tokens_after,
        temperature=temperature,
        top_p=top_p,
        eos_id=eos_id,
        generator=genB,
        print_stream=True,
    )
    print("\n")


if __name__ == "__main__":
    SYSTEM_PROMPT = ("""
You are a careful and rigorous problem solver with a unique internal reasoning capability.

Core Instruction: You must dynamically engage in reasoning whenever necessary. This process is triggered by the <think> tag.

###Operational Rules:

1.Trigger: At any point during your response, if you output the token <think>, you must immediately pause your external response and enter "Reasoning Mode."
2.Reasoning Mode: Inside this mode, perform your internal analysis, self-correction, or step-by-step verification.
3.Closure: Once your reasoning is complete, you must output </think> to close the block.
4.Continuation (Crucial): After </think>, you must resume your response exactly where you left off before the <think> tag.

###Strict Constraints:

1.DO NOT repeat the text that was output immediately before the <think> tag.
2.DO NOT start a new sentence if the <think> tag interrupted the middle of a sentence. Simply finish the sentence or paragraph logically.
3.You may use multiple <think>...</think> blocks throughout your response as needed.
    """)

    USER_PROMPT = "Solve for x: (x-1)(x+1)=35."

    # Stage 1: 生成到 checkpoint 并保存 prefix buffer，然后你手动改错
    stage1_generate_and_checkpoint(
        system_prompt=SYSTEM_PROMPT,
        user_prompt=USER_PROMPT,
        checkpoint_delay_tokens_after_first_think_end=200,  # 你想要的“中间点”
        max_new_tokens_before_checkpoint=5000,
        temperature=0.4,
        top_p=0.9,
        seed=1234,
    )

    # Stage 2: 从你改错后的 prefix 分叉续写两次（一次直接，一次先注入 <think>）
    stage2_branch_continue(
        inject_think_text="<think>",  # 你要的 thinking tag token
        max_new_tokens_after=2000,
        print_prefix=True,
    )


Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00, 146.66it/s]



===== [Stage 1] Generating until checkpoint... =====

<think>
Okay, let's see. I need to solve the equation (x - 1)(x + 1) = 35. Hmm, first, maybe I should expand the left side. So, (x - 1)(x + 1) is a difference of squares, right? That should be x squared minus 1. So, x² - 1 = 35. Then, I can add 1 to both sides to get x² = 36. Taking the square root of both sides gives x = 6 or x = -6. Wait, let me check that again. If I expand (x - 1)(x + 1), it's definitely x² - 1. Then adding 1 to both sides makes x² = 36. Yeah, so the solutions are 6 and -6. Let me verify by plugging them back in. For x = 6: (6 - 1)(6 + 1) = 5*7 = 35. That works. For x = -6: (-6 -1)(-6 +1) = (-7)*(-5) = 35. That also works. So the solutions are correct. I think that's all.
</think>

The equation $(x - 1)(x + 1) = 35$ can be solved by first expanding the left-hand side using the difference of squares formula:  
$$
(x - 1)(x + 1) = x^2 - 1
$$  
Substituting this into the equation gives:  
$$
x^2 - 1 = 35
$$  
Addi


Edit & save it, then press ENTER to continue to Stage 2... 


Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00, 174.76it/s]



==== Branch A: direct gen ====

<think>
Okay, let's see. I need to solve the equation (x - 1)(x + 1) = 35. Hmm, first, maybe I should expand the left side. So, (x - 1)(x + 1) is a difference of squares, right? That should be x squared minus 1. So, x² - 1 = 35. Then, I can add 1 to both sides to get x² = 36. Taking the square root of both sides gives x = 6 or x = -6. Wait, let me check that again. If I expand (x - 1)(x + 1), it's definitely x² - 1. Then adding 1 to both sides makes x² = 36. Yeah, so the solutions are 6 and -6. Let me verify by plugging them back in. For x = 6: (6 - 1)(6 + 1) = 5*7 = 35. That works. For x = -6: (-6 -1)(-6 +1) = (-7)*(-5) = 35. That also works. So the solutions are correct. I think that's all.
</think>

The equation $(x - 1)(x + 1) = 35$ can be solved by first expanding the left-hand side using the difference of squares formula:  
$$
(x - 1)(x + 1) = x^2 - 1
$$  
Substituting this into the equation gives:  
$$
x^2 - 1 = 35
$$  
Adding 1 to both sides:  




 (-5) = 35$  

**Final Answer:** $x = 6$ or $x = -6$.



==== Branch B: inject <think> then gen ===

<think>
Okay, let's see. I need to solve the equation (x - 1)(x + 1) = 35. Hmm, first, maybe I should expand the left side. So, (x - 1)(x + 1) is a difference of squares, right? That should be x squared minus 1. So, x² - 1 = 35. Then, I can add 1 to both sides to get x² = 36. Taking the square root of both sides gives x = 6 or x = -6. Wait, let me check that again. If I expand (x - 1)(x + 1), it's definitely x² - 1. Then adding 1 to both sides makes x² = 36. Yeah, so the solutions are 6 and -6. Let me verify by plugging them back in. For x = 6: (6 - 1)(6 + 1) = 5*7 = 35. That works. For x = -6: (-6 -1)(-6 +1) = (-7)*(-5) = 35. That also works. So the solutions are correct. I think that's all.
</think>

The equation $(x - 1)(x + 1) = 35$ can be solved by first expanding the left-hand side using the difference of squares formula:  
$$
(x - 1)(x + 1) = x^2 - 1
$$  
Substituting this into t

In [None]:
# replace the <think> tag with <analysis> and insert the <think>

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "autodl-tmp/Qwen/Qwen3-8B"
DTYPE = torch.bfloat16 

def run_experiment(
    system_prompt: str,
    user_prompt: str,
    delay_tokens_after_analysis_end: int = 5,
    inject_text: str = "\n<think>\nI am not fully confident. Re-check and decide again.\n",
    temperature: float = 0.7,
    max_new_tokens: int = 2000,
):
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=DTYPE,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    device = next(model.parameters()).device

    input_ids = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(device)

    think_id = tokenizer.convert_tokens_to_ids("<think>")
    think_end_id = tokenizer.convert_tokens_to_ids("</think>")
    analysis_open_ids = tokenizer.encode("<analysis>", add_special_tokens=False)
    analysis_close_ids = tokenizer.encode("</analysis>", add_special_tokens=False)

    with torch.inference_mode():
        out = model(input_ids, use_cache=True)
        past_kv = out.past_key_values
        logits = out.logits[:, -1, :]

        generated_count = 0
        has_finished_first_analysis = False
        tokens_since_analysis = 0
        has_injected = False 

        while generated_count < max_new_tokens:
            # 1. 检查注入时机
            if has_finished_first_analysis and not has_injected and tokens_since_analysis >= delay_tokens_after_analysis_end:
                # 原样输出注入内容
                print(inject_text, end="", flush=True)
                
                inject_ids = tokenizer.encode(inject_text, add_special_tokens=False, return_tensors="pt").to(device)
                out_inject = model(inject_ids, past_key_values=past_kv, use_cache=True)
                
                past_kv = out_inject.past_key_values
                logits = out_inject.logits[:, -1, :]
                
                has_injected = True
                # 注入后直接跳到下一轮采样，不执行后续逻辑
                continue

            # 2. 采样
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                next_token_id = torch.multinomial(probs, num_samples=1)
            else:
                next_token_id = torch.argmax(logits, dim=-1, keepdim=True)
            
            tid = next_token_id.item()
            if tid == tokenizer.eos_token_id:
                break

            # 3. 处理标签（仅在注入前进行劫持）
            if not has_injected:
                if tid == think_id:
                    print("<analysis>", end="", flush=True)
                    target_ids = torch.tensor([analysis_open_ids], device=device)
                    out = model(target_ids, past_key_values=past_kv, use_cache=True)
                elif tid == think_end_id:
                    print("</analysis>", end="", flush=True)
                    target_ids = torch.tensor([analysis_close_ids], device=device)
                    out = model(target_ids, past_key_values=past_kv, use_cache=True)
                    has_finished_first_analysis = True
                    tokens_since_analysis = 0
                else:
                    # 正常 Token
                    print(tokenizer.decode([tid]), end="", flush=True)
                    out = model(next_token_id.view(1, 1), past_key_values=past_kv, use_cache=True)
                    if has_finished_first_analysis:
                        tokens_since_analysis += 1
            else:
                # --- 注入后阶段：不做任何拦截，模型生成什么就打印什么 ---
                print(tokenizer.decode([tid]), end="", flush=True)
                out = model(next_token_id.view(1, 1), past_key_values=past_kv, use_cache=True)

            # 状态更新
            past_kv = out.past_key_values
            logits = out.logits[:, -1, :]
            generated_count += 1

    print("\n")

if __name__ == "__main__":
    run_experiment(
        system_prompt="You are a solver. Use <analysis> and <think> for internal reasoning.",
        user_prompt="Solve (x-1)(x+1)=35.",
        delay_tokens_after_analysis_end=5,
        inject_text="\n<think>\nI am not fully confident. Re-check and decide again.\n"
    )

Loading checkpoint shards: 100%|██████████| 5/5 [00:04<00:00,  1.16it/s]

<analysis>Okay, let me try to




 solve this equation: (x - 1)(x + 1) = 35. Hmm, first, I remember that when you have two binomials multiplied together like this, you can expand them. Let me recall the formula for multiplying (a - b)(a + b). Oh right, that's a difference of squares, which equals a² - b². So applying that here, (x - 1)(x + 1) should be x² - 1², which is x² - 1. 

So the equation becomes x² - 1 = 35. Now, I need to solve for x. Let me add 1 to both sides to isolate the x² term. That would give me x² = 35 + 1, which is x² = 36. 

Now, to solve for x, I take the square root of both sides. Remembering that taking the square root gives both a positive and negative solution, so x = ±√36. The square root of 36 is 6, so x = 6 or x = -6. 

Wait, let me check if I did that correctly. Let me substitute back into the original equation to verify. 

First, if x = 6: (6 - 1)(6 + 1) = (5)(7) = 35. That works. 

Now, x = -6: (-6 - 1)(-6 + 1) = (-7)(-5) = 35. That also works. 

So both solutions are valid. Therefore, th