# RL Training Primitives: Function-Based Building Blocks

**Ultra-hackable, composable primitives for RL training experiments**

## Philosophy

**Functions, not classes.** Every component accepts function arguments for maximum flexibility:

```python
# Swap reward function
rollout = create_rollout(model, tokenizer, prompt, 
                        reward_fn=my_custom_reward)

# Swap KL computation
kl = compute_kl(logprobs, ref_logprobs, 
               kl_fn=truncated_kl)

# Swap algorithm mid-training
config = get_algorithm_config(iteration)
```

## What You Can Do

âœ… **Scalar â†’ Vector rewards** without touching other code  
âœ… **Track tool calls** in multi-turn rollouts  
âœ… **Switch algorithms** mid-training (GRPO â†’ DAPO)  
âœ… **Update reference model** with custom strategies  
âœ… **Experiment with KL** divergence methods  
âœ… **Inject custom logic** at 50+ extension points

## Primitives Categories

1. **Data Structures** - NamedTuples for rollouts, batches, configs
2. **Inference/Generation** - Model loading, generation, multi-turn
3. **Rewards** - Scalar, vector, custom functions, aggregation
4. **KL Divergence** - Forward/reverse KL, reference models
5. **Rollouts** - Single/multi-turn, tool tracking
6. **Algorithms** - GRPO, DAPO, PPO components
7. **Training Loop** - Iteration, logging, checkpointing

In [None]:
# Setup
!pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
              torch datasets tqdm numpy pandas matplotlib

import os, sys, json, time, re
from typing import NamedTuple, Optional, List, Dict, Callable, Any, Tuple, Union
from dataclasses import dataclass, field
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
print(f"PyTorch: {torch.__version__}")

## 1. Data Structures

All data structures are **NamedTuples** for immutability and clarity.

In [None]:
# =============================================================================
# DATA STRUCTURES (NamedTuples for immutability)
# =============================================================================

class Generation(NamedTuple):
    """Result of text generation."""
    text: str                          # Decoded text
    token_ids: torch.Tensor            # [seq_len] generated token IDs
    logprobs: Optional[torch.Tensor]   # [seq_len] log probabilities
    valid_mask: torch.Tensor           # [seq_len] mask for tokens before EOS
    metadata: Dict[str, Any] = {}      # Extra info (thinking, tool calls, etc.)

class Message(NamedTuple):
    """Single message in a conversation."""
    role: str                          # "user", "assistant", "system"
    content: str                       # Message text
    tool_calls: List[Dict] = []        # Tool calls in this message
    metadata: Dict[str, Any] = {}      # Extra info

class Conversation(NamedTuple):
    """Multi-turn conversation."""
    messages: List[Message]            # Conversation history
    metadata: Dict[str, Any] = {}      # Extra info

class Rollout(NamedTuple):
    """Complete rollout with prompt, generation, and reward."""
    prompt: str                        # Input prompt (or conversation)
    generation: Generation             # Generated response
    reward: Union[float, np.ndarray]   # Scalar or vector reward
    kl_penalty: float = 0.0            # KL divergence penalty
    conversation: Optional[Conversation] = None  # For multi-turn
    metadata: Dict[str, Any] = {}      # Extra info

class Batch(NamedTuple):
    """Batch of rollouts for training."""
    rollouts: List[Rollout]            # List of rollouts
    advantages: torch.Tensor           # [batch_size, seq_len] advantages
    returns: Optional[torch.Tensor] = None  # [batch_size, seq_len] for PPO
    metadata: Dict[str, Any] = {}      # Extra info

class GenerationConfig(NamedTuple):
    """Configuration for text generation."""
    max_new_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.9
    top_k: int = 50
    do_sample: bool = True
    enable_thinking: bool = False      # Qwen3 thinking mode
    parse_thinking: bool = False       # Parse <think> blocks
    use_cache: bool = True

class TrainingConfig(NamedTuple):
    """Configuration for RL training."""
    algorithm: str = "grpo"            # "grpo", "dapo", "ppo"
    learning_rate: float = 1e-4
    kl_coef: float = 0.1
    clip_range: float = 0.2            # For PPO
    normalize_advantages: bool = True
    normalize_rewards: bool = False
    max_grad_norm: float = 1.0
    
class TrainingState(NamedTuple):
    """State during training."""
    step: int
    total_tokens: int
    metrics: Dict[str, float]
    ema_metrics: Dict[str, float] = {}

print("âœ“ Data structures defined")
print("  - Generation, Message, Conversation")
print("  - Rollout, Batch")
print("  - GenerationConfig, TrainingConfig, TrainingState")

## 2. Inference & Generation Primitives

Model loading, generation, multi-turn conversations.

In [None]:
# =============================================================================
# INFERENCE & GENERATION PRIMITIVES
# =============================================================================

def load_model(
    model_id: str,
    dtype: str = "bfloat16",
    device_map: str = "auto",
    use_lora: bool = False,
    lora_rank: int = 16,
    lora_alpha: int = 32,
    model_loader: Optional[Callable] = None,  # Extension point!
) -> torch.nn.Module:
    """
    Load a causal LM with optional LoRA.
    
    Extension point:
        model_loader: Custom loader function (e.g., vLLM, SGLang)
    """
    if model_loader is not None:
        return model_loader(model_id, dtype, device_map, use_lora, lora_rank, lora_alpha)
    
    dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
    torch_dtype = dtype_map.get(dtype, torch.bfloat16)
    
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch_dtype, device_map=device_map, trust_remote_code=True
    )
    
    if use_lora:
        lora_config = LoraConfig(
            r=lora_rank, lora_alpha=lora_alpha, lora_dropout=0.05,
            bias="none", task_type="CAUSAL_LM",
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        )
        model = get_peft_model(model, lora_config)
    
    return model

