# RL Training Boilerplate

This notebook provides a modular boilerplate for Reinforcement Learning training of language models using PyTorch directly (no verl).

**Features:**
- All hyperparameters at the top for easy configuration
- Modular sections: model loading, tools, chat templates, environment, reward model
- WandB integration for logging and tracking
- Easy to swap models and modify components
- Pure PyTorch implementation

**Requirements:**
- GPU with sufficient VRAM (A100 recommended)
- WandB account for logging

## 🛠️ Setup & Dependencies

In [None]:
#@title Install Dependencies
!nvidia-smi -L || true

import sys
print("Python:", sys.version)

# Install required packages
try:
    %pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                     datasets>=3.3.2 torch wandb huggingface_hub \
                     sentencepiece protobuf tqdm matplotlib pandas
except Exception:
    !pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                     datasets>=3.3.2 torch wandb huggingface_hub \
                     sentencepiece protobuf tqdm matplotlib pandas

import os, random, time, json, platform
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import pandas as pd
from IPython.display import display

print("\n=== Environment ===")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "Please connect a GPU for RL training."

## 🔑 WandB & HuggingFace Authentication

In [None]:
#@title Set API Keys (Optional - fill in if needed)
import os

# WandB API key - get from https://wandb.ai/authorize
WANDB_API_KEY = ""  # Your WandB API key
if WANDB_API_KEY:
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY

# HuggingFace token - get from https://huggingface.co/settings/tokens
HF_TOKEN = ""  # Your HuggingFace token
if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN

In [None]:
#@title Login to WandB and HuggingFace
import wandb
from huggingface_hub import login

# WandB login
try:
    wandb.login()
    print("✓ WandB login successful")
except Exception as e:
    print(f"⚠ WandB login failed: {e}")
    print("Training will continue without WandB logging")

# HuggingFace login
try:
    if os.environ.get('HF_TOKEN'):
        login(token=os.environ['HF_TOKEN'])
        print("✓ HuggingFace login successful")
except Exception as e:
    print(f"⚠ HuggingFace login failed: {e}")

## ⚙️ HYPERPARAMETERS

**All configurable parameters are in this section for easy modification.**

In [None]:
#@title Hyperparameters Configuration
from dataclasses import dataclass
from typing import Optional, List

@dataclass
class RLConfig:
    # ==================== MODEL CONFIGURATION ====================
    # Policy model (the model being trained)
    policy_model_id: str = "Qwen/Qwen3-0.6B"  # Use Qwen3-0.6B-Base for base model
    policy_model_dtype: str = "bfloat16"  # "bfloat16", "float16", or "float32"
    
    # Reference model (for KL penalty, optional)
    use_reference_model: bool = False
    reference_model_id: Optional[str] = None  # If None, uses policy_model_id
    
    # Reward model (optional)
    use_reward_model: bool = False
    reward_model_id: Optional[str] = None
    
    # ==================== LORA CONFIGURATION ====================
    # NOTE: RL typically uses full-weight updates. LoRA is available but not recommended.
    use_lora: bool = False  # Set to True only for parameter-efficient training
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: Optional[List[str]] = None  # Auto-detect if None
    
    # ==================== TRAINING CONFIGURATION ====================
    num_steps: int = 100
    num_prompts_per_step: int = 4  # Number of unique prompts per step
    samples_per_prompt: int = 4  # Rollouts per prompt (effective batch = num_prompts * samples_per_prompt)
    
    learning_rate: float = 1e-4
    weight_decay: float = 0.0
    grad_accumulation_steps: int = 1
    max_grad_norm: float = 1.0
    
    max_new_tokens: int = 512  # Qwen3 recommends 32768 for complex reasoning, use smaller for testing
    train_temperature: float = 0.7  # For non-thinking mode training
    eval_temperature: float = 0.7  # For non-thinking mode eval
    top_p: float = 0.8  # Qwen3 recommendation for non-thinking
    top_k: int = 20  # Qwen3 recommendation
    
    # ==================== RL ALGORITHM ====================
    algorithm: str = "pg"  # Currently only "pg" (policy gradient) implemented
    use_kl_penalty: bool = False
    kl_coef: float = 0.1
    
    # Advantage normalization
    normalize_advantages: bool = True
    normalize_rewards: bool = False  # Set to True for reward normalization
    
    # ==================== QWEN3-SPECIFIC CONFIGURATION ====================
    # Chat template
    use_chat_template: bool = False  # Set to True for instruct models
    system_prompt: Optional[str] = None
    
    # Thinking mode (Qwen3 feature)
    enable_thinking: bool = False  # Enable thinking mode (<think>...</think> blocks)
    thinking_train_temperature: float = 0.6  # Qwen3 recommendation for thinking mode
    thinking_train_top_p: float = 0.95  # Qwen3 recommendation for thinking mode
    thinking_eval_temperature: float = 0.6  # For thinking mode eval
    thinking_eval_top_p: float = 0.95
    parse_thinking: bool = False  # Parse and separate thinking content from response
    
    # Tool use (requires Qwen-Agent)
    use_qwen_agent: bool = False  # Enable Qwen-Agent for tool calling
    qwen_agent_tools: Optional[List[str]] = None  # List of tools (e.g., ['code_interpreter', 'time'])
    
    # ==================== DATA CONFIGURATION ====================
    dataset_name: str = "openai/gsm8k"
    dataset_config: Optional[str] = "main"
    dataset_split: str = "train"
    prompt_field: str = "question"
    answer_field: Optional[str] = "answer"
    
    val_size: int = 200
    val_every: int = 10
    val_batch_size: int = 32
    
    # ==================== PROMPT TEMPLATE ====================
    prompt_template: str = (
        "Solve step by step.\n"
        "Problem: {prompt}\n\nSolution:"
    )
    
    # ==================== REWARD CONFIGURATION ====================
    reward_type: str = "rule"  # "model", "rule", or "custom"
    use_ground_truth: bool = True  # Whether to pass ground truth to reward function
    
    # ==================== TOOLS & ENVIRONMENT ====================
    use_tools: bool = False
    use_environment: bool = False
    
    # ==================== LOGGING & CHECKPOINTING ====================
    wandb_project: str = "rl-training"
    wandb_run_name: Optional[str] = None
    log_every: int = 10
    save_every: int = 50
    ema_momentum: float = 0.9
    
    output_dir: str = f"./run_rl_{int(time.time())}"
    push_to_hub: bool = False
    hub_repo_id: Optional[str] = None
    
    seed: int = 42
    
    def __post_init__(self):
        # Auto-generate run name
        if self.wandb_run_name is None:
            model_short = self.policy_model_id.split("/")[-1]
            mode = "thinking" if self.enable_thinking else "normal"
            self.wandb_run_name = f"rl_{self.algorithm}_{model_short}_{mode}_{int(time.time())}"
        
        # Set reference model to policy model if not specified
        if self.use_reference_model and self.reference_model_id is None:
            self.reference_model_id = self.policy_model_id
        
        # Auto-detect LoRA target modules
        if self.use_lora and self.lora_target_modules is None:
            self.lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        
        # Validate configuration
        if self.use_kl_penalty and not self.use_reference_model:
            raise ValueError("KL penalty requires reference model (set use_reference_model=True)")
        if self.use_reward_model and not self.reward_model_id:
            raise ValueError("Reward model ID required when use_reward_model=True")
        if self.enable_thinking:
            print("⚠ Thinking mode enabled. Using thinking-mode parameters (temp=0.6, top_p=0.95)")
            print("  Note: Thinking mode requires adequate max_new_tokens (32768 recommended)")
        
        os.makedirs(self.output_dir, exist_ok=True)

