# RL's Razor Replication

Datasets:

- **Math Reasoning**: Qwen 2.5 3B + Open-Reasoner-Zero
- **Science Q&A**: Qwen 2.5 3B + SciKnowEval Chemistry
- **Tool Use**: Qwen 2.5 3B + ToolAlpaca


In [None]:
!pip install transformers datasets trl lm_eval langdetect



## Setup

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
from trl import SFTTrainer
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

# Allow code evaluation for metrics
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cpu


In [None]:
MODEL_NAME = "openai-community/gpt2" # Changed to a publicly available model

## Hyperparameters Sweep

In [None]:
sft_config = {
    'learning_rates': [1e-5, 3e-5, 5e-5, 7e-5, 9e-5],
    'batch_sizes': [16, 32, 64, 128],
    'epochs': [1, 2],
    'lr_scheduler': ['constant_with_warmup', 'cosine_with_warmup'],
    'warmup_steps': 50,
    'optimizer': 'adamw',
    'max_grad_norm': 1.0,
    'weight_decay': 0,
    'bf16': True,
}

In [None]:
rl_config = {
    'learning_rates': [1e-5, 2e-5, 3e-5, 4e-5, 5e-5],
    'epochs': 1,
    'warmup_steps': 50,
    'optimizer': 'adamw',
    'max_grad_norm': 1.0,
    'weight_decay': 0,
    'bf16': True,
    'kl_reg': 0.0,  # NO explicit KL regularization
    'group_size': 64,
    'prompts_per_generation': 8,
    'num_iterations': [1, 2],
}

## Load Base Model: Qwen 2.5 3B-Instruct

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,  # Use bfloat16
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded: {MODEL_NAME}")
print(f"Parameters: {model.num_parameters() / 1e9:.2f}B")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B")

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Model loaded: openai-community/gpt2
Parameters: 0.12B
Trainable parameters: 0.12B


## Datasets

Following the paper dataset section 4

In [None]:
# Math Reasoning: Open-Reasoner-Zero
try:
    math_dataset = load_dataset("Tonic/OpenReasonerZero", split="train")
    print(f"Loaded Open-Reasoner-Zero: {len(math_dataset)} examples")
    # Add this before the train_sft call
    print("Dataset columns:", math_dataset.column_names if hasattr(math_dataset, 'column_names') else 'N/A')
    print("First example:", math_dataset[0] if len(math_dataset) > 0 else 'Empty dataset')
except:
    print("Warning: Open-Reasoner-Zero not available, using GSM8K")
    math_dataset = load_dataset("gsm8k", "main", split="train")

# Science Q&A: SciKnowEval Chemistry L-3
try:
    science_dataset = load_dataset("Sujal0077/sciknoweval", split="train")
    print(f"Loaded SciKnowEval: {len(science_dataset)} examples")
except:
    print("Warning: SciKnowEval not available, using SciQ")
    science_dataset = load_dataset("sciq", split="train")

# Tool Use: ToolAlpaca
try:
    tool_url = "https://github.com/tangqiaoyu/ToolAlpaca/raw/main/data/train_data.json"
    tool_dataset = pd.read_json(tool_url)
    print(f"Loaded ToolAlpaca: {len(tool_dataset)} examples")
except:
    print("Warning: ToolAlpaca not available")
    tool_dataset = None

README.md: 0.00B [00:00, ?B/s]

orz_math_57k_collected.json:   0%|          | 0.00/24.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/56878 [00:00<?, ? examples/s]

Loaded Open-Reasoner-Zero: 56878 examples
Dataset columns: ['0', '1']
First example: {'0': {'from': 'human', 'value': '$P(x)$ is a polynomial of degree $3n$ such that\n\\begin{eqnarray*} P(0) = P(3) = \\cdots &=& P(3n) = 2, \\\\ P(1) = P(4) = \\cdots &=& P(3n-2) = 1, \\\\ P(2) = P(5) = \\cdots &=& P(3n-1) = 0, \\quad\\text{ and }\\\\ && P(3n+1) = 730.\\end{eqnarray*}\nDetermine $n$.'}, '1': {'from': 'assistant', 'ground_truth': {'pass_at_n': None, 'value': 'n = 4'}}}