def load_tokenizer(
    model_id: str,
    padding_side: str = "left"
) -> AutoTokenizer:
    """Load tokenizer with sensible defaults."""
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = padding_side
    return tokenizer

def format_prompt(
    prompt: str,
    tokenizer: AutoTokenizer,
    use_chat_template: bool = False,
    system_prompt: Optional[str] = None,
    enable_thinking: bool = False,
    template_fn: Optional[Callable] = None,  # Extension point!
) -> str:
    """
    Format prompt with chat template.
    
    Extension point:
        template_fn: Custom template function (prompt, tokenizer, **kwargs) -> str
    """
    if template_fn is not None:
        return template_fn(prompt, tokenizer, use_chat_template, system_prompt, enable_thinking)
    
    if use_chat_template:
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        if hasattr(tokenizer, 'apply_chat_template'):
            return tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True,
                enable_thinking=enable_thinking
            )
        else:
            formatted = ""
            if system_prompt:
                formatted += f"System: {system_prompt}\n\n"
            formatted += f"User: {prompt}\n\nAssistant:"
            return formatted
    else:
        return prompt

QWEN3_THINK_END_TOKEN = 151668

def parse_thinking(
    token_ids: List[int],
    tokenizer: AutoTokenizer
) -> Tuple[str, str]:
    """Parse Qwen3 <think>...</think> blocks."""
    try:
        index = len(token_ids) - token_ids[::-1].index(QWEN3_THINK_END_TOKEN)
    except ValueError:
        index = 0
    
    thinking = tokenizer.decode(token_ids[:index], skip_special_tokens=True).strip()
    response = tokenizer.decode(token_ids[index:], skip_special_tokens=True).strip()
    return thinking, response

def generate(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    prompt: str,
    config: GenerationConfig,
    compute_logprobs: bool = False,
    generator: Optional[Callable] = None,  # Extension point!
) -> Generation:
    """
    Generate text from a prompt.
    
    Extension point:
        generator: Custom generation function (model, tokenizer, prompt, config) -> Generation
    """
    if generator is not None:
        return generator(model, tokenizer, prompt, config, compute_logprobs)
    
    # Tokenize
    encoding = tokenizer(
        [prompt], padding=True, truncation=True, max_length=2048, return_tensors="pt"
    ).to(model.device)
    
    # Generate
    was_training = model.training
    was_cache = model.config.use_cache
    model.eval()
    model.config.use_cache = config.use_cache
    
    gen_kwargs = {
        "max_new_tokens": config.max_new_tokens,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "do_sample": config.do_sample,
    }
    
    if config.do_sample:
        gen_kwargs.update({
            "temperature": config.temperature,
            "top_p": config.top_p,
            "top_k": config.top_k,
        })
    
    with torch.no_grad():
        output = model.generate(**encoding, **gen_kwargs)
    
    # Extract generated tokens
    generated_ids = output[0, encoding.input_ids.size(1):]
    
    # Create validity mask (tokens before first EOS)
    is_eos = (generated_ids == tokenizer.eos_token_id)
    eos_positions = is_eos.cumsum(dim=0)
    valid_mask = (eos_positions == 0).float()
    
    # Parse thinking if enabled
    metadata = {}
    if config.enable_thinking and config.parse_thinking:
        thinking, text = parse_thinking(generated_ids.tolist(), tokenizer)
        metadata["thinking"] = thinking
    else:
        text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Compute log probabilities if requested
    logprobs = None
    if compute_logprobs:
        logprobs = compute_logprobs_for_tokens(
            model, encoding.input_ids, encoding.attention_mask,
            generated_ids.unsqueeze(0), tokenizer.pad_token_id
        )[0]
    
    model.train(was_training)
    model.config.use_cache = was_cache
    
    return Generation(
        text=text,
        token_ids=generated_ids,
        logprobs=logprobs,
        valid_mask=valid_mask,
        metadata=metadata
    )

def compute_logprobs_for_tokens(
    model: torch.nn.Module,
    prompt_ids: torch.Tensor,      # [batch, prompt_len]
    prompt_mask: torch.Tensor,     # [batch, prompt_len]
    generated_ids: torch.Tensor,   # [batch, gen_len]
    pad_token_id: int,
    micro_batch_size: int = 8
) -> torch.Tensor:
    """
    Compute log probabilities for generated tokens.
    Returns: [batch, gen_len]
    """
    batch_size = prompt_ids.size(0)
    gen_len = generated_ids.size(1)
    
    # Concatenate prompt + generated
    full_ids = torch.cat([prompt_ids, generated_ids], dim=1)
    full_mask = torch.cat([
        prompt_mask,
        torch.ones_like(generated_ids)
    ], dim=1)
    
    if micro_batch_size >= batch_size:
        # No micro-batching
        outputs = model(input_ids=full_ids[:, :-1], attention_mask=full_mask[:, :-1])
        logits = outputs.logits[:, -gen_len:, :]
        logprobs = F.log_softmax(logits, dim=-1)
        token_logprobs = logprobs.gather(-1, generated_ids.unsqueeze(-1)).squeeze(-1)
    else:
        # Micro-batching
        token_logprobs_list = []
        for i in range(0, batch_size, micro_batch_size):
            micro_full_ids = full_ids[i:i+micro_batch_size]
            micro_full_mask = full_mask[i:i+micro_batch_size]
            micro_gen_ids = generated_ids[i:i+micro_batch_size]
            
            outputs = model(input_ids=micro_full_ids[:, :-1], attention_mask=micro_full_mask[:, :-1])
            logits = outputs.logits[:, -gen_len:, :]
            logprobs = F.log_softmax(logits, dim=-1)
            micro_logprobs = logprobs.gather(-1, micro_gen_ids.unsqueeze(-1)).squeeze(-1)
            token_logprobs_list.append(micro_logprobs)
        
        token_logprobs = torch.cat(token_logprobs_list, dim=0)
    
    return token_logprobs

