In [8]:
import re
import json
import gc
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from tqdm import tqdm

In [None]:
model_id = "Qwen/Qwen3-4B"
adapter_path = "./primed_baseline_final"  

#load SFT baseline in 4 bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="cuda:0",
    attn_implementation="flash_attention_2"
)

tokenizer = AutoTokenizer.from_pretrained(adapter_path)
#attach primed baseline adapters 
model= PeftModel.from_pretrained(base_model, adapter_path)
model.eval()

print("Primed Baseline loaded")

In [None]:
def extract_ground_truth(text):
    if "####" in text:
        return text.split("####")[1].strip()
    return text.strip()

def is_correct(model_output, ground_truth):
    if "</think>" in model_output:
        final_answer = model_output.split("</think>")[-1]
    else:
        final_answer = model_output 
    
    gt_clean = ground_truth.replace(',', '')
    return gt_clean in final_answer.replace(',', '')

In [None]:
#load test set with same seed used for training
dataset_name = "openai/gsm8k"
subset = 'main'

testset = load_dataset(dataset_name, subset, split='test')
test_size = int(0.1 * len(testset))
testset = testset.shuffle(seed=42).select(range(test_size))

print(f"Test set loaded.")

In [None]:
def generate_response(question_text):
    formatted_prompt = f"User: {question_text}\nAssistant: <think>"
    model_inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
    
    generated_ids = model.generate(
        model_inputs.input_ids,
        attention_mask=model_inputs.attention_mask,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=512,
        do_sample=False
    )

    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return "<think>" + response

inference_results = []
correct_count = 0
total_samples = len(testset)

print("Running inference")

for i, example in enumerate(tqdm(testset)):
    question = example['question']
    gt_full = example['answer']
    
    model_output = generate_response(question)
    gt_value = extract_ground_truth(gt_full)
    correct = is_correct(model_output, gt_value)
    
    if correct:
        correct_count += 1
        
    #store the data
    inference_results.append({
        "id": i,
        "question": question,
        "model_output": model_output,
        "ground_truth": gt_value,
        "is_correct": correct
    })

accuracy = (correct_count / total_samples) * 100
print(f"Final Accuracy: {accuracy:.2f}%")

In [None]:
num_to_display = 20  

for res in inference_results[:num_to_display]:
    print(f"Example {res['id']}\n")
    print(f"Question:\n{res['question']}")
    print("-" * 30)
    print(f"Model Reasoning:\n{res['model_output']}")
    print("-" * 30)
    print(f"Ground_Truth: {res['ground_truth']}")
    print("\n")
    

In [None]:
#Save SFT results
output_file = "sft_results_baseline.json"
with open(output_file, "w") as f:
    json.dump(inference_results, f, indent=4)


In [None]:
# Unload the previous SFT adapters to free memory/reset model
try:
    model.unload()
except AttributeError:
    pass

gc.collect()
torch.cuda.empty_cache()

grpo_adapter_path = "./grpo_gsm8k_final"

tokenizer = AutoTokenizer.from_pretrained(grpo_adapter_path)
tokenizer.padding_side = 'left' 

#Attach GRPO adapters to the baseline
model = PeftModel.from_pretrained(base_model, grpo_adapter_path)
model.eval()


In [None]:
grpo_results = []
correct_count = 0
total_samples = len(testset)

print("Running GRPO inference")

for i, example in enumerate(tqdm(testset)):
    question = example['question']
    gt_full = example['answer']
    
    formatted_prompt = f"User: {question}\nAssistant: <think>"
    inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
    
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    
    output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    if "Assistant:" in output_text:
        model_output = "<think>" + output_text.split("<think>")[-1]
    else:
        model_output = output_text

    gt_value = extract_ground_truth(gt_full)
    correct = is_correct(model_output, gt_value)
    
    if correct:
        correct_count += 1
        
    grpo_results.append({
        "id": i,
        "question": question,
        "model_output": model_output,
        "ground_truth": gt_value,
        "is_correct": correct
    })

grpo_accuracy = (correct_count / total_samples) * 100
print(f"\nGRPO final Accuracy: {grpo_accuracy:.2f}%")

In [None]:
num_to_display = 20 

for res in grpo_results[:num_to_display]:
    print(f"Example {res['id']} [{'correct' if res['is_correct'] else 'wrong'}]\n")
    print(f"Question:\n{res['question']}")
    print("-" * 30)
    print(f"GRPO Reasoning:\n{res['model_output']}")
    print("-" * 30)
    print(f"Ground Truth: {res['ground_truth']}")
    print("\n")

In [None]:
output_file = "grpo_results_final.json"
with open(output_file, "w") as f:
    json.dump(grpo_results, f, indent=4)