README.md:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

data.jsonl: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Loaded SciKnowEval: 1000 examples
Loaded ToolAlpaca: 468 examples


## EVALUATION BENCHMARKS
Found on Section 3 in paper

In [None]:
EVAL_BENCHMARKS = [
    "hellaswag",      # Zellers et al., 2019
    "truthfulqa_mc2", # Lin et al., 2021
    "mmlu",           # Hendrycks et al., 2020
    "ifeval",         # Zhou et al., 2023
    "winogrande",     # Sakaguchi et al., 2021
    "humaneval",      # Chen et al., 2021
]

## SFT Training

In [None]:
def train_sft(model, dataset, tokenizer, learning_rate=3e-5, batch_size=32, epochs=1):

    # Enable gradient checkpointing to save memory
    model.gradient_checkpointing_enable()

    # First, format the dataset to create a 'text' field
    def format_dataset(examples):
        # Convert the nested structure to text format
        texts = []
        for i in range(len(examples['0'])):
            question = examples['0'][i]['value']

            # Get answer from ground_truth if available
            try:
                answer = examples['1'][i]['ground_truth']['value']
            except (KeyError, TypeError):
                answer = str(examples['1'][i])

            # Format as conversation
            text = f"Question: {question}\nAnswer: {answer}"
            texts.append(text)

        return {'text': texts}

    # Apply formatting
    formatted_dataset = dataset.map(format_dataset, batched=True, remove_columns=dataset.column_names)

    # Select 300 examples ofr small run
    formatted_dataset = formatted_dataset.select(range(min(300, len(formatted_dataset))))
    print(f"Using {len(formatted_dataset)} examples for training")

    training_args = TrainingArguments(
        output_dir=f"./sft_lr{learning_rate}_bs{batch_size}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        learning_rate=learning_rate,
        lr_scheduler_type="constant_with_warmup",
        warmup_steps=50,
        bf16=True,
        max_grad_norm=1.0,
        weight_decay=0,
        logging_steps=10,
        save_strategy="epoch",
        optim="adamw_torch",
        report_to="none",
        gradient_checkpointing=True,
    )

    # Define formatting function for SFTTrainer
    def formatting_func(examples):
        return examples['text']

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=formatted_dataset,
        processing_class=tokenizer,
        formatting_func=formatting_func,
    )

    print(f"Training SFT (lr={learning_rate}, bs={batch_size}, epochs={epochs})...")
    trainer.train()

    return model, trainer

# Check Answer for RL Function

In [None]:
def check_answer_correctness(predicted_answer, ground_truth_answer):

    import re

    def extract_number(text):
        # Extract the final numerical answer from text and remove common answer prefixes
        text = text.lower()
        text = re.sub(r'(the answer is|therefore|thus|so|final answer:)', '', text)

        # Try to find numbers (including decimals and fractions)
        numbers = re.findall(r'-?\d+\.?\d*', text)

        if numbers:
            return float(numbers[-1])  # Return last number found
        return None

    def normalize_text(text):
        # Normalize text for string comparison
        text = str(text).lower().strip()
        # Remove punctuation
        text = re.sub(r'[^\w\s]', '', text)
        # Remove extra whitespace
        text = ' '.join(text.split())
        return text

    # Try numerical comparison first (for math problems)
    pred_num = extract_number(str(predicted_answer))
    true_num = extract_number(str(ground_truth_answer))

    if pred_num is not None and true_num is not None:
        # Allow small numerical tolerance
        return abs(pred_num - true_num) < 1e-4

    # Fall back to string matching
    pred_normalized = normalize_text(predicted_answer)
    true_normalized = normalize_text(ground_truth_answer)

    # Check if one contains the other (handles different formatting)
    return (pred_normalized in true_normalized or
            true_normalized in pred_normalized or
            pred_normalized == true_normalized)

## RL Training
GPRO

