In [1]:
import re, argparse
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
)
import torch
from tqdm import tqdm
from pathlib import Path
from accelerate.utils import fsdp_utils

In [3]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

In [44]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")

def extract_numerical_answer(answer_text):
    # GSM8K answers end with #### followed by the numerical answer
    match = re.search(r"#### ([-\d,]+)", answer_text)
    if match:
        # Remove commas and convert to int
        return int(match.group(1).replace(",", ""))
    return None

In [45]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    #print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

In [46]:
model_name = '/data2/alex/verifiers/outputs/Qwen-1.5B-GRPO-base/checkpoint-1869'

In [47]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

In [48]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Set pad token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Set padding side to left for decoder-only models
tokenizer.padding_side = "left"

In [49]:
data = load_dataset("openai/gsm8k", "main")["test"]

In [50]:
eval_data = []
for i, item in enumerate(data):
    proccessed = {
        "question": item["question"],
        "prompt": SYSTEM_PROMPT + " " + item["question"],
        "answer": item["answer"],
        "numerical_answer": extract_numerical_answer(item["answer"]),
        "other_answer": extract_hash_answer(item["answer"]),
    }
    eval_data.append(proccessed)

In [54]:
prompts = [item["prompt"] for item in eval_data]

In [None]:
generated_texts = []
batch_size = 128
for i in tqdm(range(0, len(prompts), batch_size), desc="Generating"):
    batch_prompt = prompts[i:i+batch_size]
    inputs = tokenizer(
        batch_prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=786,
    ).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **input,
            max_new_tokens=786,
            temperature = 0.0,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
    for j, outputs in enumerate(outputs):
        generated_text = tokenizer.decode(outputs, skip_special_tokens=True)
        prompt = batch_prompt[j]
        generated_texts.append(generated_text.replace(prompt[j], '')})
        #print(f"Prompt {i+j}: {prompt[j]}")
        #print('-' * 160)
        #print(f"Generated Text: {generated_text.replace(prompt[j], '')}")
        #print("-" * 160)
    

Generating:   0%|          | 0/11 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Generating:   0%|          | 0/11 [00:06<?, ?it/s]


{'input_ids': tensor([[  198, 65354,   304,   279,  2701,  3561,   510,    27, 19895,   287,
           397,  9338,   522, 19895,   287,   397,    27,  9217,   397,  9338,
           522,  9217,   397, 17599,   323,   220,    18,   315,   806,  4780,
          1973,   220,    22, 87770,   369, 15786,    13,  8886, 22502,   374,
          3931,  1119,   220,    23, 34254,    13,  1416, 17599,   323,   806,
          4780,  1366,   311,  4332,   279, 87770, 18308,    11,  1246,  1657,
         34254,   646,  1817,   315,  1105,   614,    30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}