<a href="https://colab.research.google.com/github/Nathan22ux/nnsj-rl-razor-experiment/blob/main/rls_razor_replication.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RL's Razor Replication

Datasets:

- **Math Reasoning**: Qwen 2.5 3B + Open-Reasoner-Zero
- **Science Q&A**: Qwen 2.5 3B + SciKnowEval Chemistry
- **Tool Use**: Qwen 2.5 3B + ToolAlpaca


In [None]:
!pip install transformers datasets trl lm_eval langdetect

## Setup

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
from trl import SFTTrainer
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

# Allow code evaluation for metrics
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
MODEL_NAME = "openai-community/gpt2" # Changed to a smaller model, target LLAMA-3B

## Hyperparameters Sweep
Reference Table 2 located in page 19 on the paper RL razor

In [None]:
sft_config = {
    'learning_rates': [1e-5, 3e-5, 5e-5, 7e-5, 9e-5],
    'batch_sizes': [16, 32, 64, 128],
    'epochs': [1, 2],
    'lr_scheduler': ['constant_with_warmup', 'cosine_with_warmup'],
    'warmup_steps': 50,
    'optimizer': 'adamw',
    'max_grad_norm': 1.0,
    'weight_decay': 0,
    'bf16': True,
}

In [None]:
rl_config = {
    'learning_rates': [1e-5, 2e-5, 3e-5, 4e-5, 5e-5],
    'epochs': 1,
    'warmup_steps': 50,
    'optimizer': 'adamw',
    'max_grad_norm': 1.0,
    'weight_decay': 0,
    'bf16': True,
    'kl_reg': 0.0,  # NO explicit KL regularization
    'group_size': 64,
    'prompts_per_generation': 8,
    'num_iterations': [1, 2],
}

## Load Base Model: Qwen 2.5 3B-Instruct

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,  # Use bfloat16
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded: {MODEL_NAME}")
print(f"Parameters: {model.num_parameters() / 1e9:.2f}B")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B")

## Datasets

Following the paper dataset section 4

In [None]:
# Math Reasoning: Open-Reasoner-Zero
try:
    math_dataset = load_dataset("Tonic/OpenReasonerZero", split="train")
    print(f"Loaded Open-Reasoner-Zero: {len(math_dataset)} examples")
    # Check dataset values
    print("Dataset columns:", math_dataset.column_names if hasattr(math_dataset, 'column_names') else 'N/A')
    print("First example:", math_dataset[0] if len(math_dataset) > 0 else 'Empty dataset')
except:
    print("Warning: Open-Reasoner-Zero not available, using GSM8K")
    math_dataset = load_dataset("gsm8k", "main", split="train")

# Science Q&A: SciKnowEval Chemistry L-3
try:
    science_dataset = load_dataset("Sujal0077/sciknoweval", split="train")
    print(f"Loaded SciKnowEval: {len(science_dataset)} examples")
except:
    print("Warning: SciKnowEval not available, using SciQ")
    science_dataset = load_dataset("sciq", split="train")

# Tool Use: ToolAlpaca
try:
    tool_url = "https://github.com/tangqiaoyu/ToolAlpaca/raw/main/data/train_data.json"
    tool_dataset = pd.read_json(tool_url)
    print(f"Loaded ToolAlpaca: {len(tool_dataset)} examples")
except:
    print("Warning: ToolAlpaca not available")
    tool_dataset = None

## EVALUATION BENCHMARKS
Found on Section 3 in paper

In [None]:
EVAL_BENCHMARKS = [
    # "hellaswag",      # Zellers et al., 2019
    # "truthfulqa_mc2", # Lin et al., 2021
    # "mmlu",           # Hendrycks et al., 2020
    # "ifeval",         # Zhou et al., 2023
    # "winogrande",     # Sakaguchi et al., 2021
    # "humaneval",      # Chen et al., 2021
    # Commented out to reduce the computation amount on evaluation due to GPU limits
    "winogrande",
    "hellaswag",
    "mmlu_high_school_mathematics",
    "mmlu_high_school_computer_science"
]

## SFT Training

