In [None]:

%load_ext autoreload
%autoreload 2

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import os
from tqdm import tqdm
import re
from typing import Any

from cs336_alignment.zeroshot import parse_gsm8k_response

# Load Qwen model and tokenizer
model_path = "../models/Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")

# GSM8K test file path
gsm8k_test_path = "../data/gsm8k/test.jsonl"

# Format input prompts for GSM8K
def format_gsm8k_prompt(question):
    prompt = f"{question}\nAnswer: \n"
    return prompt

# Generate predictions
def generate_answer(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        # Increase max_new_tokens for math problems which require more steps
        # Use greedy decoding with temperature=0.0 and top_p=1.0
        output = model.generate(
            **inputs, 
            max_new_tokens=200,
            temperature=0.0,
            top_p=1.0,
            do_sample=False
        )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    return decoded

# Evaluate GSM8K
os.makedirs("gsm8k_evaluation_results", exist_ok=True)

# Load GSM8K test data
examples = []
with open(gsm8k_test_path, 'r', encoding='utf-8') as f:
    for line in f:
        example = json.loads(line)
        examples.append(example)

results = []
correct = 0

for example in tqdm(examples, desc="Evaluating GSM8K"):
    prompt = format_gsm8k_prompt(example['question'])
    raw_output = generate_answer(prompt)
    
    # Extract prediction from model output
    pred = parse_gsm8k_response(raw_output)
    
    gt_match = re.search(r'####\s*(\d+)', example['answer'])
    gt = gt_match.group(1) if gt_match else None
    # print(f"Raw: {raw_output}\nGT: {gt}\nPred: {pred}\n")
    results.append({
        "question": example['question'],
        "ground_truth": gt,
        "prediction": pred,
        "raw_output": raw_output
    })
    
    if pred == gt:
        correct += 1

accuracy = correct / len(results) if results else 0
print(f"Zero-shot accuracy on GSM8K: {accuracy:.2%}")
print(f"Correct: {correct}, Total: {len(results)}")

# Save results
with open("gsm8k_evaluation_results/qwen_gsm8k.json", "w") as f:
    json.dump({"accuracy": accuracy, "results": results}, f, indent=2)



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Evaluating GSM8K:   0%|          | 0/1319 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating GSM8K:   0%|          | 1/1319 [00:05<1:58:25,  5.39s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating GSM8K:   0%|          | 2/1319 [00:10<1:56:28,  5.31s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating GSM8K:   0%|          | 3/1319 [00:15<1:56:06,  5.29s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating GSM8K:   0%|          | 4/1319 [00:17<1:26:31,  3.95s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating GSM8K:   0%|          | 5/1319 [00:23<1:37:20,  4.44s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating GSM8K:   0%|          | 6/1319 [00:26<1:31:46,  4.19s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Evaluating GSM8K:   1%|

Zero-shot accuracy on GSM8K: 22.59%
Correct: 298, Total: 1319





In [None]:
def analyze_errors(results):
    total_none_outputs = 0
    for result in results:
        pred = result["prediction"]
        if pred == "None" or pred == "none" or pred == None: 
            total_none_outputs += 1

    print(f"Total 'None' predictions: {total_none_outputs}")
    errors = [r for r in results if r["prediction"] != r["ground_truth"]]
    print(f"Total errors: {len(errors)}")
    
    for i, error in enumerate(errors[:50]):
        print(f"\nError {i+1}:")
        print(f"Raw: {error['raw_output']}")
        print(f"Expected: {error['ground_truth']}")
        print(f"Predicted: {error['prediction']}")
        print("-" * 50)
    
analyze_errors(results)

Total 'None' predictions: 0
Total errors: 1021

Error 1:
Raw: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Answer: 
1. **Calculate the total number of eggs laid by the ducks in a day:**
   - Janet's ducks lay 16 eggs per day.

2. **Calculate the total number of eggs eaten in a day:**
   - Janet eats 3 eggs for breakfast.
   - She eats 4 eggs for baking muffins.
   - Therefore, the total number of eggs eaten in a day is:
     \[
     3 + 4 = 7 \text{ eggs}
     \]

3. **Calculate the number of eggs remaining for sale:**
   - The total number of eggs laid in a day is 16.
   - The number of eggs eaten in a day is 7.
   - Therefore, the number of eggs remaining for sale is:
     \[
     16 - 7 = 9 \text{ eggs}
     \]

4. **Calculate the total revenue from s