# Create config
config = RLConfig()

print("\n=== RL Training Configuration ===")
print(f"Policy Model: {config.policy_model_id}")
print(f"Use LoRA: {config.use_lora}")
print(f"Algorithm: {config.algorithm}")
print(f"Enable Thinking: {config.enable_thinking}")
print(f"Use Chat Template: {config.use_chat_template}")
print(f"Steps: {config.num_steps}")
print(f"Prompts per step: {config.num_prompts_per_step}")
print(f"Samples per prompt: {config.samples_per_prompt}")
print(f"Effective batch size: {config.num_prompts_per_step * config.samples_per_prompt}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Max New Tokens: {config.max_new_tokens}")
print(f"Output Dir: {config.output_dir}")

# Save config
with open(os.path.join(config.output_dir, "config.json"), "w") as f:
    config_dict = {k: v for k, v in config.__dict__.items() if not callable(v)}
    json.dump(config_dict, f, indent=2)
print(f"✓ Config saved")

## 🎲 Set Random Seed

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.seed)
print(f"✓ Random seed set to {config.seed}")

## 📦 Model Loading

This section loads the policy model, reference model (if needed), and reward model (if needed).

In [None]:
#@title Model Loading Utilities
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model

def get_torch_dtype(dtype_str: str):
    if dtype_str == "bfloat16":
        return torch.bfloat16
    elif dtype_str == "float16":
        return torch.float16
    else:
        return torch.float32

def load_tokenizer(model_id: str):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"
    tokenizer.padding_side = "left"
    return tokenizer

def load_causal_lm(model_id: str, dtype_str: str, use_lora: bool = False, lora_config=None):
    dtype = get_torch_dtype(dtype_str)
    print(f"Loading model: {model_id}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=dtype, device_map="auto", trust_remote_code=True
    )
    if use_lora and lora_config:
        print(f"Applying LoRA...")
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    return model

print("✓ Model loading utilities defined")

In [None]:
#@title Load Tokenizer
print("=== Loading Tokenizer ===")
tokenizer = load_tokenizer(config.policy_model_id)
print(f"✓ Tokenizer loaded (vocab: {len(tokenizer)})")

In [None]:
#@title Load Policy Model
print("=== Loading Policy Model ===")

lora_config = None
if config.use_lora:
    lora_config = LoraConfig(
        r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout,
        bias="none", task_type="CAUSAL_LM", target_modules=config.lora_target_modules
    )

policy_model = load_causal_lm(
    config.policy_model_id, config.policy_model_dtype,
    use_lora=config.use_lora, lora_config=lora_config
)
policy_model.config.use_cache = False
print(f"✓ Policy model loaded")

In [None]:
#@title Load Reference Model (optional)
reference_model = None
if config.use_reference_model:
    print("=== Loading Reference Model ===")
    reference_model = load_causal_lm(config.reference_model_id, config.policy_model_dtype)
    reference_model.eval()
    for param in reference_model.parameters():
        param.requires_grad_(False)
    print("✓ Reference model loaded")
else:
    print("⊗ Reference model not used")

In [None]:
#@title Load Reward Model (optional)
reward_model = None
if config.use_reward_model and config.reward_model_id:
    print("=== Loading Reward Model ===")
    reward_model = AutoModelForSequenceClassification.from_pretrained(
        config.reward_model_id, torch_dtype=torch.bfloat16, device_map="auto"
    )
    reward_model.eval()
    for param in reward_model.parameters():
        param.requires_grad_(False)
    print("✓ Reward model loaded")
else:
    print("⊗ Reward model not used (will use rule-based rewards)")

## 💬 Chat Template & Thinking Mode

