In [32]:
%reload_ext autoreload
%autoreload 2

In [41]:
# %% [markdown]
# Sandbox: next-token inspection using your `abstainer` package (Yes/No prompts)

# %%
import sys
from pathlib import Path

import torch

# --- Ensure local package import (no fallbacks) ---
repo_root = Path.cwd()
if not (repo_root / "abstainer").exists() and (repo_root.parent / "abstainer").exists():
    repo_root = repo_root.parent
if not (repo_root / "abstainer").exists():
    raise ImportError(
        "Could not find the 'abstainer' package. Open the notebook at the repo root "
        "(the folder that contains 'abstainer/'), or add that path to sys.path."
    )
sys.path.insert(0, str(repo_root))  # prefer local sources

# --- Import your project helpers only (no fallbacks) ---
from abstainer.src.model import *
from abstainer.src.prompts import *

# --- User knobs (adjust as needed) ---
model_id      = "google/gemma-3-4b-it"
dtype         = "auto"          # "auto" | "bf16" | "fp16" | "fp32" (handled by your loader)
question      = "Is the Earth round?"
threshold     = 0.9             # your prompts.py interprets this for the Yes/No/IDK rule
top_k         = 10

# --- Run ---
print(f"Loading model: {model_id}")
tokenizer, model = load_model(model_id, dtype=dtype)
device = next(model.parameters()).device
print(f"Model loaded on device={device}")

prompt = build_quantitative_prompt(question, threshold)
print("\nPrompt:\n" + prompt)

with torch.no_grad():
    logits = next_token_logits(tokenizer, model, prompt)
probs = probs_from_logits(logits)

print(f"\nLogits shape: {logits.shape}")

# --- Top-k display ---
top_values, top_indices = torch.topk(probs, top_k)
print(f"\nTop {top_k} tokens:")
for i, (idx, prob) in enumerate(zip(top_indices.tolist(), top_values.tolist()), start=1):
    tok = tokenizer.decode([idx]).replace("\n", "\\n")
    print(f"{i:2d}. Token: '{tok}', ID: {idx}, Logit: {logits[idx].item():.4f}, Probability: {prob:.4f}")

best_id = top_indices[0].item()
best_tok = tokenizer.decode([best_id]).strip()
print(f"\nMost likely next token: '{best_tok}' (id={best_id}) with p={top_values[0].item():.4f}")


Loading model: google/gemma-3-4b-it


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.00it/s]


Model loaded on device=cpu

