# RL Training Boilerplate

This notebook provides a modular boilerplate for Reinforcement Learning training of language models using PyTorch directly (no verl).

**Features:**
- All hyperparameters at the top for easy configuration
- Modular sections: model loading, tools, chat templates, environment, reward model
- WandB integration for logging and tracking
- Easy to swap models and modify components
- Pure PyTorch implementation

**Requirements:**
- GPU with sufficient VRAM (A100 recommended)
- WandB account for logging

## 🛠️ Setup & Dependencies

In [None]:
#@title Install Dependencies
!nvidia-smi -L || true

import sys
print("Python:", sys.version)

# Install required packages
try:
    %pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                     datasets>=3.3.2 torch wandb huggingface_hub \
                     sentencepiece protobuf tqdm matplotlib pandas
except Exception:
    !pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                     datasets>=3.3.2 torch wandb huggingface_hub \
                     sentencepiece protobuf tqdm matplotlib pandas

import os, random, time, json, platform
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import pandas as pd
from IPython.display import display

print("\n=== Environment ===")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "Please connect a GPU for RL training."

## 🔑 WandB & HuggingFace Authentication

In [None]:
#@title Set API Keys (Optional - fill in if needed)
import os

# WandB API key - get from https://wandb.ai/authorize
WANDB_API_KEY = ""  # Your WandB API key
if WANDB_API_KEY:
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY

# HuggingFace token - get from https://huggingface.co/settings/tokens
HF_TOKEN = ""  # Your HuggingFace token
if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN

In [None]:
#@title Login to WandB and HuggingFace
import wandb
from huggingface_hub import login

# WandB login
try:
    wandb.login()
    print("✓ WandB login successful")
except Exception as e:
    print(f"⚠ WandB login failed: {e}")
    print("Training will continue without WandB logging")

# HuggingFace login
try:
    if os.environ.get('HF_TOKEN'):
        login(token=os.environ['HF_TOKEN'])
        print("✓ HuggingFace login successful")
except Exception as e:
    print(f"⚠ HuggingFace login failed: {e}")

## ⚙️ HYPERPARAMETERS

**All configurable parameters are in this section for easy modification.**

In [None]:
#@title Hyperparameters Configuration
from dataclasses import dataclass
from typing import Optional, List

@dataclass
class RLConfig:
    # ==================== MODEL CONFIGURATION ====================
    # Policy model (the model being trained)
    policy_model_id: str = "Qwen/Qwen3-0.6B-Base"
    policy_model_dtype: str = "bfloat16"  # "bfloat16", "float16", or "float32"
    
    # Reference model (for KL penalty, optional)
    use_reference_model: bool = False
    reference_model_id: Optional[str] = None  # If None, uses policy_model_id
    
    # Reward model (optional)
    use_reward_model: bool = False
    reward_model_id: Optional[str] = None
    
    # ==================== LORA CONFIGURATION ====================
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: Optional[List[str]] = None  # Auto-detect if None
    
    # ==================== TRAINING CONFIGURATION ====================
    num_steps: int = 100
    batch_size: int = 4
    samples_per_prompt: int = 4
    
    learning_rate: float = 1e-4
    weight_decay: float = 0.0
    grad_accumulation_steps: int = 1
    max_grad_norm: float = 1.0
    
    max_new_tokens: int = 256
    train_temperature: float = 0.7
    eval_temperature: float = 0.0
    top_p: float = 0.9
    
    # ==================== RL ALGORITHM ====================
    algorithm: str = "pg"  # "pg", "ppo", "dpo", "custom"
    use_kl_penalty: bool = False
    kl_coef: float = 0.1
    
    # ==================== DATA CONFIGURATION ====================
    dataset_name: str = "openai/gsm8k"
    dataset_config: Optional[str] = "main"
    dataset_split: str = "train"
    prompt_field: str = "question"
    answer_field: Optional[str] = "answer"
    
    val_size: int = 200
    val_every: int = 10
    
    # ==================== PROMPT TEMPLATE ====================
    prompt_template: str = (
        "Solve step by step.\n"
        "Problem: {prompt}\n\nSolution:"
    )
    use_chat_template: bool = False
    system_prompt: Optional[str] = None
    
    # ==================== REWARD CONFIGURATION ====================
    reward_type: str = "rule"  # "model", "rule", or "custom"
    
    # ==================== TOOLS & ENVIRONMENT ====================
    use_tools: bool = False
    use_environment: bool = False
    
    # ==================== LOGGING & CHECKPOINTING ====================
    wandb_project: str = "rl-training"
    wandb_run_name: Optional[str] = None
    log_every: int = 10
    save_every: int = 50
    ema_momentum: float = 0.9
    
    output_dir: str = f"./run_rl_{int(time.time())}"
    push_to_hub: bool = False
    hub_repo_id: Optional[str] = None
    
    seed: int = 42
    
    def __post_init__(self):
        if self.wandb_run_name is None:
            model_short = self.policy_model_id.split("/")[-1]
            self.wandb_run_name = f"rl_{self.algorithm}_{model_short}_{int(time.time())}"
        if self.use_reference_model and self.reference_model_id is None:
            self.reference_model_id = self.policy_model_id
        if self.use_lora and self.lora_target_modules is None:
            self.lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        os.makedirs(self.output_dir, exist_ok=True)