**Qwen3-specific features for chat formatting and thinking mode.**

In [None]:
#@title Chat Template Configuration

def format_prompt(prompt: str, enable_thinking: Optional[bool] = None) -> str:
    """
    Format prompt according to configuration.
    
    Args:
        prompt: Raw prompt text
        enable_thinking: Override config.enable_thinking (for testing)
    
    Returns:
        Formatted prompt string
    """
    if enable_thinking is None:
        enable_thinking = config.enable_thinking
    
    if config.use_chat_template:
        # Use Qwen3 chat template
        messages = []
        if config.system_prompt:
            messages.append({"role": "system", "content": config.system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        if hasattr(tokenizer, 'apply_chat_template'):
            # Qwen3 supports enable_thinking parameter
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=enable_thinking  # Qwen3-specific
            )
        else:
            # Fallback for other models
            formatted = ""
            if config.system_prompt:
                formatted += f"System: {config.system_prompt}\n\n"
            formatted += f"User: {prompt}\n\nAssistant:"
            return formatted
    else:
        # Simple template (for base models)
        return config.prompt_template.format(prompt=prompt)

# Test both modes
print("=== Chat Template Examples ===")
test_q = "What is 2+2?"

if config.use_chat_template:
    # Test non-thinking mode
    prompt_no_think = format_prompt(test_q, enable_thinking=False)
    print(f"\nNon-thinking mode:")
    print(prompt_no_think[:300])
    
    # Test thinking mode
    prompt_think = format_prompt(test_q, enable_thinking=True)
    print(f"\nThinking mode:")
    print(prompt_think[:300])
else:
    prompt = format_prompt(test_q)
    print(f"\nSimple template:")
    print(prompt[:300])

print(f"\n✓ Chat template configured (thinking={config.enable_thinking})")

### 🧠 Thinking Mode Utilities

Qwen3 can generate <think>...</think> blocks for reasoning. These utilities parse and handle thinking content.

In [None]:
#@title Thinking Mode Utilities

QWEN3_THINK_END_TOKEN = 151668  # Token ID for </think>

def parse_thinking_response(output_ids: list, tokenizer) -> tuple:
    """
    Parse Qwen3 response to separate thinking content from final response.
    
    Args:
        output_ids: List of generated token IDs
        tokenizer: Tokenizer instance
    
    Returns:
        (thinking_content, response_content) tuple
    """
    try:
        # Find </think> token (151668)
        index = len(output_ids) - output_ids[::-1].index(QWEN3_THINK_END_TOKEN)
    except ValueError:
        # No thinking block found
        index = 0
    
    thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    response_content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
    
    return thinking_content, response_content

def get_generation_params(enable_thinking: bool = None, mode: str = "train"):
    """
    Get Qwen3-recommended generation parameters.
    
    Args:
        enable_thinking: If None, uses config.enable_thinking
        mode: "train" or "eval" - determines which temperature/top_p to use
    
    Returns:
        Dictionary of generation parameters
    """
    if enable_thinking is None:
        enable_thinking = config.enable_thinking
    
    if enable_thinking:
        # Qwen3 best practices for thinking mode
        temp = config.thinking_train_temperature if mode == "train" else config.thinking_eval_temperature
        top_p = config.thinking_train_top_p if mode == "train" else config.thinking_eval_top_p
    else:
        # Qwen3 best practices for non-thinking mode
        temp = config.train_temperature if mode == "train" else config.eval_temperature
        top_p = config.top_p  # Same for both train/eval in non-thinking
    
    return {
        "do_sample": True,
        "temperature": temp,
        "top_p": top_p,
        "top_k": config.top_k,
    }

print("✓ Thinking mode utilities defined")
print(f"  Thinking token ID: {QWEN3_THINK_END_TOKEN}")
print(f"  Parse thinking: {config.parse_thinking}")

## 🛠️ Tool Use (Qwen-Agent)

**Qwen3 excels at tool calling. Use Qwen-Agent for best results.**

In [None]:
#@title Tool Use Configuration

# Check if Qwen-Agent is available
try:
    import qwen_agent
    QWEN_AGENT_AVAILABLE = True
except ImportError:
    QWEN_AGENT_AVAILABLE = False

tools = {}
qwen_agent_bot = None

if config.use_qwen_agent and QWEN_AGENT_AVAILABLE:
    print("=== Qwen-Agent Tool Use ===")
    from qwen_agent.agents import Assistant
    
    # Configure LLM for Qwen-Agent
    llm_cfg = {
        'model': config.policy_model_id,
        'model_server': 'http://localhost:8000/v1',  # Set your API endpoint
        'api_key': 'EMPTY',
        'generate_cfg': {
            'thought_in_content': True,  # Include thinking in content
        },
    }
    
    # Configure tools
    tool_list = config.qwen_agent_tools if config.qwen_agent_tools else []
    
    # Example tools:
    # - 'code_interpreter': Execute Python code
    # - MCP servers for time, fetch, etc.
    
    print(f"Configured tools: {tool_list}")
    print("✓ Qwen-Agent configured")
    
    # Note: Agent will be used during generation if needed
    # qwen_agent_bot = Assistant(llm=llm_cfg, function_list=tool_list)

elif config.use_qwen_agent and not QWEN_AGENT_AVAILABLE:
    print("⚠ Qwen-Agent requested but not installed")
    print("  Install: pip install qwen-agent")

elif config.use_tools:
    print("=== Custom Tools ===")
    # Define custom tools here
    # Example: tools["calculator"] = calculator_function
    print(f"✓ Custom tools loaded: {list(tools.keys())}")

else:
    print("⊗ Tools not used")

# Environment
environment = None
if config.use_environment:
    # TODO: Initialize environment here
    print("✓ Environment initialized")
else:
    print("⊗ Environment not used")

## 🎁 Reward Function

**IMPORTANT: Customize this for your task!**

In [None]:
#@title Reward Function
import re

def compute_reward(prompt: str, response: str, ground_truth: Optional[str] = None) -> float:
    """
    Compute reward for a generated response.
    
    **TODO: CUSTOMIZE THIS FOR YOUR TASK!**
    
    Args:
        prompt: The input prompt
        response: Generated response
        ground_truth: Ground truth answer (if available)
    
    Returns:
        Reward score (float)
    """
    if config.reward_type == "model" and reward_model is not None:
        # Use learned reward model
        with torch.no_grad():
            inputs = tokenizer(
                prompt + response, 
                return_tensors="pt", 
                truncation=True, 
                max_length=1024
            ).to(reward_model.device)
            outputs = reward_model(**inputs)
            # Most reward models output a single scalar
            if hasattr(outputs, 'score'):
                reward = outputs.score.item()
            else:
                reward = outputs.logits[0].item()
        return reward
    
    elif config.reward_type == "rule":
        # Rule-based reward (CUSTOMIZE THIS FOR YOUR TASK)
        reward = 0.0
        
        # Example 1: Format checking (e.g., answer in brackets)
        if re.search(r"\[.*?\]", response):
            reward += 0.5
        
        # Example 2: Exact match with ground truth (if available)
        if ground_truth is not None and config.use_ground_truth:
            # Extract predicted answer from response
            pred_match = re.search(r"\[\s*([^\]]+)\s*\]", response)
            # Extract ground truth answer (customize based on format)
            if "####" in ground_truth:
                # GSM8K format: answer after ####
                true_match = re.search(r"####\s*([^\s]+)", ground_truth)
            else:
                true_match = re.search(r"\[\s*([^\]]+)\s*\]", ground_truth)
            
            if pred_match and true_match:
                pred = pred_match.group(1).strip().replace(",", "")
                true = true_match.group(1).strip().replace(",", "")
                if pred == true:
                    reward += 1.0  # Correct answer bonus
        
        # Example 3: Length penalty (optional)
        if len(response.split()) > 500:
            reward -= 0.1
        
        return reward
    
    else:
        # Custom reward (TODO: implement your own)
        return 0.0

# Test reward function
test_response = "Let me solve step by step. 2 + 2 = 4. The answer is [4]."
test_ground_truth = "#### 4"
test_reward = compute_reward("What is 2+2?", test_response, test_ground_truth)

print(f"=== Reward Function ===")
print(f"Example reward: {test_reward:.4f}")
if config.use_ground_truth:
    print(f"Ground truth used: Yes")
else:
    print(f"Ground truth used: No")
print("\n⚠ WARNING: Placeholder reward function!")
print("   Customize compute_reward() for your specific task.")

## 📊 Data Loading

In [None]:
#@title Load Dataset
from datasets import load_dataset

print("=== Loading Dataset ===")
if config.dataset_config:
    dataset = load_dataset(config.dataset_name, config.dataset_config, split=config.dataset_split)
else:
    dataset = load_dataset(config.dataset_name, split=config.dataset_split)

# Validate required fields exist
if config.prompt_field not in dataset.column_names:
    raise ValueError(f"Prompt field '{config.prompt_field}' not found in dataset. Available: {dataset.column_names}")
if config.use_ground_truth and config.answer_field and config.answer_field not in dataset.column_names:
    raise ValueError(f"Answer field '{config.answer_field}' not found in dataset. Available: {dataset.column_names}")

# Split into train/val
val_size = min(config.val_size, len(dataset))
val_dataset = dataset.select(range(val_size))
train_dataset = dataset.select(range(val_size, len(dataset)))

print(f"Train: {len(train_dataset):,} | Val: {len(val_dataset):,}")
print(f"Using fields: prompt='{config.prompt_field}', answer='{config.answer_field}'")

# Prepare prompts
train_prompts = [format_prompt(ex[config.prompt_field]) for ex in train_dataset]
val_prompts = [format_prompt(ex[config.prompt_field]) for ex in val_dataset]

print(f"\n✓ Prompts prepared")
print(f"Example prompt:\n{train_prompts[0][:200]}")
if config.answer_field and train_dataset[0].get(config.answer_field):
    print(f"\nExample answer:\n{str(train_dataset[0][config.answer_field])[:200]}")

## ✅ Pre-Training Validation

**Test model generation BEFORE training to ensure everything works correctly.**

This section validates:
- Chat template formatting
- Model generation (both training and inference modes)
- Thinking mode (if enabled)
- Generation parameters

In [None]:
#@title Pre-Training Validation

print("=" * 80)
print("PRE-TRAINING VALIDATION")
print("=" * 80)

# Test prompts
test_prompts = [
    "What is 15 + 27?",
    "Solve: 3x + 5 = 20",
]

print(f"\nTesting with {len(test_prompts)} prompts")
print(f"Model: {config.policy_model_id}")
print(f"Thinking mode: {config.enable_thinking}")
print(f"Chat template: {config.use_chat_template}")

# Test generation function
def test_generation(prompt, mode="train"):
    """Test generation with given prompt."""
    formatted = format_prompt(prompt, enable_thinking=config.enable_thinking)
    
    inputs = tokenizer(
        [formatted],
        padding=True,
        truncation=True,
        max_length=2048,
        return_tensors="pt"
    ).to(DEVICE)
    
    # Get generation params (use mode parameter)
    gen_params = get_generation_params(config.enable_thinking, mode=mode)
    gen_params.update({
        "max_new_tokens": min(config.max_new_tokens, 512),  # Use smaller for testing
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "use_cache": True,
    })
    
    policy_model.eval()
    policy_model.config.use_cache = True
    
    with torch.no_grad():
        output = policy_model.generate(**inputs, **gen_params)
    
    # Extract generated part
    generated_ids = output[0, inputs.input_ids.size(1):].tolist()
    
    # Parse thinking if enabled
    if config.enable_thinking and config.parse_thinking:
        thinking, response = parse_thinking_response(generated_ids, tokenizer)
        return thinking, response, generated_ids
    else:
        response = tokenizer.decode(generated_ids, skip_special_tokens=True)
        return None, response, generated_ids

# Test each prompt
for i, prompt in enumerate(test_prompts):
    print(f"\n{'=' * 80}")
    print(f"Test {i+1}: {prompt}")
    print("-" * 80)
    
    # Test training mode
    print("\n[Training Mode Generation]")
    thinking_train, response_train, ids_train = test_generation(prompt, mode="train")
    if thinking_train:
        print(f"Thinking: {thinking_train[:200]}")
        print(f"Response: {response_train[:200]}")
    else:
        print(f"Response: {response_train[:200]}")
    print(f"Tokens generated: {len(ids_train)}")
    
    # Test eval mode
    print("\n[Eval Mode Generation]")
    thinking_eval, response_eval, ids_eval = test_generation(prompt, mode="eval")
    if thinking_eval:
        print(f"Thinking: {thinking_eval[:200]}")
        print(f"Response: {response_eval[:200]}")
    else:
        print(f"Response: {response_eval[:200]}")
    print(f"Tokens generated: {len(ids_eval)}")

print(f"\n{'=' * 80}")
print("PRE-TRAINING VALIDATION COMPLETE")
print("✓ Model generation working correctly")
print("✓ Chat template functioning")
if config.enable_thinking:
    print("✓ Thinking mode active")
print("=" * 80)

## 📊 Baseline Evaluation

Evaluate model performance before training to establish baseline metrics.

In [None]:
#@title Baseline Evaluation

def evaluate_model(model, prompts, dataset, num_samples=None, desc="Eval"):
    """
    Evaluate model on a dataset.
    
    Args:
        model: Model to evaluate
        prompts: List of formatted prompts
        dataset: Dataset with ground truth
        num_samples: Number of samples to evaluate (None = all)
        desc: Description for progress bar
    
    Returns:
        Dictionary with evaluation metrics
    """
    if num_samples is None:
        num_samples = len(prompts)
    else:
        num_samples = min(num_samples, len(prompts))
    
    model.eval()
    model.config.use_cache = True
    
    total_reward = 0.0
    all_rewards = []
    
    with torch.no_grad():
        for i in tqdm(range(0, num_samples, config.val_batch_size), desc=desc):
            batch_end = min(i + config.val_batch_size, num_samples)
            batch_prompts = prompts[i:batch_end]
            
            # Tokenize
            encoding = tokenizer(
                batch_prompts,
                padding=True,
                truncation=True,
                max_length=2048,
                return_tensors="pt"
            ).to(DEVICE)
            
            # Get Qwen3-appropriate generation parameters for eval
            gen_params = get_generation_params(config.enable_thinking, mode="eval")
            gen_params.update({
                "max_new_tokens": config.max_new_tokens,
                "pad_token_id": tokenizer.pad_token_id,
                "eos_token_id": tokenizer.eos_token_id,
                "use_cache": True,
            })
            
            outputs = model.generate(**encoding, **gen_params)
            generated_ids = outputs[:, encoding.input_ids.size(1):]
            
            # Parse responses
            generated_texts = []
            for gen_ids in generated_ids:
                if config.enable_thinking and config.parse_thinking:
                    # Parse thinking content
                    _, response = parse_thinking_response(gen_ids.tolist(), tokenizer)
                    generated_texts.append(response)
                else:
                    response = tokenizer.decode(gen_ids, skip_special_tokens=True)
                    generated_texts.append(response)
            
            # Compute rewards
            for j, (prompt, response) in enumerate(zip(batch_prompts, generated_texts)):
                idx = i + j
                ground_truth = dataset[idx].get(config.answer_field) if config.answer_field else None
                reward = compute_reward(prompt, response, ground_truth)
                all_rewards.append(reward)
                total_reward += reward
    
    mean_reward = total_reward / num_samples
    std_reward = np.std(all_rewards) if len(all_rewards) > 1 else 0.0
    
    model.train()
    return {
        "mean_reward": mean_reward,
        "std_reward": std_reward,
        "num_samples": num_samples
    }

# Run baseline evaluation
print("\n" + "="*80)
print("BASELINE EVALUATION")
print("="*80)

baseline_metrics = evaluate_model(
    policy_model, 
    val_prompts, 
    val_dataset, 
    num_samples=min(100, len(val_prompts)),
    desc="Baseline"
)

print(f"\nBaseline Metrics:")
print(f"  Mean Reward: {baseline_metrics['mean_reward']:.4f} ± {baseline_metrics['std_reward']:.4f}")
print(f"  Evaluated on: {baseline_metrics['num_samples']} examples")
print("="*80)

# Save baseline
baseline_path = os.path.join(config.output_dir, "baseline_metrics.json")
with open(baseline_path, "w") as f:
    json.dump(baseline_metrics, f, indent=2)
print(f"✓ Baseline saved to {baseline_path}")

## 🔧 Training Utilities

In [None]:
#@title Training Utilities

def mask_after_eos(token_ids: torch.Tensor, eos_id: int) -> torch.Tensor:
    """Create mask for tokens before first EOS."""
    is_eos = (token_ids == eos_id)
    eos_positions = is_eos.cumsum(dim=1)
    return (eos_positions == 0).float()

def compute_model_logprobs(model, input_ids, attention_mask, target_ids, micro_batch_size=8):
    """Compute log probabilities of target tokens."""
    batch_size = input_ids.size(0)
    target_len = target_ids.size(1)
    
    full_ids = torch.cat([input_ids, target_ids], dim=1)
    full_mask = torch.cat([attention_mask, torch.ones_like(target_ids)], dim=1)
    
    if micro_batch_size >= batch_size:
        outputs = model(input_ids=full_ids[:, :-1], attention_mask=full_mask[:, :-1])
        logits = outputs.logits[:, -target_len:, :]
        logprobs = F.log_softmax(logits, dim=-1)
        token_logprobs = logprobs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
    else:
        # Micro-batching for memory efficiency
        token_logprobs_list = []
        for i in range(0, batch_size, micro_batch_size):
            micro_full_ids = full_ids[i:i+micro_batch_size]
            micro_full_mask = full_mask[i:i+micro_batch_size]
            micro_target_ids = target_ids[i:i+micro_batch_size]
            
            outputs = model(input_ids=micro_full_ids[:, :-1], attention_mask=micro_full_mask[:, :-1])
            logits = outputs.logits[:, -target_len:, :]
            logprobs = F.log_softmax(logits, dim=-1)
            micro_token_logprobs = logprobs.gather(-1, micro_target_ids.unsqueeze(-1)).squeeze(-1)
            token_logprobs_list.append(micro_token_logprobs)
        token_logprobs = torch.cat(token_logprobs_list, dim=0)
    
    return token_logprobs

class MetricsTracker:
    """Track and display training metrics with EMA smoothing."""
    
    def __init__(self, ema_momentum=0.9):
        self.ema_momentum = ema_momentum
        self.metrics = []
        self.ema_values = {}
        empty_df = pd.DataFrame(columns=["step", "loss", "loss_ema", "reward", "reward_ema", "kl", "tokens", "val_reward"])
        self.display_handle = display(empty_df, display_id=True)
    
    def update_ema(self, key, value):
        if key not in self.ema_values:
            self.ema_values[key] = value
        else:
            self.ema_values[key] = self.ema_momentum * self.ema_values[key] + (1 - self.ema_momentum) * value
        return self.ema_values[key]
    
    def log(self, step, loss, reward, kl=0.0, tokens=0, val_reward=None):
        loss_ema = self.update_ema("loss", loss)
        reward_ema = self.update_ema("reward", reward)
        
        metric = {
            "step": step, "loss": loss, "loss_ema": loss_ema,
            "reward": reward, "reward_ema": reward_ema, "kl": kl, "tokens": tokens
        }
        
        if val_reward is not None:
            metric["val_reward"] = val_reward
        
        self.metrics.append(metric)
        
        df = pd.DataFrame(self.metrics[-100:])
        styled = df.style.format({
            "loss": "{:.4f}", "loss_ema": "{:.4f}",
            "reward": "{:.4f}", "reward_ema": "{:.4f}", "kl": "{:.4f}",
            "val_reward": lambda x: "" if pd.isna(x) else f"{x:.4f}"
        })
        self.display_handle.update(styled)
        
        return metric

print("✓ Training utilities defined")

## 📈 Initialize WandB

In [None]:
#@title Initialize WandB Run
wandb_run = wandb.init(
    project=config.wandb_project,
    name=config.wandb_run_name,
    config=config.__dict__,
    job_type="training"
)

wandb.watch(policy_model, log="all", log_freq=100)

print(f"✓ WandB initialized")
print(f"  Run: {config.wandb_run_name}")
print(f"  URL: {wandb_run.get_url()}")

## 🚀 Training Loop

In [None]:
#@title Setup Optimizer
optimizer = torch.optim.AdamW(
    policy_model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)
print(f"✓ Optimizer initialized (lr={config.learning_rate})")

In [None]:
#@title Training Loop
metrics_tracker = MetricsTracker(ema_momentum=config.ema_momentum)
total_tokens = 0
best_val_reward = -float('inf')

# Reward normalization tracking (if enabled)
if config.normalize_rewards:
    reward_mean = 0.0
    reward_std = 1.0
    reward_count = 0

print("=" * 80)
print("STARTING TRAINING")
print("=" * 80)

pbar = tqdm(range(config.num_steps), desc="Training")

# Initialize gradients
optimizer.zero_grad(set_to_none=True)

for step in pbar:
    # ========== SAMPLE PROMPTS ==========
    rng = np.random.default_rng(config.seed + step)
    
    # Sample prompts (with replacement to avoid crashes on small datasets)
    num_prompts = min(config.num_prompts_per_step, len(train_dataset))
    prompt_indices = rng.choice(len(train_dataset), size=num_prompts, replace=True)
    
    # Build batch: repeat each prompt for multiple samples
    batch_prompts = []
    batch_ground_truths = []
    batch_original_indices = []
    for idx in prompt_indices:
        for _ in range(config.samples_per_prompt):
            batch_prompts.append(train_prompts[idx])
            if config.use_ground_truth and config.answer_field:
                batch_ground_truths.append(train_dataset[idx][config.answer_field])
            else:
                batch_ground_truths.append(None)
            batch_original_indices.append(idx)
    
    # ========== TOKENIZE ==========
    prompt_encoding = tokenizer(
        batch_prompts, 
        padding=True, 
        truncation=True, 
        max_length=2048, 
        return_tensors="pt"
    ).to(DEVICE)
    
    # ========== GENERATE RESPONSES (ON-POLICY) ==========
    policy_model.eval()
    policy_model.config.use_cache = True
    
    with torch.no_grad():
        # Get Qwen3-appropriate generation parameters
        gen_params = get_generation_params(config.enable_thinking)
        gen_params.update({
            "max_new_tokens": config.max_new_tokens,
            "pad_token_id": tokenizer.pad_token_id,
            "eos_token_id": tokenizer.eos_token_id,
        })
        
        generation_output = policy_model.generate(**prompt_encoding, **gen_params)
        generated_ids = generation_output[:, prompt_encoding.input_ids.size(1):]
    
    # Create validity mask (tokens before first EOS)
    valid_mask = mask_after_eos(generated_ids, tokenizer.eos_token_id)
    
    # Check for empty generations
    if valid_mask.sum() == 0:
        print(f"\nWarning: All sequences empty at step {step}, skipping batch")
        continue
    
    # ========== COMPUTE REWARDS ==========
    with torch.no_grad():
        # Parse responses (handle thinking if enabled)
        generated_texts = []
        for gen_ids in generated_ids:
            if config.enable_thinking and config.parse_thinking:
                # Extract only the response part (not thinking)
                _, response = parse_thinking_response(gen_ids.tolist(), tokenizer)
                generated_texts.append(response)
            else:
                response = tokenizer.decode(gen_ids, skip_special_tokens=True)
                generated_texts.append(response)
        
        rewards = []
        for prompt, response, ground_truth in zip(batch_prompts, generated_texts, batch_ground_truths):
            reward = compute_reward(prompt, response, ground_truth)
            rewards.append(reward)
        
        rewards_tensor = torch.tensor(rewards, device=DEVICE, dtype=torch.float32)
        
        # Reward normalization (if enabled)
        if config.normalize_rewards:
            reward_count += len(rewards)
            reward_mean = (reward_mean * (reward_count - len(rewards)) + rewards_tensor.sum().item()) / reward_count
            reward_var = ((rewards_tensor - reward_mean) ** 2).mean().item()
            reward_std = max(np.sqrt(reward_var), 1e-8)
            normalized_rewards = (rewards_tensor - reward_mean) / reward_std
            rewards_tensor = normalized_rewards
        
        mean_reward = rewards_tensor.mean().item()
    
    # ========== COMPUTE LOG PROBABILITIES ==========
    policy_model.train()
    policy_model.config.use_cache = False
    
    # Policy log probs (with gradient)
    policy_logprobs = compute_model_logprobs(
        policy_model,
        prompt_encoding.input_ids,
        prompt_encoding.attention_mask,
        generated_ids,
        micro_batch_size=8
    )
    
    # ========== COMPUTE KL PENALTY (if using reference model) ==========
    kl_penalty_value = 0.0
    if config.use_kl_penalty and reference_model is not None:
        with torch.no_grad():
            reference_logprobs = compute_model_logprobs(
                reference_model,
                prompt_encoding.input_ids,
                prompt_encoding.attention_mask,
                generated_ids,
                micro_batch_size=8
            )
            # Compute KL per token
            kl_per_token = policy_logprobs.detach() - reference_logprobs
            # Sum per sequence
            kl_per_sequence = (kl_per_token * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1.0)
            # Mean for logging
            kl_penalty_value = kl_per_sequence.mean().item()
            # Subtract KL from rewards (before broadcasting)
            rewards_tensor = rewards_tensor - config.kl_coef * kl_per_sequence
    
    # ========== COMPUTE ADVANTAGES ==========
    # Broadcast rewards to token level
    rewards_broadcast = rewards_tensor.unsqueeze(1).expand_as(policy_logprobs)
    
    # Use mean reward as baseline
    baseline = rewards_tensor.mean()
    advantages = rewards_broadcast - baseline
    
    # Normalize advantages (recommended for stability)
    if config.normalize_advantages:
        adv_mean = (advantages * valid_mask).sum() / valid_mask.sum().clamp(min=1.0)
        adv_std = torch.sqrt(((advantages - adv_mean) ** 2 * valid_mask).sum() / valid_mask.sum().clamp(min=1.0))
        advantages = (advantages - adv_mean) / (adv_std + 1e-8)
    
    # ========== COMPUTE LOSS ==========
    loss_per_token = -advantages.detach() * policy_logprobs * valid_mask
    loss = loss_per_token.sum() / valid_mask.sum().clamp(min=1.0)
    
    # Scale loss for gradient accumulation
    if config.grad_accumulation_steps > 1:
        loss = loss / config.grad_accumulation_steps
    
    # ========== BACKWARD PASS ==========
    loss.backward()
    
    # Gradient clipping and optimizer step
    if (step + 1) % config.grad_accumulation_steps == 0:
        if config.max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
    
    # ========== LOGGING ==========
    total_tokens += int(valid_mask.sum().item())
    
    # Validation (if enabled)
    val_reward = None
    if (step % config.val_every == 0 and step > 0) or (step == config.num_steps - 1):
        val_metrics = evaluate_model(
            policy_model,
            val_prompts,
            val_dataset,
            num_samples=min(50, len(val_prompts)),
            desc=f"Val@{step}"
        )
        val_reward = val_metrics['mean_reward']
        
        # Track best model
        if val_reward > best_val_reward:
            best_val_reward = val_reward
            best_dir = os.path.join(config.output_dir, "best_model")
            os.makedirs(best_dir, exist_ok=True)
            policy_model.save_pretrained(best_dir)
            tokenizer.save_pretrained(best_dir)
            print(f"\n✓ New best model (reward={val_reward:.4f}) saved to {best_dir}")
    
    # Log metrics
    metric = metrics_tracker.log(
        step=step,
        loss=loss.item() * (config.grad_accumulation_steps if config.grad_accumulation_steps > 1 else 1),
        reward=mean_reward,
        kl=kl_penalty_value,
        tokens=total_tokens
    )
    
    if val_reward is not None:
        metric['val_reward'] = val_reward
    
    wandb.log(metric, step=step)
    
    # Update progress bar
    pbar.set_postfix({
        "loss": f"{metric['loss']:.3f}",
        "reward": f"{mean_reward:.3f}",
        "kl": f"{kl_penalty_value:.3f}",
        **({"val": f"{val_reward:.3f}"} if val_reward is not None else {})
    })
    
    # ========== CHECKPOINTING ==========
    if (step % config.save_every == 0 and step > 0) or (step == config.num_steps - 1):
        checkpoint_dir = os.path.join(config.output_dir, f"checkpoint-{step}")
        os.makedirs(checkpoint_dir, exist_ok=True)
        policy_model.save_pretrained(checkpoint_dir)
        tokenizer.save_pretrained(checkpoint_dir)
        print(f"\n✓ Checkpoint saved: {checkpoint_dir}")
        
        artifact = wandb.Artifact(name=f"model-checkpoint-{step}", type="model")
        artifact.add_dir(checkpoint_dir)
        wandb.log_artifact(artifact)
    
    # Clean up
    del generated_ids, generated_texts, rewards_tensor, policy_logprobs
    if config.use_kl_penalty and reference_model is not None:
        del reference_logprobs
    torch.cuda.empty_cache()

print("=" * 80)
print("TRAINING COMPLETE")
print(f"Best validation reward: {best_val_reward:.4f}")
print("=" * 80)

## 💾 Save Final Model

In [None]:
#@title Save Final Model
final_dir = os.path.join(config.output_dir, "final_model")
os.makedirs(final_dir, exist_ok=True)

policy_model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"✓ Final model saved to: {final_dir}")

