In [None]:

%load_ext autoreload
%autoreload 2

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import os
from tqdm import tqdm
import re
from typing import Any

from cs336_alignment.zeroshot import parse_gsm8k_response

# Load Qwen model and tokenizer
model_path = "/home/alvin/Homework/s2025-assignment3-alignment/notebooks/qwen2.5-3B-instruct-finetuned/final_model"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")

# GSM8K test file path
gsm8k_test_path = "../data/gsm8k/test.jsonl"

# Format input prompts for GSM8K
def format_gsm8k_prompt(question):
    prompt = (
            "Below is an instruction that describes a task. Write a response that appropriately completes the request."
            f"\n\n### Instruction:\n{question}\n\n### Response:\n"
        )
    prompt += f"Answer: "
    return prompt

# Generate predictions
def generate_answer(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        # Increase max_new_tokens for math problems which require more steps
        # Use greedy decoding with temperature=0.0 and top_p=1.0
        output = model.generate(
            **inputs, 
            max_new_tokens=200,
            temperature=0.0,
            top_p=1.0,
            do_sample=False
        )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    return decoded

# Evaluate GSM8K
os.makedirs("gsm8k_evaluation_results_finetuned", exist_ok=True)

# Load GSM8K test data
examples = []
with open(gsm8k_test_path, 'r', encoding='utf-8') as f:
    for line in f:
        example = json.loads(line)
        examples.append(example)

results = []
correct = 0

for example in tqdm(examples[:100], desc="Evaluating GSM8K"):
    prompt = format_gsm8k_prompt(example['question'])
    raw_output = generate_answer(prompt)
    
    # Extract prediction from model output
    pred = parse_gsm8k_response(raw_output)
    
    gt_match = re.search(r'####\s*(\d+)', example['answer'])
    gt = gt_match.group(1) if gt_match else None
    # print(f"Raw: {raw_output}\nGT: {gt}\nPred: {pred}\n")
    results.append({
        "question": example['question'],
        "ground_truth": gt,
        "prediction": pred,
        "raw_output": raw_output
    })
    
    if pred == gt:
        correct += 1

accuracy = correct / len(results) if results else 0
print(f"Fine tuned accuracy on GSM8K: {accuracy:.2%}")
print(f"Correct: {correct}, Total: {len(results)}")

# Save results
with open("gsm8k_evaluation_results_finetuned/qwen_gsm8k.json", "w") as f:
    json.dump({"accuracy": accuracy, "results": results}, f, indent=2)



  from .autonotebook import tqdm as notebook_tqdm
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.26it/s]
Evaluating GSM8K: 100%|██████████| 100/100 [13:15<00:00,  7.96s/it]

Zero-shot accuracy on GSM8K: 1.00%
Correct: 1, Total: 100





In [None]:
def analyze_errors(results):
    total_none_outputs = 0
    for result in results:
        pred = result["prediction"]
        if pred == "None" or pred == "none" or pred == None: 
            total_none_outputs += 1

    print(f"Total 'None' predictions: {total_none_outputs}")
    errors = [r for r in results if r["prediction"] != r["ground_truth"]]
    print(f"Total errors: {len(errors)}")
    
    for i, error in enumerate(errors[:50]):
        print(f"\nError {i+1}:")
        print(f"Raw: {error['raw_output']}")
        print(f"Expected: {error['ground_truth']}")
        print(f"Predicted: {error['prediction']}")
        print("-" * 50)
    
analyze_errors(results)

Total 'None' predictions: 0
Total errors: 99

Error 1:
Raw: Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

### Response:
Answer: 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
Expected: 18
Predicted: 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
--------------------------------------------------

Error 2:
Raw: Below is an instruct