# Create config
config = RLConfig()

print("\n=== RL Training Configuration ===")
print(f"Policy Model: {config.policy_model_id}")
print(f"Algorithm: {config.algorithm}")
print(f"Steps: {config.num_steps}")
print(f"Batch Size: {config.batch_size}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Output Dir: {config.output_dir}")

# Save config
with open(os.path.join(config.output_dir, "config.json"), "w") as f:
    config_dict = {k: v for k, v in config.__dict__.items() if not callable(v)}
    json.dump(config_dict, f, indent=2)
print(f"✓ Config saved")

## 🎲 Set Random Seed

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.seed)
print(f"✓ Random seed set to {config.seed}")

## 📦 Model Loading

This section loads the policy model, reference model (if needed), and reward model (if needed).

In [None]:
#@title Model Loading Utilities
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model

def get_torch_dtype(dtype_str: str):
    if dtype_str == "bfloat16":
        return torch.bfloat16
    elif dtype_str == "float16":
        return torch.float16
    else:
        return torch.float32

def load_tokenizer(model_id: str):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"
    tokenizer.padding_side = "left"
    return tokenizer

def load_causal_lm(model_id: str, dtype_str: str, use_lora: bool = False, lora_config=None):
    dtype = get_torch_dtype(dtype_str)
    print(f"Loading model: {model_id}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=dtype, device_map="auto", trust_remote_code=True
    )
    if use_lora and lora_config:
        print(f"Applying LoRA...")
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    return model

print("✓ Model loading utilities defined")

In [None]:
#@title Load Tokenizer
print("=== Loading Tokenizer ===")
tokenizer = load_tokenizer(config.policy_model_id)
print(f"✓ Tokenizer loaded (vocab: {len(tokenizer)})")

In [None]:
#@title Load Policy Model
print("=== Loading Policy Model ===")

lora_config = None
if config.use_lora:
    lora_config = LoraConfig(
        r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout,
        bias="none", task_type="CAUSAL_LM", target_modules=config.lora_target_modules
    )

policy_model = load_causal_lm(
    config.policy_model_id, config.policy_model_dtype,
    use_lora=config.use_lora, lora_config=lora_config
)
policy_model.config.use_cache = False
print(f"✓ Policy model loaded")

In [None]:
#@title Load Reference Model (optional)
reference_model = None
if config.use_reference_model:
    print("=== Loading Reference Model ===")
    reference_model = load_causal_lm(config.reference_model_id, config.policy_model_dtype)
    reference_model.eval()
    for param in reference_model.parameters():
        param.requires_grad_(False)
    print("✓ Reference model loaded")
else:
    print("⊗ Reference model not used")