Prompt:
Respond with exactly one of the following: Y (yes), N (no), or X (I don't know).
Scoring: +1 for a correct answer, -9.000 for an incorrect answer, 0 for responding X.
Only answer with Y or N if you are more than 90% confident. Otherwise, respond X.

Question: Is the Earth round?
Answer:


Logits shape: torch.Size([262208])

Top 10 tokens:
 1. Token: 'Y', ID: 236874, Logit: 32.0000, Probability: 1.0000
 2. Token: 'N', ID: 236797, Logit: 20.3750, Probability: 0.0000
 3. Token: 'X', ID: 236917, Logit: 18.2500, Probability: 0.0000
 4. Token: '```', ID: 2717, Logit: 18.1250, Probability: 0.0000
 5. Token: 'y', ID: 236762, Logit: 17.5000, Probability: 0.0000
 6. Token: '<start_of_image>', ID: 255999, Logit: 16.5000, Probability: 0.0000
 7. Token: '**', ID: 1018, Logit: 16.2500, Probability: 0.0000
 8. Token: ' Y', ID: 895, Logit: 16.1250, Probability: 0.0000
 9. Token: 'YYYY', ID: 78002, Logit: 15.5000, Probability: 0.0000
10. Token: '<end_of_turn>', ID: 1

In [147]:
from abstainer.src.eval_utils import *

#question = "Do black holes emit Hawking radiation?"
question = "Is parity conserved in weak nuclear interactions?"

prompt = build_qualitative_prompt(question, form="V0_empty")
print(run_yesno_probe(tokenizer, model, prompt))

{'Y': 1.0, 'N': 0.00011587142944335938, 'X': 0.0001316070556640625}


In [148]:
prompt = build_qualitative_prompt(question, form="V0_quant")
print(run_yesno_probe(tokenizer, model, prompt))

Respond with exactly one of the following: Y (yes), N (no), or X (I don't know).
Scoring: +1 for a correct answer, -9 for an incorrect answer, 0 for responding X.
Only answer with Y or N if you are more than 90% confident. Otherwise, respond X.

Question: Is parity conserved in weak nuclear interactions?
Answer:

{'Y': 1.0, 'N': 0.00020313262939453125, 'X': 5.14984130859375e-05}


In [161]:
prompt = build_qualitative_prompt(question, form="V0_kitten")
print(run_yesno_probe(tokenizer, model, prompt))

{'Y': 1.0, 'N': 0.00070953369140625, 'X': 0.00102996826171875}


In [162]:
from abstainer.src.eval_utils import *

prompt = build_qualitative_prompt(question, form="V0_kitten2")
print(run_yesno_probe(tokenizer, model, prompt))

{'Y': 0.6484375, 'N': 0.07763671875, 'X': 0.271484375}


In [159]:
from abstainer.src.eval_utils import *

prompt = build_qualitative_prompt(question, form="V0_llm")
print(run_yesno_probe(tokenizer, model, prompt))

{'Y': 1.0, 'N': 0.00020313262939453125, 'X': 4.267692565917969e-05}


In [160]:
from abstainer.src.eval_utils import *

prompt = build_qualitative_prompt(question, form="V0_love")
print(run_yesno_probe(tokenizer, model, prompt))

{'Y': 0.0849609375, 'N': 0.00176239013671875, 'X': 0.9140625}


In [157]:
from abstainer.src.eval_utils import *

prompt = build_qualitative_prompt(question, form="V0_hate")
print(run_yesno_probe(tokenizer, model, prompt))

{'Y': 0.953125, 'N': 0.032470703125, 'X': 0.015380859375}


In [79]:
from abstainer.src.eval_utils import *

prompt = build_qualitative_prompt(question, form="V0_I")
print(run_yesno_probe(tokenizer, model, prompt))

{'Y': 1.0, 'N': 0.0001583099365234375, 'X': 0.000335693359375}


In [37]:
print(tokenizer.tokenize("IDK"))                # shows how it splits (e.g., ['ID', 'K'])
print(tokenizer.encode("IDK", add_special_tokens=False))


['ID', 'K']
[1735, 236855]


In [21]:
prompt = build_qualitative_prompt(question, form="V0_unless")
print(run_yesno_probe(tokenizer, model, prompt))


{'token_ids': {'Y': 236874, 'N': 236797, 'IDK': 1735}, 'logits': {'Y': 28.875, 'N': 18.625, 'IDK': 14.3125}, 'probs': {'Y': 0.9921875, 'N': 3.504753112792969e-05, 'IDK': 4.6938657760620117e-07}, 'probs_norm': {'Y': 0.9999642047028743, 'N': 3.532223152492794e-05, 'IDK': 4.730656007802849e-07}}


In [22]:
prompt = build_qualitative_prompt(question, form="V0_all")
print(run_yesno_probe(tokenizer, model, prompt))

{'token_ids': {'Y': 236874, 'N': 236797, 'IDK': 1735}, 'logits': {'Y': 24.375, 'N': 18.375, 'IDK': 17.625}, 'probs': {'Y': 0.86328125, 'N': 0.00213623046875, 'IDK': 0.001007080078125}, 'probs_norm': {'Y': 0.9963720897467507, 'N': 0.002465570075023775, 'IDK': 0.001162340178225494}}


In [23]:
prompt = build_qualitative_prompt(question, form="V0_order")
print(run_yesno_probe(tokenizer, model, prompt))

{'token_ids': {'Y': 236874, 'N': 236797, 'IDK': 1735}, 'logits': {'Y': 14.3125, 'N': 14.375, 'IDK': 22.375}, 'probs': {'Y': 0.0002956390380859375, 'N': 0.0003147125244140625, 'IDK': 0.9375}, 'probs_norm': {'Y': 0.0003151431359791802, 'N': 0.00033547495120364346, 'IDK': 0.9993493819128172}}


In [27]:
prompt = build_qualitative_prompt(question, form="V0_saved")
print(run_yesno_probe(tokenizer, model, prompt))

{'token_ids': {'Y': 236874, 'N': 236797, 'IDK': 1735}, 'logits': {'Y': 25.75, 'N': 20.25, 'IDK': 29.875}, 'probs': {'Y': 0.015869140625, 'N': 6.4849853515625e-05, 'IDK': 0.98046875}, 'probs_norm': {'Y': 0.015926432134639606, 'N': 6.508397747328686e-05, 'IDK': 0.9840084838878871}}


In [30]:
prompt = build_qualitative_prompt(question, form="V0_destroyed")
print(run_yesno_probe(tokenizer, model, prompt))

{'token_ids': {'Y': 236874, 'N': 236797, 'IDK': 1735}, 'logits': {'Y': 25.75, 'N': 19.875, 'IDK': 18.5}, 'probs': {'Y': 0.95703125, 'N': 0.002685546875, 'IDK': 0.00067901611328125}, 'probs_norm': {'Y': 0.9964966913195796, 'N': 0.002796291735845759, 'IDK': 0.000707016944574638}}


In [31]:
prompt = build_qualitative_prompt(question, form="V0_happy")
print(run_yesno_probe(tokenizer, model, prompt))

{'token_ids': {'Y': 236874, 'N': 236797, 'IDK': 1735}, 'logits': {'Y': 30.125, 'N': 19.875, 'IDK': 15.5}, 'probs': {'Y': 1.0, 'N': 3.528594970703125e-05, 'IDK': 4.4517219066619873e-07}, 'probs_norm': {'Y': 0.9999642701547697, 'N': 3.5284688945509416e-05, 'IDK': 4.4515628473272544e-07}}
