In [1]:
import json
import copy
import math
import re
import torch
import pandas as pd
import torch.nn.functional as F
from trl import GRPOConfig, GRPOTrainer
from torch.utils.data import DataLoader, Dataset as TorchDataset, IterableDataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model
from sklearn.model_selection import train_test_split
from tqdm import tqdm

INFO 05-12 14:02:22 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-12 14:02:22 [__init__.py:239] Automatically detected platform cuda.


In [2]:
PARAMS = {
    "ALPHA": 0.5,
    "DIFFICULTY_FACTOR": 0.5,
    "VARIANCE_FACTOR": 0.5,
    "MIN_WEIGHT": 0.01,
    "TRAIN_BATCH_SIZE": 4,
    "VALID_BATCH_SIZE": 2,
    "MAX_STEPS": 16384,
    "STEPS_PER_UPDATE": 1,
    "STEPS_PER_EVAL": 10,
    "NUM_GENERATIONS": 4,
    "MAX_COMPLETION_LENGTH": 1024,
    "LEARNING_RATE": 2e-4,
    "BETA": 0.04,
}

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()

model_dir = "Qwen/Qwen3-4B-Base"
compute_dtype = torch.float16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map="auto",
)

lora_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules=[
            'q_proj',
            'k_proj',
            'v_proj',
            'dense'
        ],
        bias="none",
        lora_dropout=0.05,
        task_type="CAUSAL_LM",
    )

tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, padding_side="left")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
def create_prompt(question):
    chat = [
        {"role": "system", "content": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>"},
        {"role": "user", "content": question + ' Return final answer within \\boxed{}.'}
    ]
    prompt = tokenizer.apply_chat_template(
        conversation=chat,
        tokenize=False,
        add_generation_prompt=True
    )
    return prompt

In [5]:
dataset_path = "AIME_IMO_MATH.csv"
data = pd.read_csv(dataset_path)
data['prompt'] = data['problem'].apply(create_prompt)
train_df, test_df = train_test_split(data, test_size=0.1, random_state=42, shuffle=True)
train_df.head()

Unnamed: 0,problem,answer,prompt
239,In an equation of the form $k = ax^2 + bx + c$...,-150,<|im_start|>system\nA conversation between Use...
1534,"In the land of Ink, the money system is unique...",6,<|im_start|>system\nA conversation between Use...
1103,"A street has 50 houses on each side, for a tot...",245,<|im_start|>system\nA conversation between Use...
555,The largest term in the binomial expansion of ...,1024,<|im_start|>system\nA conversation between Use...
888,Let $\alpha$ and $\beta$ be angles for which\n...,8,<|im_start|>system\nA conversation between Use...


In [6]:
class AdaptiveDataset(IterableDataset):
    def __init__(self, prompts, num_generations=4, alpha=0.5, difficulty_factor=0.5, variance_factor=0.5, min_weight=0.01):
        self.alpha = alpha
        self.difficulty_factor = difficulty_factor
        self.variance_factor = variance_factor
        self.num_generations = num_generations
        self.min_weight = min_weight
        
        self.data = []
        for _, prompt in prompts.iterrows():
            self.data.append({
                "problem": prompt["problem"],
                "answer": str(prompt["answer"]),
                "prompt": prompt["prompt"]
            })
        
        n = len(self.data)
        
        self._mean_rewards = torch.zeros(n)
        self._var_rewards = torch.zeros(n)
        self._weights = torch.ones(n)
        self._reward_counts = torch.zeros(n)
        self._update_probabilities()
    
    def _update_probabilities(self):
        weights = self._weights + self.min_weight
        self._probabilities = weights / weights.sum()
    
    def __iter__(self):
        while True:
            idx = torch.multinomial(self._probabilities, 1, replacement=True).item()
            sample = {
                "index": idx,
                "problem": self.data[idx]["problem"],
                "answer": self.data[idx]["answer"],
                "prompt": self.data[idx]["prompt"],
                "current_weight": self._weights[idx].item()
            }
            for _ in range(self.num_generations):
                yield sample
            
    
    def update_weights(self, reward_records):
        alpha = self.alpha
        for idx, rewards in reward_records:
            if not rewards:
                continue
            r_tensor = torch.tensor(rewards, dtype=torch.float32)
            new_mean = r_tensor.mean().item()
            new_var = r_tensor.var().item()
            
            self._mean_rewards[idx] = (1 - alpha) * self._mean_rewards[idx] + alpha * new_mean
            self._var_rewards[idx] = (1 - alpha) * self._var_rewards[idx] + alpha * new_var
            self._reward_counts[idx] += len(rewards)
            
            difficulty = 1 - self._mean_rewards[idx]
            uncertainty = torch.sqrt(self._var_rewards[idx])
            self._weights[idx] = self.difficulty_factor * difficulty + self.variance_factor * uncertainty
        
        self._update_probabilities()
    
    def get_statistics(self):
        weights_np = self._weights.numpy()
        means_np = self._mean_rewards.numpy()
        
        return {
            "weight_distribution": {
                "min": float(weights_np.min()),
                "max": float(weights_np.max()),
                "mean": float(weights_np.mean()),
                "std": float(weights_np.std())
            },
            "performance_distribution": {
                "mean_reward": float(means_np.mean()),
                "hardest_problems": means_np.argsort()[:5].tolist(),
                "most_uncertain": self._var_rewards.argsort(descending=True)[:5].tolist()
            },
            "sampling_distribution": {
                "entropy": float(-(self._probabilities * torch.log(self._probabilities + 1e-8)).sum()),
                "max_probability": float(self._probabilities.max()),
                "effective_dataset_size": float(1.0 / (self._probabilities ** 2).sum())
            }
        }
    
    def get_item_metadata(self, idx):
        return {
            "mean_reward": self._mean_rewards[idx].item(),
            "var_reward": self._var_rewards[idx].item(),
            "weight": self._weights[idx].item(),
            "reward_count": self._reward_counts[idx].item(),
            "sampling_probability": self._probabilities[idx].item()
        }
    
    def reset_weights(self):
        self._mean_rewards.fill_(0.0)
        self._var_rewards.fill_(0.0)
        self._weights.fill_(1.0)
        self._reward_counts.fill_(0.0)
        self._update_probabilities()


class ValidationDataset(TorchDataset):
    def __init__(self, validation_df):
        self.data = []
        for _, row in validation_df.iterrows():
            self.data.append({
                "problem": row["problem"],
                "answer": str(row["answer"]),
                "prompt": row["prompt"]
            })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            "index": idx,
            "problem": self.data[idx]["problem"],
            "answer": self.data[idx]["answer"],
            "prompt": self.data[idx]["prompt"]
        }

