In [1]:
import os
from huggingface_hub import snapshot_download
from pika.probe.linear_eoi_probe import LinearEoiProbe
from pika.hub import download_probe_path, download_probe
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MODEL_USED = "Qwen/Qwen2.5-Math-7B-Instruct"

In [4]:
probe = download_probe(
    repo_id="CoffeeGitta/pika-probes",
    model_name=MODEL_USED,
    dataset="MATH",
    label_type="majority_vote_is_correct",
    device="cuda:2"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_USED)


# probe = download_probe(probe_name="Qwen2.5-Math-7B-Instruct--MATH--linear-eoi-probe--mv-correct--k5-t0.7", probe_type="linear_eoi_probe", device="cuda:2")

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 55370.35it/s]
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 186.83it/s]


In [5]:
BOXED_TEMPLATE = "Provide your answer in \\boxed{}"

In [6]:
SAMPLE_QUESTIONS = [
    "Solve: What is 2 + 2? ",
    "Find the sum of all positive integers $ n $ such that $ n + 2 $ divides the product $ 3(n + 3)(n^2 + 9) $. ",
    "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
    "There are exactly three positive real numbers $ k $ such that the function $ f(x) = \frac{(x - 18)(x - 72)(x - 98)(x - k)}{x} $ defined over the positive real numbers achieves its minimum value at exactly two positive real numbers $ x $. Find the sum of these three values of $ k $."
]

for i in range(len(SAMPLE_QUESTIONS)):
    tmp_msg = SAMPLE_QUESTIONS[i] + BOXED_TEMPLATE
    SAMPLE_QUESTIONS[i] = tokenizer.apply_chat_template([{"role": "user", "content": tmp_msg}], tokenize=False, add_generation_prompt=True)



SAMPLE_QUESTIONS

['<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>user\nSolve: What is 2 + 2? Provide your answer in \\boxed{}<|im_end|>\n<|im_start|>assistant\n',
 '<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>user\nFind the sum of all positive integers $ n $ such that $ n + 2 $ divides the product $ 3(n + 3)(n^2 + 9) $. Provide your answer in \\boxed{}<|im_end|>\n<|im_start|>assistant\n',
 '<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>user\nA robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?Provide your answer in \\boxed{}<|im_end|>\n<|im_start|>assistant\n',
 '<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>user\nThere are exactly three positive real numbers $ k $ such that the

In [7]:
predictions = probe.predict(SAMPLE_QUESTIONS)

Extracting activations: 100%|██████████| 1/1 [00:01<00:00,  1.01s/it]


In [8]:
predictions #((idx), (preds))

(tensor([0, 1, 2, 3], dtype=torch.int32),
 tensor([0.9380, 0.5321, 0.9586, 0.1813]))

In [9]:
idx, scores = predictions

In [10]:
results = [
    {"i": int(i), "p_correct": float(s), "question": SAMPLE_QUESTIONS[int(i)]}
    for i, s in zip(idx.tolist(), scores.tolist())
]

for r in results:
    print(f'idx:{r["i"]:>2},  p_correct={r["p_correct"]:.3f}')

idx: 0,  p_correct=0.938
idx: 1,  p_correct=0.532
idx: 2,  p_correct=0.959
idx: 3,  p_correct=0.181