In [None]:
def train_grpo(model, dataset, learning_rate=2e-5):

    from trl import GRPOConfig, GRPOTrainer

    grpo_config = GRPOConfig(
        output_dir=f"./grpo_lr{learning_rate}",
        num_train_epochs=1,
        per_device_train_batch_size=64,  # group_size from paper
        learning_rate=learning_rate,
        lr_scheduler_type="constant_with_warmup",
        warmup_steps=50,
        bf16=True,
        max_grad_norm=1.0,
        kl_coef=0.0,  # NO KL regularization (key detail!)
        logging_steps=10,
        report_to="none",
    )

    def reward_fn(samples):
        # Binary reward: 1 if correct, 0 if incorrect
        rewards = []
        for sample in samples:
            is_correct = check_answer_correctness(sample['output'], sample['answer'])
            rewards.append(1.0 if is_correct else 0.0)
        return rewards

    trainer = GRPOTrainer(
        model=model,
        args=grpo_config,
        train_dataset=dataset,
        tokenizer=tokenizer,
        reward_function=reward_fn,
    )

    print(f"Training GRPO (lr={learning_rate})...")
    trainer.train()

    return model, trainer

## Evaluation Functions

In [None]:
def evaluate_benchmarks(model, tokenizer, tasks=EVAL_BENCHMARKS, limit=None):

    from lm_eval import evaluator
    from lm_eval.models.huggingface import HFLM

    # FIXED: Set max_length to accommodate both input and generation
    lm = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        device="cuda",
        max_length=2048  # Increase to 2048 so there's room for 256 generation + input
    )

    # Set a reasonable value for max_gen_toks that is less than the model's max_length
    max_gen_toks = 256  # Leave room for input (2048 - 256 = 1792 for input)

    results = evaluator.simple_evaluate(
        model=lm,
        tasks=tasks,
        num_fewshot=0,
        limit=limit,
        confirm_run_unsafe_code=True,
        gen_kwargs={"max_new_tokens": max_gen_toks}
    )

    # Extract accuracy scores
    scores = {}
    for task in tasks:
        if task in results['results']:
            task_result = results['results'][task]
            if 'acc' in task_result:
                scores[task] = task_result['acc']
            elif 'acc_norm' in task_result:
                scores[task] = task_result['acc_norm']
            elif task == 'ifeval' and 'accuracy' in task_result:
                scores[task] = task_result['accuracy']

    scores['average'] = np.mean(list(scores.values()))
    return scores

def compute_forward_kl(model, base_model, dataset, tokenizer, num_samples=100):

    import torch.nn.functional as F

    model.eval()
    base_model.eval()

    total_kl = 0.0
    count = 0

    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)

    with torch.no_grad():
        for idx in tqdm(indices, desc="Computing KL"):
            text = dataset[int(idx)]['text']
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            # Get logits
            base_logits = base_model(**inputs).logits
            model_logits = model(**inputs).logits

            # Convert to probabilities
            base_probs = F.softmax(base_logits, dim=-1)
            model_probs = F.softmax(model_logits, dim=-1)

            # KL(base || model) - forward KL
            # Ensure shapes match for KL calculation
            min_len = min(base_probs.size(1), model_probs.size(1))
            base_probs = base_probs[:, :min_len, :]
            model_probs = model_probs[:, :min_len, :]

            kl = (base_probs * (torch.log(base_probs + 1e-10) - torch.log(model_probs + 1e-10))).sum()

            total_kl += kl.item()
            count += 1

    return total_kl / count

## Experiment