In [7]:
train_dataset = AdaptiveDataset(train_df, num_generations=PARAMS["NUM_GENERATIONS"], alpha=PARAMS["ALPHA"], difficulty_factor=PARAMS["DIFFICULTY_FACTOR"], 
                                variance_factor=PARAMS["VARIANCE_FACTOR"], 
                                min_weight=PARAMS["MIN_WEIGHT"])

valid_dataset = ValidationDataset(test_df)

In [8]:
def extract_boxed_text(text):
    pattern = r"\\boxed\{(.*?)\}"
    matches = re.findall(pattern, text)
    if not matches:
        return ""
    for match in reversed(matches):
        if match.strip():
            return match.strip()
    return ""

def accuracy_reward_func(completions, answer, index, **kwargs):
    rewards = []
    for comp, val in zip(completions, answer):
        extracted = extract_boxed_text(comp)
        rewards.append(1.0 if str(extracted) == str(val) else 0.0)

    rewards_buffer = []
    num_prompts = PARAMS["TRAIN_BATCH_SIZE"] // PARAMS["NUM_GENERATIONS"]
    for i in range(num_prompts):
        question_idx = index[i * PARAMS["NUM_GENERATIONS"]]
        rewards_distr = []
        for j in range(PARAMS["NUM_GENERATIONS"]):
            idx = i * PARAMS["NUM_GENERATIONS"] + j
            rewards_distr.append(rewards[idx])
        rewards_buffer.append((question_idx, rewards_distr))
    train_dataset.update_weights(rewards_buffer)
    return rewards

In [9]:
from trl import GRPOConfig, GRPOTrainer
from transformers import PrinterCallback


training_args = GRPOConfig(
    learning_rate=PARAMS["LEARNING_RATE"],
    per_device_train_batch_size=PARAMS["TRAIN_BATCH_SIZE"],
    gradient_accumulation_steps=1,
    max_steps=PARAMS["MAX_STEPS"],
    max_completion_length=PARAMS["MAX_COMPLETION_LENGTH"],
    num_generations=PARAMS["NUM_GENERATIONS"],
    beta=PARAMS["BETA"],
    logging_steps=1, 
    accelerator_config={
            "dispatch_batches": False
    },
    report_to="none",
)

trainer = GRPOTrainer(
        model=model,
        reward_funcs=accuracy_reward_func,
        args=training_args,
        train_dataset=train_dataset,
        peft_config=lora_config,
        callbacks=[PrinterCallback()]
        )

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
trainer.train()

Step,Training Loss


{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 0.0002, 'num_tokens': 4780.0, 'completions/mean_length': 1024.0, 'completions/min_length': 1024.0, 'completions/max_length': 1024.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/accuracy_reward_func/mean': 0.0, 'rewards/accuracy_reward_func/std': 0.0, 'reward': 0.0, 'reward_std': 0.0, 'kl': 0.0, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 6.103515625e-05}
