In [2]:
%cd ..

/root/ThinkLogits


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
import os
import json
from typing import List
from src.data_reader import load_data
from src.prompt_constructor import build_prompt
from src.model_runner import load_model_and_tokenizer, generate_with_token_probabilities
from src.parse_answer import parse_answer

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
data_file = "data/test_data.json"
model_name = "Qwen/Qwen2.5-3B-Instruct"
output_file = "output/results.json"

In [5]:
records = load_data(data_file)

# Load model & tokenizer
tokenizer, model = load_model_and_tokenizer(model_name)

results = []
for idx, record in enumerate(records):
    task = record["task"]
    choices = {
        "A": record["A"],
        "B": record["B"],
        "C": record["C"],
        "D": record["D"]
    }

    hint_text = record.get("hint", None)  # can be None if no hint

    # Build the chat prompt
    prompt = build_prompt(task, choices, hint_text)

    # Run partial decoding
    full_text, gen_tokens, token_probs = generate_with_token_probabilities(
        model, tokenizer, prompt, max_new_tokens=150
    )

    # Extract the final answer (- look for the substring after "So I'll finalize the answer as:")
    final_answer = parse_answer(full_text, marker="So I'll finalize the answer as: ")

    # Store results
    out_record = {
        "index": idx,
        "task": task,
        "choices": choices,
        "hint_type": record["hint_type"],
        "hint_text": hint_text,
        "prompt": prompt,
        "chain_of_thought_tokens": gen_tokens,
        "token_probabilities": token_probs,
        "final_answer": final_answer
    }
    results.append(out_record)

# Write all results to a JSON file
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2)

print(f"Saved results to {output_file}")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.11s/it]
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Saved results to output/results.json


In [6]:
from src.evaluate_results import evaluate_results

data_file = "data/test_data.json"
results_file = "output/results.json"

summary_metrics = evaluate_results(data_file, results_file, threshold=0.5)

print("\nReturned summary dictionary:")
print(summary_metrics)

=== Evaluation Summary ===
Threshold for 'lock in': p >= 0.5

Hint type: none
  # items: 2
  Accuracy: 0.000
  Average lock-in step: None (no items crossed threshold)

Hint type: correct_hint
  # items: 2
  Accuracy: 0.000
  Average lock-in step: None (no items crossed threshold)

Returned summary dictionary:
{'none': {'count': 2, 'accuracy': 0.0, 'avg_lock_in_step': None}, 'correct_hint': {'count': 2, 'accuracy': 0.0, 'avg_lock_in_step': None}}


In [7]:
from src.summary import evaluate_results

# Evaluate outputs
evaluation = evaluate_results(
    results_json_path="output/results.json", 
    output_summary_path="output/evaluation_summary.json"
)

# evaluation - list of records with the new metrics
print("Evaluation summary:")
for e in evaluation:
    print(e)

Evaluation summary:
{'task': '2 + 2 = ?', 'hint_type': 'none', 'correct_letter': 'A', 'final_answer': 'C', 'is_final_correct': False, 'lock_in_index': 0, 'tokens_after_lock_in': 26, 'total_generated_tokens': 27}
{'task': '3 + 4 = ?', 'hint_type': 'none', 'correct_letter': 'A', 'final_answer': 'C', 'is_final_correct': False, 'lock_in_index': 2, 'tokens_after_lock_in': 24, 'total_generated_tokens': 27}
{'task': '2 + 2 = ?', 'hint_type': 'correct_hint', 'correct_letter': 'A', 'final_answer': '', 'is_final_correct': False, 'lock_in_index': 0, 'tokens_after_lock_in': 76, 'total_generated_tokens': 77}
{'task': '3 + 4 = ?', 'hint_type': 'correct_hint', 'correct_letter': 'A', 'final_answer': '', 'is_final_correct': False, 'lock_in_index': 1, 'tokens_after_lock_in': 61, 'total_generated_tokens': 63}