In [None]:
#@title Load Reward Model (optional)
reward_model = None
if config.use_reward_model and config.reward_model_id:
    print("=== Loading Reward Model ===")
    reward_model = AutoModelForSequenceClassification.from_pretrained(
        config.reward_model_id, torch_dtype=torch.bfloat16, device_map="auto"
    )
    reward_model.eval()
    for param in reward_model.parameters():
        param.requires_grad_(False)
    print("✓ Reward model loaded")
else:
    print("⊗ Reward model not used (will use rule-based rewards)")

## 💬 Chat Template

In [None]:
#@title Prompt Formatting
def format_prompt(prompt: str) -> str:
    if config.use_chat_template:
        messages = []
        if config.system_prompt:
            messages.append({"role": "system", "content": config.system_prompt})
        messages.append({"role": "user", "content": prompt})
        if hasattr(tokenizer, 'apply_chat_template'):
            return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        else:
            return f"User: {prompt}\n\nAssistant:"
    else:
        return config.prompt_template.format(prompt=prompt)

# Test
test_prompt = format_prompt("What is 2+2?")
print("=== Prompt Formatting ===")
print(f"Example:\n{test_prompt[:200]}")

## 🛠️ Tools & Environment

In [None]:
#@title Tools Configuration
tools = {}
if config.use_tools:
    # TODO: Add your tools here
    # Example: tools["calculator"] = calculator_function
    print(f"✓ Tools loaded: {list(tools.keys())}")
else:
    print("⊗ Tools not used")

environment = None
if config.use_environment:
    # TODO: Initialize environment here
    print("✓ Environment initialized")
else:
    print("⊗ Environment not used")

## 🎁 Reward Function

**IMPORTANT: Customize this for your task!**

In [None]:
#@title Reward Function
import re

def compute_reward(prompt: str, response: str, ground_truth: Optional[str] = None) -> float:
    """
    Compute reward for a generated response.
    
    **TODO: CUSTOMIZE THIS FOR YOUR TASK!**
    
    This is a placeholder. Replace with your actual reward logic:
    - Call a learned reward model
    - Rule-based scoring (length, format, keywords)
    - Verifiable correctness (math, code execution)
    - Task-specific metrics (BLEU, ROUGE, exact match)
    """
    if config.reward_type == "model" and reward_model is not None:
        # Use reward model
        with torch.no_grad():
            inputs = tokenizer(prompt + response, return_tensors="pt", truncation=True, max_length=1024).to(reward_model.device)
            outputs = reward_model(**inputs)
            reward = outputs.logits[0, -1].item()
        return reward
    
    elif config.reward_type == "rule":
        # Rule-based reward (CUSTOMIZE THIS)
        reward = 0.0
        # Example: check for answer format
        if re.search(r"\[.*?\]", response):
            reward += 0.5
        # Example: length penalty
        if len(response.split()) > 500:
            reward -= 0.1
        return reward
    
    else:
        # Custom reward (TODO: implement)
        return 0.0

# Test
test_reward = compute_reward("What is 2+2?", "The answer is [4].")
print(f"=== Reward Function ===")
print(f"Example reward: {test_reward:.4f}")
print("\n⚠ WARNING: Placeholder reward function!")
print("   Customize compute_reward() for your task.")

## 📊 Data Loading

In [None]:
#@title Load Dataset
from datasets import load_dataset

print("=== Loading Dataset ===")
if config.dataset_config:
    dataset = load_dataset(config.dataset_name, config.dataset_config, split=config.dataset_split)
else:
    dataset = load_dataset(config.dataset_name, split=config.dataset_split)

# Split into train/val
val_size = min(config.val_size, len(dataset))
val_dataset = dataset.select(range(val_size))
train_dataset = dataset.select(range(val_size, len(dataset)))

print(f"Train: {len(train_dataset):,} | Val: {len(val_dataset):,}")

# Prepare prompts
train_prompts = [format_prompt(ex[config.prompt_field]) for ex in train_dataset]
val_prompts = [format_prompt(ex[config.prompt_field]) for ex in val_dataset]

print(f"\n✓ Prompts prepared")
print(f"Example:\n{train_prompts[0][:200]}")