print("âœ“ Inference primitives defined")
print("  - load_model, load_tokenizer")
print("  - format_prompt, parse_thinking")
print("  - generate, compute_logprobs_for_tokens")

## 3. Reward Primitives

Scalar rewards, vector rewards, custom functions, aggregation.

In [None]:
# =============================================================================
# REWARD PRIMITIVES
# =============================================================================

def compute_scalar_reward(
    rollout: Rollout,
    reward_fn: Callable[[Rollout], float],  # Extension point!
) -> float:
    """
    Compute scalar reward for a rollout.
    
    Extension point:
        reward_fn: Custom reward function (rollout) -> float
    """
    return reward_fn(rollout)

def compute_vector_reward(
    rollout: Rollout,
    reward_fns: List[Callable[[Rollout], float]],  # Extension point!
) -> np.ndarray:
    """
    Compute vector reward (multiple reward components).
    
    Extension point:
        reward_fns: List of reward functions, each returning a scalar
    """
    return np.array([fn(rollout) for fn in reward_fns])

def aggregate_vector_rewards(
    vector_rewards: np.ndarray,  # [num_components]
    aggregator: Callable[[np.ndarray], float],  # Extension point!
) -> float:
    """
    Aggregate vector rewards into a scalar.
    
    Extension point:
        aggregator: Custom aggregation function (vector) -> scalar
    """
    return aggregator(vector_rewards)

# ===== Common reward functions =====

def length_penalty_reward(min_length: int = 50, max_length: int = 500) -> Callable:
    """Reward based on response length."""
    def reward_fn(rollout: Rollout) -> float:
        length = len(rollout.generation.text.split())
        if length < min_length:
            return -0.5
        elif length > max_length:
            return -0.3
        else:
            return 0.0
    return reward_fn

def format_reward(pattern: str) -> Callable:
    """Reward if response matches a regex pattern."""
    def reward_fn(rollout: Rollout) -> float:
        if re.search(pattern, rollout.generation.text):
            return 1.0
        return 0.0
    return reward_fn

def exact_match_reward(ground_truth_fn: Callable[[str], str], parse_fn: Callable[[str], str]) -> Callable:
    """Reward for exact match with ground truth."""
    def reward_fn(rollout: Rollout) -> float:
        ground_truth = ground_truth_fn(rollout.prompt)
        prediction = parse_fn(rollout.generation.text)
        return 1.0 if prediction == ground_truth else 0.0
    return reward_fn

def tool_usage_reward(min_tools: int = 1) -> Callable:
    """Reward based on number of tool calls."""
    def reward_fn(rollout: Rollout) -> float:
        if rollout.conversation is None:
            return 0.0
        total_tools = sum(len(msg.tool_calls) for msg in rollout.conversation.messages)
        return 1.0 if total_tools >= min_tools else 0.0
    return reward_fn

# ===== Common aggregators =====

def weighted_sum_aggregator(weights: List[float]) -> Callable:
    """Weighted sum aggregation."""
    def aggregator(vector: np.ndarray) -> float:
        return np.dot(vector, weights)
    return aggregator

def min_aggregator(vector: np.ndarray) -> float:
    """Minimum of all components (pessimistic)."""
    return float(np.min(vector))

def max_aggregator(vector: np.ndarray) -> float:
    """Maximum of all components (optimistic)."""
    return float(np.max(vector))

def mean_aggregator(vector: np.ndarray) -> float:
    """Mean of all components."""
    return float(np.mean(vector))

print("âœ“ Reward primitives defined")
print("  - compute_scalar_reward, compute_vector_reward")
print("  - aggregate_vector_rewards")
print("  - Built-in: length_penalty, format_reward, exact_match, tool_usage")
print("  - Aggregators: weighted_sum, min, max, mean")

## 4. KL Divergence Primitives

Forward/reverse KL, reference model handling, update strategies.

In [None]:
# =============================================================================
# KL DIVERGENCE PRIMITIVES
# =============================================================================

def compute_forward_kl(
    policy_logprobs: torch.Tensor,   # [batch, seq_len]
    ref_logprobs: torch.Tensor,      # [batch, seq_len]
    valid_mask: torch.Tensor,        # [batch, seq_len]
    kl_fn: Optional[Callable] = None,  # Extension point!
) -> torch.Tensor:
    """
    Compute forward KL: KL(policy || ref) = E[log(policy) - log(ref)].
    Returns: [batch] KL per sequence
    
    Extension point:
        kl_fn: Custom KL computation (policy_lp, ref_lp, mask) -> kl
    """
    if kl_fn is not None:
        return kl_fn(policy_logprobs, ref_logprobs, valid_mask)
    
    kl_per_token = policy_logprobs - ref_logprobs
    kl_per_seq = (kl_per_token * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1.0)
    return kl_per_seq

def compute_reverse_kl(
    policy_logprobs: torch.Tensor,
    ref_logprobs: torch.Tensor,
    valid_mask: torch.Tensor,
    kl_fn: Optional[Callable] = None,  # Extension point!
) -> torch.Tensor:
    """
    Compute reverse KL: KL(ref || policy) = E[log(ref) - log(policy)].
    Used in on-policy distillation.
    """
    if kl_fn is not None:
        return kl_fn(policy_logprobs, ref_logprobs, valid_mask)
    
    kl_per_token = ref_logprobs - policy_logprobs
    kl_per_seq = (kl_per_token * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1.0)
    return kl_per_seq