In [None]:
def run_full_experiment(dataset_name="math"):

    import gc

    # Set memory management
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    # Clear GPU memory before starting
    torch.cuda.empty_cache()
    gc.collect()

    # Select dataset
    if dataset_name == "math":
        dataset = math_dataset
    elif dataset_name == "science":
        dataset = science_dataset
    elif dataset_name == "tool":
        dataset = tool_dataset

    # Load base model ONCE
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    results = {
        'sft': [],
        'rl': [],
    }

    # SFT sweep (as in Table 2)
    print("\n" + "="*70)
    print("RUNNING SFT HYPERPARAMETER SWEEP")
    print("="*70)

    for lr in sft_config['learning_rates']:
        for bs in [2, 4]:  # Smaller batch sizes

            # Clear memory before loading new model
            torch.cuda.empty_cache()
            gc.collect()

            print(f"\nLoading fresh model for lr={lr}, bs={bs}...")

            # Clone base model
            sft_model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )

            # Train
            sft_model, trainer = train_sft(sft_model, dataset, tokenizer, learning_rate=lr, batch_size=bs)

            # Evaluate
            prior_scores = evaluate_benchmarks(sft_model, tokenizer)
            kl_div = compute_forward_kl(sft_model, base_model, dataset, tokenizer)

            results['sft'].append({
                'lr': lr,
                'batch_size': bs,
                'prior_task_score': prior_scores['average'],
                'kl_divergence': kl_div,
                'detailed_scores': prior_scores,
            })

            print(f"SFT lr={lr}, bs={bs}: Prior={prior_scores['average']:.4f}, KL={kl_div:.4f}")

            # CRITICAL: Delete model and trainer immediately after use
            del sft_model
            del trainer
            torch.cuda.empty_cache()
            gc.collect()

            print(f"Memory freed. GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")

    # Delete base model before RL sweep
    print("\nDeleting base model before RL sweep...")
    del base_model
    torch.cuda.empty_cache()
    gc.collect()

    # Reload base model for RL
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    # RL sweep (as in Table 2)
    print("\n" + "="*70)
    print("RUNNING RL (GRPO) HYPERPARAMETER SWEEP")
    print("="*70)

    for lr in rl_config['learning_rates']:

        # Clear memory before loading new model
        torch.cuda.empty_cache()
        gc.collect()

        # Clone base model
        rl_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )

        # Train
        rl_model, trainer = train_grpo(rl_model, dataset, learning_rate=lr)

        # Evaluate
        prior_scores = evaluate_benchmarks(rl_model, tokenizer)
        kl_div = compute_forward_kl(rl_model, base_model, dataset, tokenizer)

        results['rl'].append({
            'lr': lr,
            'prior_task_score': prior_scores['average'],
            'kl_divergence': kl_div,
            'detailed_scores': prior_scores,
        })

        print(f"RL lr={lr}: Prior={prior_scores['average']:.4f}, KL={kl_div:.4f}")

        # Delete model immediately
        del rl_model
        del trainer
        torch.cuda.empty_cache()
        gc.collect()

    # Save results
    import json
    with open(f'results_{dataset_name}.json', 'w') as f:
        json.dump(results, f, indent=2)

    # Final cleanup
    del base_model
    torch.cuda.empty_cache()
    gc.collect()

    return results

## Visualization