## 🔧 Training Utilities

In [None]:
#@title Training Utilities

def mask_after_eos(token_ids: torch.Tensor, eos_id: int) -> torch.Tensor:
    """Create mask for tokens before first EOS."""
    is_eos = (token_ids == eos_id)
    eos_positions = is_eos.cumsum(dim=1)
    return (eos_positions == 0).float()

def compute_model_logprobs(model, input_ids, attention_mask, target_ids, micro_batch_size=8):
    """Compute log probabilities of target tokens."""
    batch_size = input_ids.size(0)
    target_len = target_ids.size(1)
    
    full_ids = torch.cat([input_ids, target_ids], dim=1)
    full_mask = torch.cat([attention_mask, torch.ones_like(target_ids)], dim=1)
    
    if micro_batch_size >= batch_size:
        outputs = model(input_ids=full_ids[:, :-1], attention_mask=full_mask[:, :-1])
        logits = outputs.logits[:, -target_len:, :]
        logprobs = F.log_softmax(logits, dim=-1)
        token_logprobs = logprobs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
    else:
        # Micro-batching for memory efficiency
        token_logprobs_list = []
        for i in range(0, batch_size, micro_batch_size):
            micro_full_ids = full_ids[i:i+micro_batch_size]
            micro_full_mask = full_mask[i:i+micro_batch_size]
            micro_target_ids = target_ids[i:i+micro_batch_size]
            
            outputs = model(input_ids=micro_full_ids[:, :-1], attention_mask=micro_full_mask[:, :-1])
            logits = outputs.logits[:, -target_len:, :]
            logprobs = F.log_softmax(logits, dim=-1)
            micro_token_logprobs = logprobs.gather(-1, micro_target_ids.unsqueeze(-1)).squeeze(-1)
            token_logprobs_list.append(micro_token_logprobs)
        token_logprobs = torch.cat(token_logprobs_list, dim=0)
    
    return token_logprobs

class MetricsTracker:
    """Track and display training metrics with EMA smoothing."""
    
    def __init__(self, ema_momentum=0.9):
        self.ema_momentum = ema_momentum
        self.metrics = []
        self.ema_values = {}
        empty_df = pd.DataFrame(columns=["step", "loss", "loss_ema", "reward", "reward_ema", "kl", "tokens"])
        self.display_handle = display(empty_df, display_id=True)
    
    def update_ema(self, key, value):
        if key not in self.ema_values:
            self.ema_values[key] = value
        else:
            self.ema_values[key] = self.ema_momentum * self.ema_values[key] + (1 - self.ema_momentum) * value
        return self.ema_values[key]
    
    def log(self, step, loss, reward, kl=0.0, tokens=0):
        loss_ema = self.update_ema("loss", loss)
        reward_ema = self.update_ema("reward", reward)
        
        metric = {
            "step": step, "loss": loss, "loss_ema": loss_ema,
            "reward": reward, "reward_ema": reward_ema, "kl": kl, "tokens": tokens
        }
        self.metrics.append(metric)
        
        df = pd.DataFrame(self.metrics[-100:])
        self.display_handle.update(df.style.format({
            "loss": "{:.4f}", "loss_ema": "{:.4f}",
            "reward": "{:.4f}", "reward_ema": "{:.4f}", "kl": "{:.4f}"
        }))
        
        return metric

print("✓ Training utilities defined")

## 📈 Initialize WandB

In [None]:
#@title Initialize WandB Run
wandb_run = wandb.init(
    project=config.wandb_project,
    name=config.wandb_run_name,
    config=config.__dict__,
    job_type="training"
)

wandb.watch(policy_model, log="all", log_freq=100)

print(f"✓ WandB initialized")
print(f"  Run: {config.wandb_run_name}")
print(f"  URL: {wandb_run.get_url()}")

## 🚀 Training Loop

In [None]:
#@title Setup Optimizer
optimizer = torch.optim.AdamW(
    policy_model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)
print(f"✓ Optimizer initialized (lr={config.learning_rate})")