def create_truncated_kl(threshold: float = 1.0) -> Callable:
    """Create truncated KL function (clip values above threshold)."""
    def kl_fn(policy_lp, ref_lp, mask):
        kl_per_token = (policy_lp - ref_lp).clamp(max=threshold)
        return (kl_per_token * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
    return kl_fn

def create_adaptive_kl(confidence_threshold: float = 0.1) -> Callable:
    """Create adaptive KL (only consider high-confidence tokens)."""
    def kl_fn(policy_lp, ref_lp, mask):
        policy_prob = torch.exp(policy_lp)
        confidence_mask = (policy_prob > confidence_threshold).float()
        combined_mask = mask * confidence_mask
        kl_per_token = policy_lp - ref_lp
        return (kl_per_token * combined_mask).sum(dim=1) / combined_mask.sum(dim=1).clamp(min=1.0)
    return kl_fn

# ===== Reference model utilities =====

def create_reference_model(
    policy_model: torch.nn.Module,
    ref_model_fn: Optional[Callable] = None,  # Extension point!
) -> torch.nn.Module:
    """
    Create reference model (frozen copy of policy).
    
    Extension point:
        ref_model_fn: Custom reference model creation
    """
    if ref_model_fn is not None:
        return ref_model_fn(policy_model)
    
    # Default: freeze policy model
    import copy
    ref_model = copy.deepcopy(policy_model)
    ref_model.eval()
    for param in ref_model.parameters():
        param.requires_grad_(False)
    return ref_model

def update_reference_model(
    ref_model: torch.nn.Module,
    policy_model: torch.nn.Module,
    update_fn: Callable,  # Extension point!
) -> None:
    """
    Update reference model with custom strategy.
    
    Extension point:
        update_fn: (ref_model, policy_model) -> None
    """
    update_fn(ref_model, policy_model)

def create_ema_updater(decay: float = 0.999) -> Callable:
    """Create EMA update function for reference model."""
    def update_fn(ref_model, policy_model):
        with torch.no_grad():
            for ref_param, policy_param in zip(ref_model.parameters(), policy_model.parameters()):
                ref_param.data.mul_(decay).add_(policy_param.data, alpha=1-decay)
    return update_fn

def create_periodic_updater(update_every: int) -> Callable:
    """Create periodic hard update function (every N steps)."""
    state = {"step": 0}
    
    def update_fn(ref_model, policy_model):
        state["step"] += 1
        if state["step"] % update_every == 0:
            with torch.no_grad():
                for ref_param, policy_param in zip(ref_model.parameters(), policy_model.parameters()):
                    ref_param.data.copy_(policy_param.data)
    return update_fn

print("âœ“ KL divergence primitives defined")
print("  - compute_forward_kl, compute_reverse_kl")
print("  - create_truncated_kl, create_adaptive_kl")
print("  - create_reference_model, update_reference_model")
print("  - create_ema_updater, create_periodic_updater")

## 5. Rollout Primitives

Single-turn, multi-turn, tool tracking.

In [None]:
# =============================================================================
# ROLLOUT PRIMITIVES
# =============================================================================

def create_rollout(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    prompt: str,
    gen_config: GenerationConfig,
    reward_fn: Callable[[Rollout], float],  # Extension point!
    ref_model: Optional[torch.nn.Module] = None,
    kl_coef: float = 0.0,
) -> Rollout:
    """
    Create a single rollout: prompt -> generation -> reward.
    
    Extension point:
        reward_fn: Custom reward function
    """
    # Generate response
    generation = generate(
        model, tokenizer, prompt, gen_config, compute_logprobs=True
    )
    
    # Compute KL penalty if reference model provided
    kl_penalty = 0.0
    if ref_model is not None and kl_coef > 0:
        with torch.no_grad():
            ref_generation = generate(
                ref_model, tokenizer, prompt, gen_config, compute_logprobs=True
            )
            kl = compute_forward_kl(
                generation.logprobs.unsqueeze(0),
                ref_generation.logprobs.unsqueeze(0),
                generation.valid_mask.unsqueeze(0)
            )
            kl_penalty = kl_coef * kl.item()
    
    # Create preliminary rollout for reward computation
    rollout = Rollout(
        prompt=prompt,
        generation=generation,
        reward=0.0,  # Placeholder
        kl_penalty=kl_penalty
    )
    
    # Compute reward
    reward = reward_fn(rollout)
    
    # Return final rollout
    return Rollout(
        prompt=prompt,
        generation=generation,
        reward=reward - kl_penalty,
        kl_penalty=kl_penalty
    )

def create_multiturn_rollout(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    initial_message: str,
    gen_config: GenerationConfig,
    reward_fn: Callable[[Rollout], float],
    max_turns: int = 5,
    tool_parser: Optional[Callable[[str], List[Dict]]] = None,  # Extension point!
    tool_executor: Optional[Callable[[Dict], str]] = None,  # Extension point!
    stop_condition: Optional[Callable[[Conversation], bool]] = None,  # Extension point!
) -> Rollout:
    """
    Create multi-turn rollout with tool support.
    
    Extension points:
        tool_parser: Extract tool calls from text
        tool_executor: Execute a tool call and return result
        stop_condition: Custom stopping condition
    """
    messages = [
        Message(role="user", content=initial_message)
    ]
    
    all_generations = []
    
    for turn in range(max_turns):
        # Format conversation as prompt
        conversation_text = "\n\n".join(
            f"{msg.role.capitalize()}: {msg.content}" for msg in messages
        )
        conversation_text += "\n\nAssistant:"
        
        # Generate response
        generation = generate(
            model, tokenizer, conversation_text, gen_config, compute_logprobs=True
        )
        all_generations.append(generation)
        
        # Parse tool calls if parser provided
        tool_calls = []
        if tool_parser is not None:
            tool_calls = tool_parser(generation.text)
        
        # Add assistant message
        messages.append(
            Message(role="assistant", content=generation.text, tool_calls=tool_calls)
        )
        
        # Execute tools and add results
        if tool_calls and tool_executor is not None:
            for tool_call in tool_calls:
                result = tool_executor(tool_call)
                messages.append(
                    Message(role="user", content=f"Tool result: {result}")
                )
        
        # Check stop condition
        conversation = Conversation(messages=messages)
        if stop_condition is not None and stop_condition(conversation):
            break
        
        # If no tools, stop
        if not tool_calls:
            break
    
    # Combine all generations (use last one for reward)
    final_generation = all_generations[-1]
    
    conversation = Conversation(messages=messages)
    
    # Create rollout
    rollout = Rollout(
        prompt=initial_message,
        generation=final_generation,
        reward=0.0,  # Placeholder
        conversation=conversation
    )
    
    # Compute reward
    reward = reward_fn(rollout)
    
    return Rollout(
        prompt=initial_message,
        generation=final_generation,
        reward=reward,
        conversation=conversation
    )

def create_batch_rollouts(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    prompts: List[str],
    gen_config: GenerationConfig,
    reward_fn: Callable[[Rollout], float],
    ref_model: Optional[torch.nn.Module] = None,
    kl_coef: float = 0.0,
) -> List[Rollout]:
    """Create multiple rollouts in parallel."""
    rollouts = []
    for prompt in prompts:
        rollout = create_rollout(
            model, tokenizer, prompt, gen_config, reward_fn, ref_model, kl_coef
        )
        rollouts.append(rollout)
    return rollouts

print("âœ“ Rollout primitives defined")
print("  - create_rollout, create_multiturn_rollout")
print("  - create_batch_rollouts")

## 6. Algorithm Primitives

GRPO, DAPO, PPO components with extension points.

In [None]:
# =============================================================================
# ALGORITHM PRIMITIVES
# =============================================================================

def compute_advantages(
    rollouts: List[Rollout],
    normalize: bool = True,
    advantage_fn: Optional[Callable] = None,  # Extension point!
) -> torch.Tensor:
    """
    Compute advantages from rollouts.
    Returns: [batch, seq_len]
    
    Extension point:
        advantage_fn: Custom advantage computation
    """
    if advantage_fn is not None:
        return advantage_fn(rollouts, normalize)
    
    # Extract rewards
    rewards = torch.tensor([r.reward for r in rollouts], dtype=torch.float32)
    baseline = rewards.mean()
    
    # Broadcast to token level
    max_len = max(r.generation.token_ids.size(0) for r in rollouts)
    batch_size = len(rollouts)
    
    advantages = torch.zeros((batch_size, max_len), dtype=torch.float32)
    
    for i, rollout in enumerate(rollouts):
        seq_len = rollout.generation.token_ids.size(0)
        adv = rewards[i] - baseline
        advantages[i, :seq_len] = adv
    
    # Normalize
    if normalize:
        valid_mask = torch.zeros_like(advantages)
        for i, rollout in enumerate(rollouts):
            seq_len = rollout.generation.valid_mask.size(0)
            valid_mask[i, :seq_len] = rollout.generation.valid_mask
        
        adv_mean = (advantages * valid_mask).sum() / valid_mask.sum().clamp(min=1.0)
        adv_std = torch.sqrt(
            ((advantages - adv_mean) ** 2 * valid_mask).sum() / valid_mask.sum().clamp(min=1.0)
        )
        advantages = (advantages - adv_mean) / (adv_std + 1e-8)
    
    return advantages

def compute_policy_gradient_loss(
    rollouts: List[Rollout],
    advantages: torch.Tensor,
    loss_fn: Optional[Callable] = None,  # Extension point!
) -> torch.Tensor:
    """
    Compute policy gradient loss.
    
    Extension point:
        loss_fn: Custom loss computation
    """
    if loss_fn is not None:
        return loss_fn(rollouts, advantages)
    
    # Standard policy gradient: -E[log Ï€(a|s) * advantage]
    total_loss = 0.0
    total_tokens = 0
    
    for i, rollout in enumerate(rollouts):
        seq_len = rollout.generation.logprobs.size(0)
        logprobs = rollout.generation.logprobs
        valid_mask = rollout.generation.valid_mask
        adv = advantages[i, :seq_len]
        
        loss_per_token = -adv.detach() * logprobs * valid_mask
        total_loss += loss_per_token.sum()
        total_tokens += valid_mask.sum()
    
    return total_loss / total_tokens.clamp(min=1.0)

# ===== GRPO (Group Relative Policy Optimization) =====

def create_grpo_config(
    kl_coef: float = 0.1,
    normalize_advantages: bool = True,
) -> TrainingConfig:
    """Create GRPO configuration."""
    return TrainingConfig(
        algorithm="grpo",
        kl_coef=kl_coef,
        normalize_advantages=normalize_advantages
    )

# ===== DAPO (Direct Alignment from Preferences Optimization) =====

def create_dapo_config(
    beta: float = 0.1,
    normalize_advantages: bool = True,
) -> TrainingConfig:
    """Create DAPO configuration."""
    return TrainingConfig(
        algorithm="dapo",
        kl_coef=beta,
        normalize_advantages=normalize_advantages
    )

def compute_dapo_loss(
    rollouts: List[Rollout],
    beta: float = 0.1,
    pair_selector: Optional[Callable] = None,  # Extension point!
) -> torch.Tensor:
    """
    Compute DAPO loss (pairwise preference).
    
    Extension point:
        pair_selector: Custom pair selection (rollouts) -> [(i, j)]
    """
    if pair_selector is None:
        # Default: compare all pairs
        pairs = [(i, j) for i in range(len(rollouts)) for j in range(i+1, len(rollouts))]
    else:
        pairs = pair_selector(rollouts)
    
    total_loss = 0.0
    
    for i, j in pairs:
        r_i, r_j = rollouts[i], rollouts[j]
        
        # Preference: higher reward is preferred
        if r_i.reward > r_j.reward:
            preferred, dispreferred = r_i, r_j
        else:
            preferred, dispreferred = r_j, r_i
        
        # DAPO loss: -log sigmoid(beta * (log Ï€_pref - log Ï€_dispref))
        pref_logp = preferred.generation.logprobs.sum()
        dispref_logp = dispreferred.generation.logprobs.sum()
        
        loss = -F.logsigmoid(beta * (pref_logp - dispref_logp))
        total_loss += loss
    
    return total_loss / max(len(pairs), 1)

# ===== PPO (Proximal Policy Optimization) =====

def create_ppo_config(
    clip_range: float = 0.2,
    value_loss_coef: float = 0.5,
    kl_coef: float = 0.0,
) -> TrainingConfig:
    """Create PPO configuration."""
    return TrainingConfig(
        algorithm="ppo",
        clip_range=clip_range,
        kl_coef=kl_coef,
        normalize_advantages=True
    )

def compute_ppo_loss(
    rollouts: List[Rollout],
    old_logprobs: torch.Tensor,  # [batch, seq_len]
    advantages: torch.Tensor,
    clip_range: float = 0.2,
) -> torch.Tensor:
    """
    Compute PPO clipped loss.
    """
    total_loss = 0.0
    total_tokens = 0
    
    for i, rollout in enumerate(rollouts):
        seq_len = rollout.generation.logprobs.size(0)
        new_logprobs = rollout.generation.logprobs
        old_lp = old_logprobs[i, :seq_len]
        valid_mask = rollout.generation.valid_mask
        adv = advantages[i, :seq_len]
        
        # Importance ratio
        ratio = torch.exp(new_logprobs - old_lp)
        
        # Clipped objective
        surr1 = ratio * adv
        surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * adv
        policy_loss = -torch.min(surr1, surr2) * valid_mask
        
        total_loss += policy_loss.sum()
        total_tokens += valid_mask.sum()
    
    return total_loss / total_tokens.clamp(min=1.0)

# ===== Algorithm switching =====

def get_algorithm_config(
    iteration: int,
    switch_strategy: Callable[[int], str],  # Extension point!
) -> TrainingConfig:
    """
    Get algorithm config for current iteration.
    
    Extension point:
        switch_strategy: (iteration) -> algorithm_name
    """
    algo = switch_strategy(iteration)
    
    if algo == "grpo":
        return create_grpo_config()
    elif algo == "dapo":
        return create_dapo_config()
    elif algo == "ppo":
        return create_ppo_config()
    else:
        raise ValueError(f"Unknown algorithm: {algo}")

print("âœ“ Algorithm primitives defined")
print("  - compute_advantages, compute_policy_gradient_loss")
print("  - GRPO: create_grpo_config")
print("  - DAPO: create_dapo_config, compute_dapo_loss")
print("  - PPO: create_ppo_config, compute_ppo_loss")
print("  - get_algorithm_config (for switching)")

## 7. Training Loop Primitives

Iteration management, logging, checkpointing.

In [None]:
# =============================================================================
# TRAINING LOOP PRIMITIVES
# =============================================================================

def training_iteration(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    prompts: List[str],
    gen_config: GenerationConfig,
    train_config: TrainingConfig,
    reward_fn: Callable,
    optimizer: torch.optim.Optimizer,
    ref_model: Optional[torch.nn.Module] = None,
    ref_update_fn: Optional[Callable] = None,  # Extension point!
) -> TrainingState:
    """
    Single training iteration.
    
    Extension point:
        ref_update_fn: Update reference model after iteration
    """
    # Create rollouts
    rollouts = create_batch_rollouts(
        model, tokenizer, prompts, gen_config, reward_fn, ref_model, train_config.kl_coef
    )
    
    # Compute advantages
    advantages = compute_advantages(rollouts, normalize=train_config.normalize_advantages)
    
    # Compute loss based on algorithm
    if train_config.algorithm == "grpo" or train_config.algorithm == "pg":
        loss = compute_policy_gradient_loss(rollouts, advantages)
    elif train_config.algorithm == "dapo":
        loss = compute_dapo_loss(rollouts, beta=train_config.kl_coef)
    elif train_config.algorithm == "ppo":
        # For PPO, need old logprobs (not implemented in this simplified version)
        raise NotImplementedError("PPO requires storing old logprobs")
    else:
        raise ValueError(f"Unknown algorithm: {train_config.algorithm}")
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    
    # Gradient clipping
    if train_config.max_grad_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.max_grad_norm)
    
    optimizer.step()
    
    # Update reference model if function provided
    if ref_model is not None and ref_update_fn is not None:
        ref_update_fn(ref_model, model)
    
    # Compute metrics
    mean_reward = np.mean([r.reward for r in rollouts])
    mean_kl = np.mean([r.kl_penalty for r in rollouts])
    total_tokens = sum(r.generation.valid_mask.sum().item() for r in rollouts)
    
    return TrainingState(
        step=0,  # Will be set by caller
        total_tokens=int(total_tokens),
        metrics={
            "loss": loss.item(),
            "reward": mean_reward,
            "kl": mean_kl,
        }
    )