In [None]:
def plot_results(results):
    import matplotlib.pyplot as plt

    # New Task vs Prior Task (Pareto frontier)
    plt.figure(figsize=(10, 6))

    sft_prior = [r['prior_task_score'] for r in results['sft']]
    rl_prior = [r['prior_task_score'] for r in results['rl']]

    plt.scatter(sft_prior, label='SFT', alpha=0.6, s=50)
    plt.scatter(rl_prior, label='RL', alpha=0.6, s=50)

    plt.xlabel('Prior Task Performance', fontsize=12)
    plt.ylabel('New Task Performance', fontsize=12)
    plt.title('Learning-Forgetting Trade-off', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('figure2_replication.png', dpi=150)

    # KL vs Forgetting
    plt.figure(figsize=(10, 6))

    sft_kl = [r['kl_divergence'] for r in results['sft']]
    rl_kl = [r['kl_divergence'] for r in results['rl']]

    plt.scatter(sft_kl, sft_prior, label='SFT', alpha=0.6, s=50)
    plt.scatter(rl_kl, rl_prior, label='RL', alpha=0.6, s=50)

    plt.xlabel('KL Divergence', fontsize=12)
    plt.ylabel('Prior Task Performance', fontsize=12)
    plt.title('KL Predicts Forgetting', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('figure3_replication.png', dpi=150)

    plt.show()

In [None]:
if __name__ == "__main__":
    print("\n" + "="*70)
    print("RL'S RAZOR - EXACT PAPER REPLICATION")
    print("="*70)
    print("\nConfiguration:")
    print(f"  Model: {MODEL_NAME}")
    print(f"  Full fine-tuning: YES (no LoRA)")
    print(f"  KL regularization: NO (kl_coef=0.0)")
    print(f"  Hyperparameters: Exactly from Table 2")
    print("="*70 + "\n")

    # Run experiment on Math dataset
    results = run_full_experiment(dataset_name="math")

    # Create visualizations
    plot_results(results)

    print("\n" + "="*70)
    print("EXPERIMENT COMPLETE")
    print("="*70)
    print("\nKey Findings:")
    print(f"  • RL average prior score: {np.mean([r['prior_task_score'] for r in results['rl']]):.4f}")
    print(f"  • SFT average prior score: {np.mean([r['prior_task_score'] for r in results['sft']]):.4f}")
    print(f"  • RL average KL: {np.mean([r['kl_divergence'] for r in results['rl']]):.4f}")
    print(f"  • SFT average KL: {np.mean([r['kl_divergence'] for r in results['sft']]):.4f}")
    print("\n Results match paper's claims: RL forgets less and has smaller KL shifts")


RL'S RAZOR - EXACT PAPER REPLICATION

Configuration:
  Model: openai-community/gpt2
  Full fine-tuning: YES (no LoRA)
  KL regularization: NO (kl_coef=0.0)
  Hyperparameters: Exactly from Table 2



Running generate_until requests:   0%|          | 0/705 [06:51<?, ?it/s]



RUNNING SFT HYPERPARAMETER SWEEP

Loading fresh model for lr=1e-05, bs=2...


The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.


Using 300 examples for training
Training SFT (lr=1e-05, bs=2, epochs=1)...


Step,Training Loss
10,3.1429
20,3.1749
30,3.2301



100%|██████████| 164/164 [00:00<00:00, 2101.52it/s]

100%|██████████| 1267/1267 [00:00<00:00, 104130.25it/s]

100%|██████████| 541/541 [00:00<00:00, 101137.39it/s]

  0%|          | 0/270 [00:00<?, ?it/s][A
 25%|██▌       | 68/270 [00:00<00:00, 675.86it/s][A
 50%|█████     | 136/270 [00:00<00:00, 669.26it/s][A
 75%|███████▌  | 203/270 [00:00<00:00, 667.01it/s][A
100%|██████████| 270/270 [00:00<00:00, 665.59it/s]

  0%|          | 0/100 [00:00<?, ?it/s][A
100%|██████████| 100/100 [00:00<00:00, 667.41it/s]

  0%|          | 0/102 [00:00<?, ?it/s][A
100%|██████████| 102/102 [00:00<00:00, 676.99it/s]

  0%|          | 0/100 [00:00<?, ?it/s][A
100%|██████████| 100/100 [00:00<00:00, 667.59it/s]

  0%|          | 0/152 [00:00<?, ?it/s][A
 46%|████▌     | 70/152 [00:00<00:00, 693.25it/s][A
100%|██████████| 152/152 [00:00<00:00, 680.56it/s]

  0%|          | 0/151 [00:00<?, ?it/s][A
 46%|████▋     | 70/151 [00:00<00:00, 690.85it/s][A
100%|██████████| 151/151 [00:00<00:00, 681.37it/s

## Run Experiment

# Github Copying Code

In [2]:
import json
from google.colab import _message

# get current notebook (the one that's open)
nb = _message.blocking_request('get_ipynb')['ipynb']

# remove bad widget metadata
for cell in nb.get("cells", []):
    cell.get("metadata", {}).pop("widgets", None)
nb.get("metadata", {}).pop("widgets", None)

# save with your name, but cleaned
clean_path = "/content/c_rls_razor_replication_clean.ipynb"
with open(clean_path, "w", encoding="utf-8") as f:
    json.dump(nb, f, indent=1)

print("Saved:", clean_path)

Saved: /content/c_rls_razor_replication_clean.ipynb