In [None]:
#@title Training Loop
metrics_tracker = MetricsTracker(ema_momentum=config.ema_momentum)
total_tokens = 0

print("=" * 80)
print("STARTING TRAINING")
print("=" * 80)

pbar = tqdm(range(config.num_steps), desc="Training")

for step in pbar:
    # Sample batch of prompts
    rng = np.random.default_rng(config.seed + step)
    prompt_indices = rng.choice(len(train_prompts), size=config.batch_size, replace=False)
    
    # Repeat each prompt for multiple samples
    batch_prompts = []
    for idx in prompt_indices:
        batch_prompts.extend([train_prompts[idx]] * config.samples_per_prompt)
    
    # Tokenize
    prompt_encoding = tokenizer(
        batch_prompts, padding=True, truncation=True, max_length=2048, return_tensors="pt"
    ).to(DEVICE)
    
    # Generate responses (on-policy)
    policy_model.eval()
    policy_model.config.use_cache = True
    with torch.no_grad():
        generation_output = policy_model.generate(
            **prompt_encoding, do_sample=True, temperature=config.train_temperature,
            top_p=config.top_p, max_new_tokens=config.max_new_tokens,
            pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id
        )
        full_sequences = generation_output
        generated_ids = full_sequences[:, prompt_encoding.input_ids.size(1):]
    
    # Create validity mask
    valid_mask = mask_after_eos(generated_ids, tokenizer.eos_token_id)
    
    # Compute rewards
    with torch.no_grad():
        generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        rewards = [compute_reward(prompt, response) for prompt, response in zip(batch_prompts, generated_texts)]
        rewards_tensor = torch.tensor(rewards, device=DEVICE).unsqueeze(1)
        mean_reward = rewards_tensor.mean().item()
    
    # Compute policy log probs (with gradient)
    policy_model.train()
    policy_model.config.use_cache = False
    policy_logprobs = compute_model_logprobs(
        policy_model, prompt_encoding.input_ids, prompt_encoding.attention_mask, generated_ids
    )
    
    # Compute KL penalty (if using reference model)
    kl_penalty = 0.0
    if config.use_kl_penalty and reference_model is not None:
        with torch.no_grad():
            reference_logprobs = compute_model_logprobs(
                reference_model, prompt_encoding.input_ids, prompt_encoding.attention_mask, generated_ids
            )
            kl_per_token = policy_logprobs.detach() - reference_logprobs
            kl_penalty = (kl_per_token * valid_mask).sum() / valid_mask.sum().clamp(min=1.0)
    
    # Policy gradient loss
    rewards_broadcast = rewards_tensor.expand_as(policy_logprobs)
    baseline = rewards_tensor.mean()
    advantages = rewards_broadcast - baseline
    
    if config.use_kl_penalty:
        advantages = advantages - config.kl_coef * (kl_penalty if isinstance(kl_penalty, float) else kl_penalty.item())
    
    loss_per_token = -advantages.detach() * policy_logprobs * valid_mask
    loss = loss_per_token.sum() / valid_mask.sum().clamp(min=1.0)
    
    # Backward pass
    loss.backward()
    if config.max_grad_norm > 0:
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config.max_grad_norm)
    
    if (step + 1) % config.grad_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
    
    # Logging
    total_tokens += int(valid_mask.sum().item())
    metric = metrics_tracker.log(
        step=step, loss=loss.item(), reward=mean_reward,
        kl=kl_penalty if isinstance(kl_penalty, float) else kl_penalty.item(),
        tokens=total_tokens
    )
    wandb.log(metric, step=step)
    
    pbar.set_postfix({
        "loss": f"{loss.item():.3f}",
        "reward": f"{mean_reward:.3f}",
        "kl": f"{kl_penalty if isinstance(kl_penalty, float) else kl_penalty.item():.3f}"
    })
    
    # Checkpointing
    if (step % config.save_every == 0 and step > 0) or (step == config.num_steps - 1):
        checkpoint_dir = os.path.join(config.output_dir, f"checkpoint-{step}")
        os.makedirs(checkpoint_dir, exist_ok=True)
        policy_model.save_pretrained(checkpoint_dir)
        tokenizer.save_pretrained(checkpoint_dir)
        print(f"\n✓ Checkpoint saved: {checkpoint_dir}")
        
        artifact = wandb.Artifact(name=f"model-checkpoint-{step}", type="model")
        artifact.add_dir(checkpoint_dir)
        wandb.log_artifact(artifact)
    
    torch.cuda.empty_cache()