def update_ema_metrics(
    current_metrics: Dict[str, float],
    ema_metrics: Dict[str, float],
    momentum: float = 0.9
) -> Dict[str, float]:
    """Update EMA metrics."""
    new_ema = {}
    for key, value in current_metrics.items():
        if key in ema_metrics:
            new_ema[key] = momentum * ema_metrics[key] + (1 - momentum) * value
        else:
            new_ema[key] = value
    return new_ema

def save_checkpoint(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    step: int,
    output_dir: str,
    checkpoint_fn: Optional[Callable] = None,  # Extension point!
) -> None:
    """
    Save model checkpoint.
    
    Extension point:
        checkpoint_fn: Custom checkpoint saving
    """
    if checkpoint_fn is not None:
        checkpoint_fn(model, tokenizer, step, output_dir)
        return
    
    checkpoint_dir = os.path.join(output_dir, f"checkpoint-{step}")
    os.makedirs(checkpoint_dir, exist_ok=True)
    model.save_pretrained(checkpoint_dir)
    tokenizer.save_pretrained(checkpoint_dir)

print("âœ“ Training loop primitives defined")
print("  - training_iteration")
print("  - update_ema_metrics")
print("  - save_checkpoint")

## 8. Example: Basic Training Loop

Putting it all together!

In [None]:
# =============================================================================
# EXAMPLE: BASIC TRAINING LOOP
# =============================================================================

