In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import pandas as pd
import numpy as np
import json
import torch
import re
import ast

In [None]:
model_name = "Qwen/Qwen2.5-14B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="mps"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)



In [None]:
system_prompt = "You are an expert at solving complex grade-school math problems with step-by-step reason and return output in json format"

user_prompt_template = """{question}

Solve the above grade-school math problem by breaking it down step by step.

Follow this format strictly:
Reason: [step 1: ..., step 2: ..., ...]
Answer: [final numeric answer] # present only numeric value for Answer
"""


In [None]:
with open('../gsm8k/gsm8k.json', 'r') as f:
    data = json.load(f)

In [None]:
data[0]['question']

In [None]:
data[0]['answer'].split('#### ')

In [None]:

def save_pred(response, ques, gt_reason, gt_answer, path):
    response = response.replace('"', "'")
    reason_start = re.search(r"Reason:", response, re.DOTALL).span()[0]
    reason_end = re.search(r"Reason:", response, re.DOTALL).span()[1]
    answer_start = re.search(r"Answer:", response, re.DOTALL).span()[0]
    answer_end = re.search(r"Answer:", response, re.DOTALL).span()[1]
    
    reason_text = response[reason_start:answer_start]
    reason_lines = [line.strip() for line in reason_text.strip().split("\n") if line.strip()]
    
    answer_text = response[answer_start:]
    answer_str = answer_text.replace("Answer:", "").strip()
    answer_dict = ast.literal_eval(answer_str)
    
    result = {
            "Question": ques,
            "Reason": reason_lines,
            "Answer": answer_dict,
            "GT_answer": gt_answer,
            "GT_reason": gt_reason
        }
    with open(path, "w") as f:
        json.dump(result, f, indent=2)


In [None]:

for num, pair in enumerate(data):
    if num < 809:
        continue
    print(num)
    ques = pair['question']
    gt_reason = pair['answer'].split('#### ')[0]
    gt_answer = pair['answer'].split('#### ')[1]
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt_template.format(question=ques)}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1024
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    path = '../gsm8k/qwen_25_pred/' + str(num) + '.json'
    save_pred(response, ques, gt_reason, gt_answer, path)
    
    
