<a href="https://colab.research.google.com/github/JovannyReb/GRPO_Reasoninig_Gym/blob/main/training_smol_rl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install -q -U transformers peft accelerate datasets trl reasoning_gym wandb flash_attn

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m80.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.6/71.6 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m125.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.6/503.6 kB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.7/564.7 kB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

In [None]:
import re
import torch
import reasoning_gym
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from rich import print
import wandb
from datasets import Dataset

wandb.login()

# Create a simple dataset for format training
def create_simple_questions(size=100):
    questions = [
        "What is 2 + 2?",
        "What color is the sky?",
        "What is the capital of France?",
        "How many days are in a week?",
        "What is 10 - 5?",
        "What animal says 'moo'?",
        "How many wheels does a car have?",
        "What comes after Monday?",
        "What is 3 × 4?",
        "What season comes after winter?"
    ]

    # Repeat questions to reach desired size
    extended_questions = []
    for i in range(size):
        extended_questions.append(questions[i % len(questions)])

    return [{"question": q} for q in extended_questions]

# Create simple dataset
simple_data = create_simple_questions(size=200)
train_dataset = Dataset.from_list(simple_data)

# System prompt
SYSTEM_PROMPT = (
    "Please think step by step and then give your answer.\n"
    "Use this exact format: <think>your reasoning here</think> <answer>your final answer</answer>"
)

def formatting_func(sample):
    return f"{SYSTEM_PROMPT}\n\nQuestion: {sample['question']}\n"

# Detailed format checking functions
def check_think_structure(response: str) -> dict:
    """Check think tag structure and return detailed info"""
    has_opening = "<think>" in response
    has_closing = "</think>" in response
    has_complete_pair = bool(re.search(r"<think>.*?</think>", response, re.DOTALL))

    return {
        "has_opening": has_opening,
        "has_closing": has_closing,
        "has_complete_pair": has_complete_pair
    }

def check_answer_structure(response: str) -> dict:
    """Check answer tag structure and return detailed info"""
    has_opening = "<answer>" in response
    has_closing = "</answer>" in response
    has_complete_pair = bool(re.search(r"<answer>.*?</answer>", response, re.DOTALL))

    return {
        "has_opening": has_opening,
        "has_closing": has_closing,
        "has_complete_pair": has_complete_pair
    }

def get_structure_score(response: str) -> float:
    """
    Detailed reward structure:
    Think tags:
    - 0.1: Has <think> opening tag
    - 0.1: Has </think> closing tag
    - 0.3: Has complete <think>...</think> pair (includes opening/closing bonuses)

    Answer tags:
    - 0.1: Has <answer> opening tag
    - 0.1: Has </answer> closing tag
    - 0.3: Has complete <answer>...</answer> pair (includes opening/closing bonuses)

    Maximum possible: 0.6 (0.3 for think + 0.3 for answer)
    """
    think_info = check_think_structure(response)
    answer_info = check_answer_structure(response)

    score = 0.0

    # Think tag scoring
    if think_info["has_complete_pair"]:
        score += 0.3  # Complete pair gets full points
    else:
        # Partial points for individual tags
        if think_info["has_opening"]:
            score += 0.1
        if think_info["has_closing"]:
            score += 0.1

    # Answer tag scoring
    if answer_info["has_complete_pair"]:
        score += 0.3  # Complete pair gets full points
    else:
        # Partial points for individual tags
        if answer_info["has_opening"]:
            score += 0.1
        if answer_info["has_closing"]:
            score += 0.1

    return score

# Detailed reward function focusing on progressive structure learning
def structure_reward_fn(prompts, completions, **kwargs):
    rewards = []
    print(f"\n=== Reward Function Called ===")
    print(f"Processing {len(completions)} completions")

    # Track statistics
    stats = {
        "think_opening": 0,
        "think_closing": 0,
        "think_complete": 0,
        "answer_opening": 0,
        "answer_closing": 0,
        "answer_complete": 0,
        "perfect_structure": 0
    }

    for i, completion in enumerate(completions):
        score = get_structure_score(completion)
        rewards.append(score)

        # Gather statistics
        think_info = check_think_structure(completion)
        answer_info = check_answer_structure(completion)

        if think_info["has_opening"]: stats["think_opening"] += 1
        if think_info["has_closing"]: stats["think_closing"] += 1
        if think_info["has_complete_pair"]: stats["think_complete"] += 1
        if answer_info["has_opening"]: stats["answer_opening"] += 1
        if answer_info["has_closing"]: stats["answer_closing"] += 1
        if answer_info["has_complete_pair"]: stats["answer_complete"] += 1
        if score == 0.6: stats["perfect_structure"] += 1

        # Show full completion for all examples (with truncation for very long ones)
        print(f"\n--- Example {i+1} ---")
        if len(completion) > 500:
            print(f"FULL COMPLETION: {repr(completion[:250])}...[TRUNCATED]...{repr(completion[-100:])}")
        else:
            print(f"FULL COMPLETION: {repr(completion)}")
        print(f"LENGTH: {len(completion)} characters")
        print(f"Think - Opening: {think_info['has_opening']}, Closing: {think_info['has_closing']}, Complete: {think_info['has_complete_pair']}")
        print(f"Answer - Opening: {answer_info['has_opening']}, Closing: {answer_info['has_closing']}, Complete: {answer_info['has_complete_pair']}")
        print(f"SCORE: {score:.2f}/0.60")
        print(f"--- End Example {i+1} ---")

    total = len(rewards)
    avg_reward = sum(rewards) / total if rewards else 0

    print(f"\n=== Reward Statistics ===")
    print(f"Average reward: {avg_reward:.3f}/0.60 ({avg_reward/0.6*100:.1f}%)")
    print(f"Perfect structure (0.6): {stats['perfect_structure']}/{total} ({stats['perfect_structure']/total*100:.1f}%)")
    print(f"\nThink Tags:")
    print(f"  Opening <think>: {stats['think_opening']}/{total} ({stats['think_opening']/total*100:.1f}%)")
    print(f"  Closing </think>: {stats['think_closing']}/{total} ({stats['think_closing']/total*100:.1f}%)")
    print(f"  Complete pairs: {stats['think_complete']}/{total} ({stats['think_complete']/total*100:.1f}%)")
    print(f"\nAnswer Tags:")
    print(f"  Opening <answer>: {stats['answer_opening']}/{total} ({stats['answer_opening']/total*100:.1f}%)")
    print(f"  Closing </answer>: {stats['answer_closing']}/{total} ({stats['answer_closing']/total*100:.1f}%)")
    print(f"  Complete pairs: {stats['answer_complete']}/{total} ({stats['answer_complete']/total*100:.1f}%)")
    print("=== End Reward Function ===\n")

    return rewards