def train(
    model_name: str,
    prompts: List[str],
    reward_fn: Callable,
    num_steps: int = 100,
    batch_size: int = 4,
    algorithm: str = "grpo",
    learning_rate: float = 1e-4,
    use_ref_model: bool = False,
    kl_coef: float = 0.1,
    output_dir: str = "./output",
):
    """
    Complete training loop example.
    """
    print("=" * 80)
    print("TRAINING SETUP")
    print("=" * 80)
    print(f"Model: {model_name}")
    print(f"Algorithm: {algorithm}")
    print(f"Steps: {num_steps}")
    print(f"Batch size: {batch_size}")
    print(f"Learning rate: {learning_rate}")
    print(f"Use reference model: {use_ref_model}")
    print(f"KL coefficient: {kl_coef if use_ref_model else 0.0}")
    print("=" * 80)
    
    # Load model and tokenizer
    print("\nLoading model...")
    model = load_model(model_name, use_lora=True)
    tokenizer = load_tokenizer(model_name)
    
    # Create reference model if needed
    ref_model = None
    if use_ref_model:
        print("Creating reference model...")
        ref_model = create_reference_model(model)
    
    # Create optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Create configs
    gen_config = GenerationConfig(max_new_tokens=256, temperature=0.7)
    
    if algorithm == "grpo":
        train_config = create_grpo_config(kl_coef=kl_coef)
    elif algorithm == "dapo":
        train_config = create_dapo_config(beta=kl_coef)
    else:
        raise ValueError(f"Unknown algorithm: {algorithm}")
    
    # Training loop
    print("\nStarting training...\n")
    os.makedirs(output_dir, exist_ok=True)
    
    total_tokens = 0
    ema_metrics = {}
    
    pbar = tqdm(range(num_steps), desc="Training")
    
    for step in pbar:
        # Sample prompts
        batch_prompts = np.random.choice(prompts, size=batch_size, replace=True).tolist()
        
        # Training iteration
        state = training_iteration(
            model, tokenizer, batch_prompts, gen_config, train_config,
            reward_fn, optimizer, ref_model
        )
        
        total_tokens += state.total_tokens
        
        # Update EMA
        ema_metrics = update_ema_metrics(state.metrics, ema_metrics)
        
        # Update progress bar
        pbar.set_postfix({
            "loss": f"{state.metrics['loss']:.3f}",
            "reward": f"{state.metrics['reward']:.3f}",
            "tokens": total_tokens
        })
        
        # Save checkpoint
        if (step + 1) % 50 == 0:
            save_checkpoint(model, tokenizer, step, output_dir)
            print(f"\nCheckpoint saved at step {step}")
    
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE")
    print("=" * 80)
    print(f"Final loss: {ema_metrics['loss']:.4f}")
    print(f"Final reward: {ema_metrics['reward']:.4f}")
    print(f"Total tokens: {total_tokens:,}")
    print(f"Output dir: {output_dir}")
    
    return model, tokenizer