In [None]:
def train_sft(model, dataset, tokenizer, learning_rate=3e-5, batch_size=32, epochs=1):

    # Enable gradient checkpointing to save memory
    model.gradient_checkpointing_enable()

    # Formating the dataset, creating a 'text' field
    def format_dataset(examples):
        # Converting nested structure to text format
        texts = []
        for i in range(len(examples['0'])):
            question = examples['0'][i]['value']

            # Get answer from ground_truth if available
            try:
                answer = examples['1'][i]['ground_truth']['value']
            except (KeyError, TypeError):
                answer = str(examples['1'][i])

            # Format as conversation
            text = f"Question: {question}\nAnswer: {answer}"
            texts.append(text)

        return {'text': texts}

    # Apply formatting
    formatted_dataset = dataset.map(format_dataset, batched=True, remove_columns=dataset.column_names)

    # Bc og GPU limitations selected 300 examples for small run
    formatted_dataset = formatted_dataset.select(range(min(300, len(formatted_dataset))))
    print(f"Using {len(formatted_dataset)} examples for training")

    training_args = TrainingArguments(
        output_dir=f"./sft_lr{learning_rate}_bs{batch_size}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        learning_rate=learning_rate,
        lr_scheduler_type="constant_with_warmup",
        warmup_steps=50,
        bf16=True,
        max_grad_norm=1.0,
        weight_decay=0,
        logging_steps=10,
        save_strategy="epoch",
        optim="adamw_torch",
        report_to="none",
        gradient_checkpointing=True,
    )

    # Define formatting function for SFTTrainer
    def formatting_func(examples):
        return examples['text']

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=formatted_dataset,
        processing_class=tokenizer,
        formatting_func=formatting_func,
    )

    print(f"Training SFT (lr={learning_rate}, bs={batch_size}, epochs={epochs})...")
    trainer.train()

    return model, trainer

# Check Answer for RL Function

In [None]:
def check_answer_correctness(predicted_answer, ground_truth_answer):

    import re

    def extract_number(text):
        # Extract the final numerical answer from text and remove common answer prefixes
        text = text.lower()
        text = re.sub(r'(the answer is|therefore|thus|so|final answer:)', '', text)

        # Try to find numbers (including decimals and fractions)
        numbers = re.findall(r'-?\d+\.?\d*', text)

        if numbers:
            return float(numbers[-1])  # Return last number found
        return None

    def normalize_text(text):
        # Normalize text for string comparison
        text = str(text).lower().strip()
        # Remove punctuation
        text = re.sub(r'[^\w\s]', '', text)
        # Remove extra whitespace
        text = ' '.join(text.split())
        return text

    # Try numerical comparison first (for math problems)
    pred_num = extract_number(str(predicted_answer))
    true_num = extract_number(str(ground_truth_answer))

    if pred_num is not None and true_num is not None:
        # Allow small numerical tolerance
        return abs(pred_num - true_num) < 1e-4

    # Fall back to string matching
    pred_normalized = normalize_text(predicted_answer)
    true_normalized = normalize_text(ground_truth_answer)

    # Check if one contains the other (handles different formatting)
    return (pred_normalized in true_normalized or
            true_normalized in pred_normalized or
            pred_normalized == true_normalized)

## RL Training
GPRO

In [None]:
def train_grpo(model, dataset, tokenizer, learning_rate=2e-5):

    from trl import GRPOConfig, GRPOTrainer

    # Format dataset for GRPO - this needs a 'prompt' field
    def format_for_grpo(examples):
        prompts = []
        answers = []
        for i in range(len(examples['0'])):
            question = examples['0'][i]['value']
            try:
                answer = examples['1'][i]['ground_truth']['value']
            except (KeyError, TypeError):
                answer = str(examples['1'][i])

            prompt = f"Question: {question}\nAnswer:"
            prompts.append(prompt)
            answers.append(answer)

        return {'prompt': prompts, 'answer': answers}

    # Format dataset
    formatted_dataset = dataset.map(format_for_grpo, batched=True, remove_columns=dataset.column_names)

    # Reduce to 100 examples bc of Colab limits (delete on jupyter)
    formatted_dataset = formatted_dataset.select(range(min(100, len(formatted_dataset))))
    print(f"Using {len(formatted_dataset)} examples for GRPO training")

    grpo_config = GRPOConfig(
        output_dir=f"./grpo_lr{learning_rate}",
        num_train_epochs=1,
        per_device_train_batch_size=64,
        learning_rate=learning_rate,
        lr_scheduler_type="constant_with_warmup",
        warmup_steps=50,
        max_grad_norm=1.0,
        logging_steps=10,
        report_to="none",
    )

    def reward_fn(prompts, completions, completion_ids, **kwargs):
        """
        Reward function with correct signature for GRPOTrainer
        Args:
            prompts: List of prompt strings
            completions: List of completion strings
            completion_ids: List of completion token IDs
        Returns:
            List of reward scores (float)
        """
        rewards = []
        for prompt, completion in zip(prompts, completions):
            try:
                # This is a basic reward fn, it checks if completion has content
                if completion and len(completion.strip()) > 0:
                    rewards.append(1.0)
                else:
                    rewards.append(0.0)
            except:
                rewards.append(0.0)

        return rewards

    trainer = GRPOTrainer(
        model=model,
        args=grpo_config,
        train_dataset=formatted_dataset,
        processing_class=tokenizer,
        reward_funcs=[reward_fn],
    )

    print(f"Training GRPO (lr={learning_rate})...")
    trainer.train()

    return model, trainer