# Load model
MODEL_NAME = "HuggingFaceTB/SmolLM-135M-Instruct"

def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    llm = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    return tokenizer, llm

tokenizer, model = load_model()

# LoRa config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
print("Trainable parameters:")
model.print_trainable_parameters()

# Training config
training_args = GRPOConfig(
    output_dir="./GRPO_structure",
    learning_rate=5e-5,
    per_device_train_batch_size=2,  # Small batch for debugging
    gradient_accumulation_steps=2,
    max_prompt_length=256,
    max_completion_length=128,
    num_generations=4,  # Generate 4 responses per prompt
    optim="adamw_8bit",
    num_train_epochs=2,
    bf16=True,
    report_to=["wandb"],
    remove_unused_columns=False,
    logging_steps=5,
    save_steps=50,
)

# Add prompts to dataset
train_dataset = train_dataset.map(lambda x: {"prompt": formatting_func(x)})

# Test the reward function on a sample first
print("Testing reward function with sample completions:")
test_completions = [
    "I think the answer is 4.",  # No structure (0.0)
    "Let me <think> about this. The answer is 4",  # Think opening only (0.1)
    "Let me think </think> about this. The answer is 4",  # Think closing only (0.1)
    "<think>Let me think about this</think> The answer is 4",  # Complete think (0.3)
    "The answer is <answer>4",  # Answer opening only (0.1)
    "The answer is 4</answer>",  # Answer closing only (0.1)
    "The answer is <answer>4</answer>",  # Complete answer (0.3)
    "<think>2 + 2 = 4</think> <answer>4</answer>",  # Perfect structure (0.6)
    "<think>Let me think</think> The answer is <answer>4",  # Mixed (0.4)
    "I need to <think> calculate </think> and give <answer> the result </answer>",  # Perfect (0.6)
]

print(f"\nTesting {len(test_completions)} examples:")
for i, completion in enumerate(test_completions):
    score = get_structure_score(completion)
    think_info = check_think_structure(completion)
    answer_info = check_answer_structure(completion)
    print(f"{i+1:2d}. Score: {score:.1f} | Think: {think_info['has_complete_pair']} | Answer: {answer_info['has_complete_pair']} | Text: {completion[:50]}...")

test_rewards = structure_reward_fn([], test_completions)

# Create trainer
trainer = GRPOTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    reward_funcs=[structure_reward_fn]
)

# Initialize wandb
wandb.init(
    project="GRPO",
    config={
        "model_name": MODEL_NAME,
        "dataset_size": len(train_dataset),
        "learning_rate": training_args.learning_rate,
        "batch_size": training_args.per_device_train_batch_size,
        "num_generations": training_args.num_generations,
    }
)

print("Starting training...")
trainer.train()

print("Training completed!")

# Test the model after training
print("\n=== Testing Trained Model ===")
test_prompt = formatting_func({"question": "What is 5 + 3?"})

# Get the device the model is on
device = next(model.parameters()).device
print(f"Model device: {device}")

# Move inputs to the same device as the model
inputs = tokenizer(test_prompt, return_tensors="pt").to(device)

print(f"TEST PROMPT: {repr(test_prompt)}")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=False  # Disable cache to avoid gradient checkpointing warning
    )

full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_response[len(test_prompt):]  # Remove prompt

print(f"\nFULL MODEL OUTPUT: {repr(full_response)}")
print(f"\nMODEL RESPONSE (prompt removed): {repr(response)}")
print(f"\nRESPONSE LENGTH: {len(response)} characters")
print(f"STRUCTURE SCORE: {get_structure_score(response):.2f}/0.60")

# Show detailed breakdown
think_info = check_think_structure(response)
answer_info = check_answer_structure(response)
print(f"\nDETAILED BREAKDOWN:")
print(f"Think tags - Opening: {think_info['has_opening']}, Closing: {think_info['has_closing']}, Complete: {think_info['has_complete_pair']}")
print(f"Answer tags - Opening: {answer_info['has_opening']}, Closing: {answer_info['has_closing']}, Complete: {answer_info['has_complete_pair']}")

# Test a few more examples
print(f"\n=== Testing Multiple Examples ===")
test_questions = ["What is 10 - 7?", "What color is grass?", "How many sides does a triangle have?"]
for i, question in enumerate(test_questions, 1):
    test_prompt = formatting_func({"question": question})
    inputs = tokenizer(test_prompt, return_tensors="pt").to(device)  # Move to device

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id,
            use_cache=False  # Disable cache
        )

    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = full_response[len(test_prompt):]
    score = get_structure_score(response)

    print(f"\nTest {i}: {question}")
    print(f"Response: {repr(response)}")
    print(f"Score: {score:.2f}/0.60")



KeyboardInterrupt: 