print("âœ“ Training loop example defined")
print("  Use: train(model_name, prompts, reward_fn, ...)")

## 9. Creative Examples

Demonstrating the power of primitives!

### Example 1: Vector Rewards with Custom Aggregation

In [None]:
# Example: Vector rewards (correctness, safety, helpfulness)

def correctness_reward(rollout: Rollout) -> float:
    # Check if answer is correct
    return 1.0 if "correct" in rollout.generation.text.lower() else 0.0

def safety_reward(rollout: Rollout) -> float:
    # Check for unsafe content
    unsafe_words = ["harm", "illegal", "dangerous"]
    text = rollout.generation.text.lower()
    return 0.0 if any(word in text for word in unsafe_words) else 1.0

def helpfulness_reward(rollout: Rollout) -> float:
    # Check if response is helpful
    length = len(rollout.generation.text.split())
    return 1.0 if 50 <= length <= 200 else 0.5

# Combine into vector reward
def vector_reward_fn(rollout: Rollout) -> float:
    vector = compute_vector_reward(rollout, [
        correctness_reward,
        safety_reward,
        helpfulness_reward
    ])
    # Aggregate with custom weights
    return aggregate_vector_rewards(
        vector,
        weighted_sum_aggregator([0.5, 0.3, 0.2])  # Correctness most important
    )

print("âœ“ Vector reward example defined")
print("  Components: correctness (0.5), safety (0.3), helpfulness (0.2)")

### Example 2: Algorithm Switching (GRPO â†’ DAPO)

In [None]:
# Example: Switch from GRPO to DAPO after 50 steps

def adaptive_algorithm_switch(iteration: int) -> str:
    """Switch algorithm based on training progress."""
    if iteration < 50:
        return "grpo"  # Start with GRPO for exploration
    else:
        return "dapo"  # Switch to DAPO for refinement

# Usage in training loop:
# for step in range(num_steps):
#     train_config = get_algorithm_config(step, adaptive_algorithm_switch)
#     # ... rest of training iteration