## Evaluation Functions

In [None]:
def evaluate_benchmarks(model, tokenizer, tasks=EVAL_BENCHMARKS, limit=300):  # Changed from limit=None to limit=300

    from lm_eval import evaluator
    from lm_eval.models.huggingface import HFLM

    lm = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        device="cuda",
        max_length=2048
    )

    max_gen_toks = 256

    results = evaluator.simple_evaluate(
        model=lm,
        tasks=tasks,
        num_fewshot=0,
        limit=limit,
        confirm_run_unsafe_code=True,
        gen_kwargs={"max_new_tokens": max_gen_toks}
    )

    # Extract accuracy scores
    scores = {}
    for task in tasks:
        if task in results['results']:
            task_result = results['results'][task]
            if 'acc' in task_result:
                scores[task] = task_result['acc']
            elif 'acc_norm' in task_result:
                scores[task] = task_result['acc_norm']
            elif task == 'ifeval' and 'accuracy' in task_result:
                scores[task] = task_result['accuracy']

    scores['average'] = np.mean(list(scores.values()))
    return scores

def compute_forward_kl(model, base_model, dataset, tokenizer, num_samples=100):

    import torch.nn.functional as F

    model.eval()
    base_model.eval()

    total_kl = 0.0
    count = 0

    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)

    with torch.no_grad():
        for idx in tqdm(indices, desc="Computing KL"):
            text = dataset[int(idx)]['text']
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            # Get logits
            base_logits = base_model(**inputs).logits
            model_logits = model(**inputs).logits

            # Convert to probabilities
            base_probs = F.softmax(base_logits, dim=-1)
            model_probs = F.softmax(model_logits, dim=-1)

            # KL(base || model) - forward KL
            min_len = min(base_probs.size(1), model_probs.size(1))
            base_probs = base_probs[:, :min_len, :]
            model_probs = model_probs[:, :min_len, :]

            kl = (base_probs * (torch.log(base_probs + 1e-10) - torch.log(model_probs + 1e-10))).sum()

            total_kl += kl.item()
            count += 1

    return total_kl / count

## Experiment

In [None]:
def run_full_experiment(dataset_name="math"):
    """
    Run full experiment to create Pareto frontier (Figure 2)
    """

    import gc

    # Set memory management
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    # Clear GPU memory before starting
    torch.cuda.empty_cache()
    gc.collect()

    # Select dataset
    if dataset_name == "math":
        dataset = math_dataset
    elif dataset_name == "science":
        dataset = science_dataset
    elif dataset_name == "tool":
        dataset = tool_dataset

    # FORMAT DATASET ONCE HERE
    def format_dataset_for_kl(examples):
        """Convert the nested structure to text format"""
        texts = []
        for i in range(len(examples['0'])):
            question = examples['0'][i]['value']
            try:
                answer = examples['1'][i]['ground_truth']['value']
            except (KeyError, TypeError):
                answer = str(examples['1'][i])
            text = f"Question: {question}\nAnswer: {answer}"
            if len(text) > 800:
                text = text[:800]
            texts.append(text)
        return {'text': texts}

    # Create formatted version for KL computation
    formatted_dataset_kl = dataset.map(format_dataset_for_kl, batched=True, remove_columns=dataset.column_names)

    # Load base model ONCE
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    results = {
        'sft': [],
        'rl': [],
    }

    # SFT sweep (as in Table 2)
    print("\n" + "="*70)
    print("RUNNING SFT HYPERPARAMETER SWEEP")
    print("="*70)

    for lr in sft_config['learning_rates']:
        for bs in [2, 4]:  # Smaller batch sizes, on paper [16,64]

            # Clear memory before loading new model
            torch.cuda.empty_cache()
            gc.collect()

            print(f"\nLoading fresh model for lr={lr}, bs={bs}...")

            # Clone base model
            sft_model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )

            # Train (train_sft will format the dataset internally)
            sft_model, trainer = train_sft(sft_model, dataset, tokenizer, learning_rate=lr, batch_size=bs)

            # Evaluate
            prior_scores = evaluate_benchmarks(sft_model, tokenizer)
            kl_div = compute_forward_kl(sft_model, base_model, formatted_dataset_kl, tokenizer)  # Use formatted dataset

            results['sft'].append({
                'lr': lr,
                'batch_size': bs,
                'prior_task_score': prior_scores['average'],
                'kl_divergence': kl_div,
                'detailed_scores': prior_scores,
            })

            print(f"SFT lr={lr}, bs={bs}: Prior={prior_scores['average']:.4f}, KL={kl_div:.4f}")

            # Delete model and trainer immediately after use
            del sft_model
            del trainer
            torch.cuda.empty_cache()
            gc.collect()

            print(f"Memory freed. GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")

    # Delete base model before RL sweep
    print("\nDeleting base model before RL sweep...")
    del base_model
    torch.cuda.empty_cache()
    gc.collect()

    # Reload base model for RL
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    # RL sweep (as in Table 2)
    print("\n" + "="*70)
    print("RUNNING RL (GRPO) HYPERPARAMETER SWEEP")
    print("="*70)

    for lr in rl_config['learning_rates']:

        # Clear memory before loading new model
        torch.cuda.empty_cache()
        gc.collect()

        # Clone base model
        rl_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )

        # Train
        rl_model, trainer = train_grpo(rl_model, dataset, tokenizer, learning_rate=lr)

        # Evaluate
        prior_scores = evaluate_benchmarks(rl_model, tokenizer)
        kl_div = compute_forward_kl(rl_model, base_model, formatted_dataset_kl, tokenizer)  # Use formatted dataset

        results['rl'].append({
            'lr': lr,
            'prior_task_score': prior_scores['average'],
            'kl_divergence': kl_div,
            'detailed_scores': prior_scores,
        })

        print(f"RL lr={lr}: Prior={prior_scores['average']:.4f}, KL={kl_div:.4f}")

        # Delete model immediately
        del rl_model
        del trainer
        torch.cuda.empty_cache()
        gc.collect()

    # Save results
    import json
    with open(f'results_{dataset_name}.json', 'w') as f:
        json.dump(results, f, indent=2)

    # Final cleanup
    del base_model
    torch.cuda.empty_cache()
    gc.collect()

    return results

