In [None]:
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import numpy as np 
import torch
from tqdm import tqdm
import gc
import re
from datasets import load_dataset, DatasetDict, concatenate_datasets
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import TrainingArguments
from trl import GRPOConfig, GRPOTrainer
from peft import PeftModel

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
dataset_name="openai/gsm8k"
subset='main'
trainset=load_dataset(dataset_name,subset,split='train')
testset= load_dataset(dataset_name,subset,split='test')
print(len(trainset))
print(len(testset))

In [None]:
#taking 10% due to compute availability
train_size = int(0.1 * len(trainset))
test_size = int(0.1 * len(testset))

trainset = trainset.shuffle(seed=42).select(range(train_size))
testset = testset.shuffle(seed=42).select(range(test_size))

print(f"Train size: {len(trainset)}")
print(f"Test size: {len(testset)}")

In [None]:
print(trainset[0]['question'])

In [None]:
from huggingface_hub import login
login("")

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_id='Qwen/Qwen3-4B'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,              
    bnb_4bit_quant_type="nf4",      
    bnb_4bit_compute_dtype=torch.bfloat16, 
    bnb_4bit_use_double_quant=True, 
)

print('model loading with 4 bit quantization')

tokenizer=AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config, 
    device_map="cuda:0",            
    attn_implementation="flash_attention_2"
)

print("Model loaded on GPU.")

In [None]:
try:
    del trainer
    del model
except NameError:
    pass

gc.collect()
torch.cuda.empty_cache()

In [None]:
model.gradient_checkpointing_enable()

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules="all-linear", 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

def format_gsm8k_for_sft(example):
    
    full_answer = example['answer']
    reasoning, solution = full_answer.split("####")
    
    reasoning = reasoning.strip()
    solution = solution.strip()

trainset = trainset.map(format_gsm8k_for_sft, remove_columns=trainset.column_names)
testset = testset.map(format_gsm8k_for_sft, remove_columns=testset.column_names)


sft_config = SFTConfig(
    output_dir="./primed_baseline",
    max_length=1024,              
    dataset_text_field="text",        
    per_device_train_batch_size=1,    
    gradient_accumulation_steps=8,    
    learning_rate=2e-4,               
    logging_steps=10,
    max_steps=200,                    
    save_steps=100,
    fp16=False,
    bf16=True,                        
    report_to="none",                 
    packing=False                     
)

trainer = SFTTrainer(
    model=model,
    train_dataset=trainset,           
    eval_dataset=testset,             
    args=sft_config,
    peft_config=lora_config,          
)

print("Starting SFT")
trainer.train()

print("Saving Primed Baseline")
# Save the LoRA adapters
trainer.model.save_pretrained("./primed_baseline_final")
# Save the tokenizer(to be reloaded later)
tokenizer.save_pretrained("./primed_baseline_final")
print("Primed baseline saved to ./primed_baseline_final")

In [None]:
def extract_xml_answer(text):
    answer = text.split("</think>")[-1].strip()
    return answer
    
def correctness_reward_func(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    extracted_answers = [extract_xml_answer(r) for r in responses]
    
    rewards = []
    for predicted, truth in zip(extracted_answers, answer):
        pred_clean = re.sub(r"[^0-9.]", "", predicted)
        truth_clean = re.sub(r"[^0-9.]", "", truth.split("####")[-1])
        

        if pred_clean == truth_clean and len(pred_clean) > 0:
            rewards.append(1.0)
        else:
            rewards.append(0.0)
    return rewards

def format_reward_func(completions, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    for r in responses:
        if "</think>" in r:
            rewards.append(0.2)
        else:
            rewards.append(0.0)
    return rewards


def make_prompt(example):
    return {
        "prompt": [
            {"role": "system", "content": "You are a helpful mathematics assistant."},
            {"role": "user", "content": example['question']},
        ]
    }


grpo_dataset = trainset.map(make_prompt)

#Taking loaded in model and adding SFT weights
if isinstance(model, PeftModel):
    model = model.unload()


adapter_path = "./primed_baseline_final"
print("Attaching SFT adapters")
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True)

model.gradient_checkpointing_enable()

training_args = GRPOConfig(
    output_dir="./grpo_gsm8k_final",
    learning_rate=5e-6,
    lr_scheduler_type="cosine",
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8, 
    num_generations=2,              
    max_prompt_length=256,
    max_completion_length=128,      
    max_steps=200,                   
    save_steps=100,
    report_to="none",
    use_vllm=False,
    gradient_checkpointing=True,    
    optim="paged_adamw_8bit"        
)

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[correctness_reward_func, format_reward_func],
    args=training_args,
    train_dataset=grpo_dataset,
    processing_class=tokenizer,
)

print("Starting GRPO Training")
trainer.train()

print("Saving GRPO Model & Tokenizer")
trainer.model.save_pretrained("./grpo_gsm8k_final")
tokenizer.save_pretrained("./grpo_gsm8k_final")
print("Model and tokenizer saved to ./grpo_gsm8k_final")