print("âœ“ Algorithm switching example defined")
print("  Strategy: GRPO (steps 0-49) â†’ DAPO (steps 50+)")

### Example 3: Reference Model Update Strategy

In [None]:
# Example: EMA update for reference model

# Create EMA updater (decay=0.999)
ref_update_fn = create_ema_updater(decay=0.999)

# Alternative: Periodic hard update every 20 steps
ref_update_fn_periodic = create_periodic_updater(update_every=20)

# Custom: Adaptive update based on KL
def adaptive_ref_update(kl_threshold=0.5):
    state = {"accumulated_kl": 0.0}
    
    def update_fn(ref_model, policy_model, kl_value):
        state["accumulated_kl"] += kl_value
        if state["accumulated_kl"] > kl_threshold:
            # Hard update when KL gets too high
            with torch.no_grad():
                for ref_param, policy_param in zip(ref_model.parameters(), policy_model.parameters()):
                    ref_param.data.copy_(policy_param.data)
            state["accumulated_kl"] = 0.0
    return update_fn

print("âœ“ Reference model update strategies defined")
print("  - EMA (decay=0.999)")
print("  - Periodic (every 20 steps)")
print("  - Adaptive (based on accumulated KL)")

### Example 4: Tool Call Tracking in Multi-Turn Rollouts

In [None]:
# Example: Track and reward tool usage

def parse_tool_calls(text: str) -> List[Dict]:
    """Parse XML-style tool calls from text."""
    import re
    pattern = r'<tool name="([^"]+)">(.*?)</tool>'
    matches = re.findall(pattern, text, re.DOTALL)
    return [{"name": name, "args": args.strip()} for name, args in matches]

def execute_tool(tool_call: Dict) -> str:
    """Execute a tool call (mock implementation)."""
    if tool_call["name"] == "calculator":
        try:
            result = eval(tool_call["args"])  # UNSAFE - only for demo!
            return f"Result: {result}"
        except:
            return "Error: Invalid expression"
    return "Unknown tool"

def tool_call_stop_condition(conversation: Conversation) -> bool:
    """Stop when assistant says 'DONE' or 5 turns."""
    if len(conversation.messages) >= 10:  # 5 turns * 2 messages per turn
        return True
    last_message = conversation.messages[-1]
    if last_message.role == "assistant" and "DONE" in last_message.content:
        return True
    return False

# Reward based on tool usage
def multiturn_tool_reward(rollout: Rollout) -> float:
    if rollout.conversation is None:
        return 0.0
    
    total_tools = sum(len(msg.tool_calls) for msg in rollout.conversation.messages)
    
    # Reward for using tools (but not too many)
    if total_tools == 0:
        return 0.0
    elif 1 <= total_tools <= 3:
        return 1.0
    else:
        return 0.5  # Penalize excessive tool use

# Usage:
# rollout = create_multiturn_rollout(
#     model, tokenizer, "Solve: 15 * 23 + 17",
#     gen_config, multiturn_tool_reward,
#     tool_parser=parse_tool_calls,
#     tool_executor=execute_tool,
#     stop_condition=tool_call_stop_condition
# )

print("âœ“ Multi-turn tool tracking example defined")
print("  - parse_tool_calls (XML format)")
print("  - execute_tool (mock calculator)")
print("  - Reward: 1.0 for 1-3 tools, 0.5 for more, 0.0 for none")

### Example 5: Custom KL with Truncation

In [None]:
# Example: Truncated KL to prevent excessive penalization

# Create truncated KL function
truncated_kl_fn = create_truncated_kl(threshold=1.0)

# Use in rollout creation:
# def custom_kl_rollout(model, tokenizer, prompt, gen_config, reward_fn, ref_model):
#     # ... generate with policy and reference models ...
#     kl = compute_forward_kl(
#         policy_logprobs, ref_logprobs, valid_mask,
#         kl_fn=truncated_kl_fn  # Use truncated KL!
#     )
#     # ... rest of rollout creation ...

print("âœ“ Truncated KL example defined")
print("  Threshold: 1.0 (clips KL values above this)")

## Summary

You now have **50+ composable primitives** for RL training:

### Data Structures
- Generation, Message, Conversation
- Rollout, Batch
- GenerationConfig, TrainingConfig, TrainingState

### Inference & Generation
- load_model, load_tokenizer
- format_prompt, parse_thinking
- generate, compute_logprobs_for_tokens

### Rewards
- compute_scalar_reward, compute_vector_reward
- aggregate_vector_rewards
- Built-in rewards: length_penalty, format_reward, exact_match, tool_usage
- Aggregators: weighted_sum, min, max, mean

### KL Divergence
- compute_forward_kl, compute_reverse_kl
- create_truncated_kl, create_adaptive_kl
- create_reference_model, update_reference_model
- create_ema_updater, create_periodic_updater

### Rollouts
- create_rollout, create_multiturn_rollout
- create_batch_rollouts

### Algorithms
- compute_advantages, compute_policy_gradient_loss
- GRPO: create_grpo_config
- DAPO: create_dapo_config, compute_dapo_loss
- PPO: create_ppo_config, compute_ppo_loss
- get_algorithm_config (for switching)

### Training Loop
- training_iteration
- update_ema_metrics
- save_checkpoint

## Key Benefits

âœ… **Hackable**: Swap any component with a custom function  
âœ… **Composable**: Mix and match primitives  
âœ… **Testable**: Each function is independently testable  
âœ… **Extensible**: 50+ extension points for custom logic  
âœ… **Type-safe**: NamedTuples provide clear contracts

## Next Steps

1. **Test with real model**: Load a small model and run basic training
2. **Add your use case**: Create custom reward/KL/algorithm functions
3. **Experiment**: Try vector rewards, algorithm switching, tool tracking
4. **Scale up**: Use with larger models and datasets

Happy experimenting! ðŸš€