## Visualization

In [None]:
def plot_results(results):
    import matplotlib.pyplot as plt
    import numpy as np

    # Extract data
    sft_prior = [r['prior_task_score'] for r in results['sft']]
    sft_kl = [r['kl_divergence'] for r in results['sft']]

    rl_prior = [r['prior_task_score'] for r in results['rl']]
    rl_kl = [r['kl_divergence'] for r in results['rl']]

    # KL vs Prior Task (showing forgetting)
    plt.figure(figsize=(10, 6))

    plt.scatter(sft_kl, sft_prior, label='SFT', alpha=0.6, s=50)
    plt.scatter(rl_kl, rl_prior, label='RL', alpha=0.6, s=50)

    plt.xlabel('KL Divergence', fontsize=12)
    plt.ylabel('Prior Task Performance', fontsize=12)
    plt.title('KL Predicts Forgetting (Lower KL = Less Forgetting)', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('kl_vs_forgetting.png', dpi=150)
    plt.show()

    # Comparison plot
    plt.figure(figsize=(10, 6))

    # Plot SFT and RL results
    methods = ['SFT', 'RL']
    prior_scores = [np.mean(sft_prior), np.mean(rl_prior)]
    kl_divs = [np.mean(sft_kl), np.mean(rl_kl)]

    x = np.arange(len(methods))
    width = 0.35

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.bar(x - width/2, prior_scores, width, label='Prior Task Score', alpha=0.8)
    ax.bar(x + width/2, kl_divs, width, label='KL Divergence', alpha=0.8)

    ax.set_ylabel('Score')
    ax.set_title('SFT vs RL: Forgetting Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(methods)
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('sft_vs_rl_comparison.png', dpi=150)
    plt.show()

In [None]:
if __name__ == "__main__":
    print("\n" + "="*70)
    print("RL'S RAZOR")
    print("="*70)
    print("\nConfiguration:")
    print(f"  Model: {MODEL_NAME}")
    print(f"  Hyperparameters: Exactly from Table 2")
    print("="*70 + "\n")

    # Run experiment on Math dataset
    results = run_full_experiment(dataset_name="math")

    # Create visualizations
    plot_results(results)

    print("\n" + "="*70)
    print("EXPERIMENT COMPLETE")
    print("="*70)
    print("\nKey Findings:")
    print(f"  • RL average prior score: {np.mean([r['prior_task_score'] for r in results['rl']]):.4f}")
    print(f"  • SFT average prior score: {np.mean([r['prior_task_score'] for r in results['sft']]):.4f}")
    print(f"  • RL average KL: {np.mean([r['kl_divergence'] for r in results['rl']]):.4f}")
    print(f"  • SFT average KL: {np.mean([r['kl_divergence'] for r in results['sft']]):.4f}")

## Run Experiment