# Merge LoRA if used
if config.use_lora:
    print("Merging LoRA adapters...")
    merged_model = policy_model.merge_and_unload()
    merged_dir = os.path.join(config.output_dir, "final_model_merged")
    os.makedirs(merged_dir, exist_ok=True)
    merged_model.save_pretrained(merged_dir)
    tokenizer.save_pretrained(merged_dir)
    print(f"✓ Merged model saved to: {merged_dir}")

# Push to Hub if configured
if config.push_to_hub and config.hub_repo_id:
    print(f"Pushing to Hub: {config.hub_repo_id}")
    model_to_push = merged_model if config.use_lora else policy_model
    model_to_push.push_to_hub(repo_id=config.hub_repo_id, private=True)
    tokenizer.push_to_hub(repo_id=config.hub_repo_id, private=True)
    print(f"✓ Pushed to Hub")

# Log to WandB
final_artifact = wandb.Artifact(name="final-model", type="model")
final_artifact.add_dir(final_dir)
wandb.log_artifact(final_artifact)
print(f"✓ Logged to WandB: {wandb_run.get_url()}")

## 📊 Results Summary

In [None]:
#@title Training Summary
# Save metrics
metrics_df = pd.DataFrame(metrics_tracker.metrics)
metrics_df.to_csv(os.path.join(config.output_dir, "training_metrics.csv"), index=False)