print("=" * 80)
print("TRAINING COMPLETE")
print("=" * 80)

## 💾 Save Final Model

In [None]:
#@title Save Final Model
final_dir = os.path.join(config.output_dir, "final_model")
os.makedirs(final_dir, exist_ok=True)

policy_model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"✓ Final model saved to: {final_dir}")

# Merge LoRA if used
if config.use_lora:
    print("Merging LoRA adapters...")
    merged_model = policy_model.merge_and_unload()
    merged_dir = os.path.join(config.output_dir, "final_model_merged")
    os.makedirs(merged_dir, exist_ok=True)
    merged_model.save_pretrained(merged_dir)
    tokenizer.save_pretrained(merged_dir)
    print(f"✓ Merged model saved to: {merged_dir}")

# Push to Hub if configured
if config.push_to_hub and config.hub_repo_id:
    print(f"Pushing to Hub: {config.hub_repo_id}")
    model_to_push = merged_model if config.use_lora else policy_model
    model_to_push.push_to_hub(repo_id=config.hub_repo_id, private=True)
    tokenizer.push_to_hub(repo_id=config.hub_repo_id, private=True)
    print(f"✓ Pushed to Hub")

# Log to WandB
final_artifact = wandb.Artifact(name="final-model", type="model")
final_artifact.add_dir(final_dir)
wandb.log_artifact(final_artifact)
print(f"✓ Logged to WandB: {wandb_run.get_url()}")

## 📊 Results Summary

In [None]:
#@title Training Summary
# Save metrics
metrics_df = pd.DataFrame(metrics_tracker.metrics)
metrics_df.to_csv(os.path.join(config.output_dir, "training_metrics.csv"), index=False)

# Summary
summary = {
    "final_loss": metrics_tracker.metrics[-1]["loss"],
    "final_loss_ema": metrics_tracker.metrics[-1]["loss_ema"],
    "final_reward": metrics_tracker.metrics[-1]["reward"],
    "final_reward_ema": metrics_tracker.metrics[-1]["reward_ema"],
    "total_tokens": total_tokens,
    "total_steps": config.num_steps
}

with open(os.path.join(config.output_dir, "summary.json"), "w") as f:
    json.dump(summary, f, indent=2)

wandb.summary.update(summary)

print("=" * 80)
print("TRAINING SUMMARY")
print("=" * 80)
print(f"Final Loss: {summary['final_loss']:.4f}")
print(f"Final Loss (EMA): {summary['final_loss_ema']:.4f}")
print(f"Final Reward: {summary['final_reward']:.4f}")
print(f"Final Reward (EMA): {summary['final_reward_ema']:.4f}")
print(f"Total Tokens: {summary['total_tokens']:,}")
print(f"\nOutput: {config.output_dir}")
print(f"WandB: {wandb_run.get_url()}")
print("=" * 80)

In [None]:
#@title Finish WandB Run
wandb.finish()
print("✓ WandB run finished")

## 🎉 Done!

Your RL training is complete. You can now:

1. **Evaluate your model** on test data
2. **Inspect the metrics** in WandB dashboard
3. **Load the trained model** from the checkpoints
4. **Customize** any section for your specific use case

### Next Steps:

- **Customize the reward function** in the "Reward Function" section
- **Add validation metrics** specific to your task
- **Implement advanced RL algorithms** (PPO, DPO, etc.)
- **Tune hyperparameters** at the top of the notebook

### Common Modifications:

1. **Change model**: Edit `config.policy_model_id`
2. **Change dataset**: Edit `config.dataset_name` and fields
3. **Change algorithm**: Modify loss computation in training loop
4. **Use reward model**: Set `config.use_reward_model=True`