# Summary
summary = {
    "baseline_reward": baseline_metrics["mean_reward"],
    "final_loss": metrics_tracker.metrics[-1]["loss"],
    "final_loss_ema": metrics_tracker.metrics[-1]["loss_ema"],
    "final_reward": metrics_tracker.metrics[-1]["reward"],
    "final_reward_ema": metrics_tracker.metrics[-1]["reward_ema"],
    "best_val_reward": best_val_reward,
    "total_tokens": total_tokens,
    "total_steps": config.num_steps
}

with open(os.path.join(config.output_dir, "summary.json"), "w") as f:
    json.dump(summary, f, indent=2)

wandb.summary.update(summary)

print("=" * 80)
print("TRAINING SUMMARY")
print("=" * 80)
print(f"Baseline Reward: {summary['baseline_reward']:.4f}")
print(f"Final Loss: {summary['final_loss']:.4f}")
print(f"Final Loss (EMA): {summary['final_loss_ema']:.4f}")
print(f"Final Reward: {summary['final_reward']:.4f}")
print(f"Final Reward (EMA): {summary['final_reward_ema']:.4f}")
print(f"Best Val Reward: {summary['best_val_reward']:.4f}")
print(f"Total Tokens: {summary['total_tokens']:,}")
print(f"\nImprovement: {summary['best_val_reward'] - summary['baseline_reward']:.4f}")
print(f"Output: {config.output_dir}")
print(f"WandB: {wandb_run.get_url()}")
print("=" * 80)

In [None]:
#@title Finish WandB Run
wandb.finish()
print("✓ WandB run finished")

## 🎉 Done!

Your RL training is complete. You can now:

1. **Evaluate your model** on test data
2. **Inspect the metrics** in WandB dashboard
3. **Load the trained model** from the checkpoints
4. **Customize** any section for your specific use case

### Next Steps:

- **Customize the reward function** in the "Reward Function" section
- **Add validation metrics** specific to your task
- **Implement advanced RL algorithms** (PPO, DPO, etc.)
- **Tune hyperparameters** at the top of the notebook

### Common Modifications:

1. **Change model**: Edit `config.policy_model_id`
2. **Change dataset**: Edit `config.dataset_name` and fields
3. **Change algorithm**: Modify loss computation in training loop
4. **Use reward model**: Set `config.use_reward_model=True`