In [None]:
!nvidia-smi -L || true

import sys
print("Python:", sys.version)

# Install required packages
# Core ML
!uv pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                 datasets>=3.3.2 torch wandb huggingface_hub \
                 sentencepiece protobuf tqdm matplotlib pandas

# Inference backends (install based on what you'll use)
# Uncomment the backend you want:
# !uv pip install -q vllm>=0.6.0  # For vLLM backend
# !uv pip install -q "sglang[all]>=0.4.0"  # For SGLang backend

print("\n=== Environment ===")
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "Please connect a GPU for RL training."

In [None]:
# Core imports for RL harness
import os
import random
import time
import json
import asyncio
import uuid
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional, Any, Tuple, Callable
from collections import defaultdict, deque
from queue import Queue

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from tqdm.auto import tqdm
import pandas as pd

# HuggingFace
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Optional: vLLM (comment out if not installed)
try:
    from vllm import LLM, SamplingParams
    VLLM_AVAILABLE = True
    print("✓ vLLM available")
except ImportError:
    VLLM_AVAILABLE = False
    print("⚠ vLLM not installed")

# Optional: SGLang (comment out if not installed)
try:
    import sglang as sgl
    SGLANG_AVAILABLE = True
    print("✓ SGLang available")
except ImportError:
    SGLANG_AVAILABLE = False
    print("⚠ SGLang not installed")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Core Data Structures
Immutable data structures for tracking rollouts and generation results.

In [None]:
@dataclass
class GenerationResult:
    """Output from inference backend"""
    text: str
    token_ids: List[int]
    logprobs: Optional[List[float]] = None  # Log probabilities per token
    top_logprobs: Optional[List[Dict[int, float]]] = None  # Top-k logprobs
    finish_reason: str = "length"  # "stop", "length", "error"
    
@dataclass
class StepData:
    """Single step in a trajectory"""
    prompt: str
    response: GenerationResult
    action_logprobs: Optional[torch.Tensor] = None  # For gradient computation
    reward: Optional[np.ndarray] = None  # Multi-dimensional reward vector
    value_estimate: Optional[float] = None
    timestamp: float = field(default_factory=time.time)
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class RolloutState:
    """Complete state of a single rollout"""
    rollout_id: str
    initial_prompt: str
    trajectory: List[StepData] = field(default_factory=list)
    step_count: int = 0
    terminated: bool = False
    termination_reason: Optional[str] = None
    cumulative_reward: np.ndarray = field(default_factory=lambda: np.zeros(1))
    verification_budget_spent: float = 0.0
    max_verification_budget: float = 10.0
    ref_model_snapshot_id: Optional[str] = None  # For dynamic ref changes
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def checkpoint(self) -> 'RolloutState':
        """Create a deep copy for branching"""
        return deepcopy(self)
    
    def can_verify(self, cost: float) -> bool:
        """Check if verification is within budget"""
        return self.verification_budget_spent + cost <= self.max_verification_budget
    
    def add_step(self, step: StepData):
        """Add a step to trajectory"""
        self.trajectory.append(step)
        self.step_count += 1
        if step.reward is not None:
            self.cumulative_reward = self.cumulative_reward + step.reward
    
    def get_full_context(self) -> str:
        """Get the full conversation context"""
        context = self.initial_prompt
        for step in self.trajectory:
            context += step.response.text
        return context
    
    def to_dict(self) -> Dict:
        """Serialize for storage"""
        return asdict(self)

@dataclass
class TrainingBatch:
    """Batch of data for training update"""
    prompts: List[str]
    responses: List[str]
    token_ids: List[List[int]]
    rewards: np.ndarray  # (batch_size, num_objectives)
    old_logprobs: Optional[torch.Tensor] = None  # For importance sampling
    advantages: Optional[torch.Tensor] = None
    returns: Optional[torch.Tensor] = None
    
print("✓ Data structures defined")

In [None]:
import os

# WandB API key - get from https://wandb.ai/authorize
WANDB_API_KEY = ""  # Your WandB API key
if WANDB_API_KEY:
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY

# HuggingFace token - get from https://huggingface.co/settings/tokens
HF_TOKEN = ""  # Your HuggingFace token
if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN

In [None]:
import wandb
from huggingface_hub import login

# WandB login
try:
    wandb.login()
    print("✓ WandB login successful")
except Exception as e:
    print(f"⚠ WandB login failed: {e}")
    print("Training will continue without WandB logging")

# HuggingFace login
try:
    if os.environ.get('HF_TOKEN'):
        login(token=os.environ['HF_TOKEN'])
        print("✓ HuggingFace login successful")
except Exception as e:
    print(f"⚠ HuggingFace login failed: {e}")

# Async Rollout manager with dynamic batching
Since rollouts vary dramatically in length (6-50 steps), use an async architecture:

In [None]:
class AsyncRolloutManager:
    """
    Manages concurrent rollouts with dynamic batching.
    Groups rollouts by step count for efficient batch inference.
    """
    
    def __init__(
        self,
        inference_backend: InferenceBackend,
        max_concurrent_rollouts: int = 32,
        max_steps: int = 50,
        num_objectives: int = 1,
    ):
        self.backend = inference_backend
        self.max_concurrent = max_concurrent_rollouts
        self.max_steps = max_steps
        self.num_objectives = num_objectives
        
        self.active_rollouts: Dict[str, RolloutState] = {}
        self.completed_rollouts: List[RolloutState] = []
        self.rollout_counter = 0
        
        # Customizable hooks - you override these!
        self.should_terminate_fn: Optional[Callable[[RolloutState], bool]] = None
        self.get_reward_fn: Optional[Callable[[RolloutState, StepData], np.ndarray]] = None
        self.process_response_fn: Optional[Callable[[str], str]] = None
    
    def start_rollout(self, prompt: str, metadata: Optional[Dict] = None) -> str:
        """Start a new rollout from prompt"""
        if len(self.active_rollouts) >= self.max_concurrent:
            raise RuntimeError(f"Max concurrent rollouts ({self.max_concurrent}) reached")
        
        rollout_id = f"rollout_{self.rollout_counter}"
        self.rollout_counter += 1
        
        state = RolloutState(
            rollout_id=rollout_id,
            initial_prompt=prompt,
            cumulative_reward=np.zeros(self.num_objectives),
            metadata=metadata or {},
        )
        
        self.active_rollouts[rollout_id] = state
        return rollout_id
    
    def step_all(
        self,
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        **gen_kwargs
    ) -> Dict[str, GenerationResult]:
        """
        Step all active rollouts.
        Groups by step count for efficient batching.
        Returns: Dict mapping rollout_id to generation result
        """
        if not self.active_rollouts:
            return {}
        
        # Group active rollouts by step count for batching
        step_groups: Dict[int, List[str]] = defaultdict(list)
        for rid, state in self.active_rollouts.items():
            if not state.terminated:
                step_groups[state.step_count].append(rid)
        
        all_results = {}
        
        # Process each step group as a batch
        for step_num, rollout_ids in step_groups.items():
            # Gather prompts (full context so far)
            prompts = [self.active_rollouts[rid].get_full_context() for rid in rollout_ids]
            
            # Batch generation
            gen_results = self.backend.generate(
                prompts,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                **gen_kwargs
            )
            
            # Process results for each rollout
            for rid, gen_result in zip(rollout_ids, gen_results):
                state = self.active_rollouts[rid]
                
                # Optionally process response
                if self.process_response_fn:
                    gen_result.text = self.process_response_fn(gen_result.text)
                
                # Create step data
                step_data = StepData(
                    prompt=state.get_full_context(),
                    response=gen_result,
                )
                
                # Compute reward if function provided
                if self.get_reward_fn:
                    step_data.reward = self.get_reward_fn(state, step_data)
                else:
                    step_data.reward = np.zeros(self.num_objectives)
                
                # Add step to trajectory
                state.add_step(step_data)
                
                # Check termination conditions
                terminated = False
                termination_reason = None
                
                # Max steps reached
                if state.step_count >= self.max_steps:
                    terminated = True
                    termination_reason = "max_steps"
                
                # EOS token generated
                if gen_result.finish_reason == "stop":
                    terminated = True
                    termination_reason = "eos"
                
                # Custom termination
                if self.should_terminate_fn and self.should_terminate_fn(state):
                    terminated = True
                    termination_reason = "custom"
                
                if terminated:
                    state.terminated = True
                    state.termination_reason = termination_reason
                
                all_results[rid] = gen_result
        
        return all_results
    
    def complete_terminated_rollouts(self) -> List[RolloutState]:
        """Move terminated rollouts to completed list"""
        newly_completed = []
        to_remove = []
        
        for rid, state in self.active_rollouts.items():
            if state.terminated:
                self.completed_rollouts.append(state)
                newly_completed.append(state)
                to_remove.append(rid)
        
        for rid in to_remove:
            del self.active_rollouts[rid]
        
        return newly_completed
    
    def has_active_rollouts(self) -> bool:
        """Check if there are active rollouts"""
        return len(self.active_rollouts) > 0
    
    def get_statistics(self) -> Dict:
        """Get current rollout statistics"""
        active_steps = [s.step_count for s in self.active_rollouts.values()]
        completed_steps = [s.step_count for s in self.completed_rollouts]
        
        return {
            "active": len(self.active_rollouts),
            "completed": len(self.completed_rollouts),
            "active_avg_steps": np.mean(active_steps) if active_steps else 0,
            "completed_avg_steps": np.mean(completed_steps) if completed_steps else 0,
            "completed_rewards": [s.cumulative_reward.tolist() for s in self.completed_rollouts[-10:]],
        }
    
    def reset(self):
        """Reset manager for new training iteration"""
        self.active_rollouts.clear()
        self.completed_rollouts.clear()

print("✓ AsyncRolloutManager defined")

# Flexible Rollout State with Checkpointing
Track everything per-rollout with ability to branch/restore:


# Multi-Backend Inference with Unified Interface

In [None]:
class InferenceBackend(ABC):
    """Abstract base class for inference backends"""
    
    @abstractmethod
    def generate(self, prompts: List[str], **kwargs) -> List[GenerationResult]:
        """Generate responses for a batch of prompts"""
        pass
    
    @abstractmethod
    def shutdown(self):
        """Clean up resources"""
        pass


class HuggingFaceBackend(InferenceBackend):
    """
    HuggingFace backend - full gradient access.
    Use this for:
    - Gradient computation during training (FULL UPDATES)
    - Reference model logit computation
    - Small-scale experiments
    """
    
    def __init__(
        self,
        model_name: str,
        device: str = "cuda",
        dtype: str = "bfloat16",  # "float32", "bfloat16", "float16"
        gradient_checkpointing: bool = False,  # Save memory during backward pass
    ):
        self.device = device
        self.model_name = model_name
        self.dtype_map = {
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
        }
        self.torch_dtype = self.dtype_map.get(dtype, torch.bfloat16)
        
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model - FULL WEIGHTS
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=self.torch_dtype,
            device_map=None,  # We'll move to device manually for full control
            trust_remote_code=True,
        )
        self.model = self.model.to(device)
        
        # Optional: Gradient checkpointing for memory efficiency
        if gradient_checkpointing:
            self.model.gradient_checkpointing_enable()
            print(f"✓ Gradient checkpointing enabled")
        
        # Count trainable parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"✓ HuggingFace backend loaded: {model_name}")
        print(f"  Total params: {total_params:,}")
        print(f"  Trainable params: {trainable_params:,}")
        print(f"  Dtype: {dtype}")
    
    def generate(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        do_sample: bool = True,
        return_logprobs: bool = False,
        **kwargs
    ) -> List[GenerationResult]:
        """Generate responses with optional logprobs"""
        self.model.eval()  # Set to eval mode for generation
        results = []
        
        for prompt in prompts:
            inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature if do_sample else 1.0,
                    top_p=top_p if do_sample else 1.0,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    output_scores=return_logprobs,
                    return_dict_in_generate=True,
                )
            
            # Decode
            generated_ids = outputs.sequences[0][inputs["input_ids"].shape[1]:]
            text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            
            # Extract logprobs if requested
            logprobs = None
            if return_logprobs and hasattr(outputs, "scores"):
                logprobs = []
                for i, score in enumerate(outputs.scores):
                    probs = F.log_softmax(score[0], dim=-1)
                    token_id = generated_ids[i].item()
                    logprobs.append(probs[token_id].item())
            
            results.append(GenerationResult(
                text=text,
                token_ids=generated_ids.tolist(),
                logprobs=logprobs,
                finish_reason="stop" if generated_ids[-1] == self.tokenizer.eos_token_id else "length"
            ))
        
        return results
    
    def get_logits(
        self,
        prompts: List[str],
        responses: List[str],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get logits for prompt+response pairs.
        Returns: (logits, attention_mask)
        Essential for KL divergence computation in RL.
        """
        full_texts = [p + r for p, r in zip(prompts, responses)]
        
        inputs = self.tokenizer(
            full_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        return outputs.logits, inputs["attention_mask"]
    
    def compute_log_probs(
        self,
        prompts: List[str],
        responses: List[str],
    ) -> torch.Tensor:
        """
        Compute log probabilities of responses given prompts.
        Critical for RL training - stable numerical computation.
        """
        full_texts = [p + r for p, r in zip(prompts, responses)]
        
        # Tokenize
        prompt_encodings = self.tokenizer(prompts, padding=True, return_tensors="pt")
        full_encodings = self.tokenizer(full_texts, padding=True, return_tensors="pt")
        
        prompt_lens = [len(self.tokenizer.encode(p)) for p in prompts]
        
        inputs = {k: v.to(self.device) for k, v in full_encodings.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits  # (batch, seq_len, vocab)
        
        # Compute log probs for each response token
        log_probs_list = []
        for i in range(len(prompts)):
            # Response starts after prompt
            start_idx = prompt_lens[i] - 1  # -1 because we predict next token
            response_logits = logits[i, start_idx:-1]  # (response_len, vocab)
            response_ids = full_encodings["input_ids"][i, prompt_lens[i]:]
            
            # Filter out padding
            mask = full_encodings["attention_mask"][i, prompt_lens[i]:]
            
            # Compute log softmax
            log_probs = F.log_softmax(response_logits, dim=-1)
            
            # Gather log probs of actual tokens
            token_log_probs = log_probs.gather(1, response_ids.unsqueeze(1).to(self.device)).squeeze()
            
            # Apply mask and sum
            masked_log_probs = token_log_probs * mask.to(self.device).float()
            log_probs_list.append(masked_log_probs.sum())
        
        return torch.stack(log_probs_list)
    
    def shutdown(self):
        del self.model
        torch.cuda.empty_cache()
        print("✓ HuggingFace backend shutdown")


class VLLMBackend(InferenceBackend):
    """
    vLLM backend - fast inference with PagedAttention.
    Use this for:
    - Fast rollout generation
    - Large batch inference
    Note: vLLM doesn't support gradient computation!
    """
    
    def __init__(
        self,
        model_name: str,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.9,
        max_model_len: int = 4096,
    ):
        if not VLLM_AVAILABLE:
            raise ImportError("vLLM not installed. Run: pip install vllm")
        
        self.model_name = model_name
        self.llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            max_model_len=max_model_len,
            trust_remote_code=True,
        )
        print(f"✓ vLLM backend loaded: {model_name}")
    
    def generate(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        return_logprobs: bool = False,
        n_logprobs: int = 5,
        **kwargs
    ) -> List[GenerationResult]:
        """Batch generate with vLLM"""
        
        sampling_params = SamplingParams(
            max_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            logprobs=n_logprobs if return_logprobs else None,
        )
        
        outputs = self.llm.generate(prompts, sampling_params)
        
        results = []
        for output in outputs:
            completion = output.outputs[0]
            
            logprobs = None
            if return_logprobs and completion.logprobs:
                logprobs = [lp[completion.token_ids[i]].logprob 
                           for i, lp in enumerate(completion.logprobs)]
            
            results.append(GenerationResult(
                text=completion.text,
                token_ids=list(completion.token_ids),
                logprobs=logprobs,
                finish_reason=completion.finish_reason or "length"
            ))
        
        return results
    
    def shutdown(self):
        del self.llm
        torch.cuda.empty_cache()
        print("✓ vLLM backend shutdown")


class SGLangBackend(InferenceBackend):
    """
    SGLang backend - RadixAttention for efficient KV caching.
    Use this for:
    - Fast rollout generation
    - Prefix caching (good for multi-turn)
    Note: SGLang doesn't support gradient computation!
    """
    
    def __init__(
        self,
        model_name: str,
        mem_fraction_static: float = 0.8,
        tp_size: int = 1,
    ):
        if not SGLANG_AVAILABLE:
            raise ImportError("SGLang not installed. Run: pip install 'sglang[all]'")
        
        self.model_name = model_name
        self.engine = sgl.Engine(
            model_path=model_name,
            mem_fraction_static=mem_fraction_static,
            tp_size=tp_size,
        )
        print(f"✓ SGLang backend loaded: {model_name}")
    
    def generate(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        return_logprobs: bool = False,
        n_logprobs: int = 5,
        **kwargs
    ) -> List[GenerationResult]:
        """Batch generate with SGLang"""
        
        sampling_params = {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }
        
        if return_logprobs:
            outputs = self.engine.generate(
                prompts,
                sampling_params,
                return_logprob=True,
                top_logprobs_num=n_logprobs,
            )
        else:
            outputs = self.engine.generate(prompts, sampling_params)
        
        results = []
        for output in outputs:
            logprobs = None
            if return_logprobs and "meta_info" in output:
                # SGLang returns: (logprob, token_id, text)
                logprobs = [lp[0] for lp in output["meta_info"]["output_token_logprobs"]]
            
            results.append(GenerationResult(
                text=output["text"],
                token_ids=output.get("output_ids", []),
                logprobs=logprobs,
                finish_reason=output.get("meta_info", {}).get("finish_reason", {}).get("type", "length")
            ))
        
        return results
    
    def shutdown(self):
        self.engine.shutdown()
        print("✓ SGLang backend shutdown")


print("✓ Inference backends defined")

# Dynamic Reference Model Manager
Handle changing reference models mid-training:

In [None]:
class ReferenceModelManager:
    """
    Manages reference model for KL divergence computation.
    Supports dynamic snapshots mid-training.
    """
    
    def __init__(self, model_name: str, device: str = "cuda"):
        self.model_name = model_name
        self.device = device
        self.snapshots: Dict[str, Dict] = {}  # snapshot_id -> state_dict
        
        # Load reference model (frozen)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
        self.model.eval()
        
        # Freeze all parameters
        for param in self.model.parameters():
            param.requires_grad = False
        
        print(f"✓ Reference model loaded: {model_name}")
    
    def snapshot_current_training_model(self, training_model: nn.Module) -> str:
        """
        Create snapshot of current training model weights.
        Useful for dynamic reference model updates.
        """
        snapshot_id = f"snapshot_{int(time.time() * 1000)}"
        # Store only trainable parameters (e.g., LoRA weights)
        self.snapshots[snapshot_id] = {
            k: v.cpu().clone() for k, v in training_model.state_dict().items()
        }
        print(f"✓ Created snapshot: {snapshot_id}")
        return snapshot_id
    
    def load_snapshot(self, snapshot_id: str):
        """Load a previously saved snapshot into reference model"""
        if snapshot_id not in self.snapshots:
            raise ValueError(f"Snapshot {snapshot_id} not found")
        
        self.model.load_state_dict(self.snapshots[snapshot_id], strict=False)
        print(f"✓ Loaded snapshot: {snapshot_id}")
    
    def compute_log_probs(
        self,
        prompts: List[str],
        responses: List[str],
        snapshot_id: Optional[str] = None,
    ) -> torch.Tensor:
        """
        Compute log probabilities from reference model.
        Optionally uses a specific snapshot.
        """
        # Load snapshot if specified
        if snapshot_id and snapshot_id in self.snapshots:
            original_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            self.model.load_state_dict(self.snapshots[snapshot_id], strict=False)
        
        # Compute log probs
        full_texts = [p + r for p, r in zip(prompts, responses)]
        prompt_lens = [len(self.tokenizer.encode(p)) for p in prompts]
        
        full_encodings = self.tokenizer(full_texts, padding=True, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in full_encodings.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
        
        log_probs_list = []
        for i in range(len(prompts)):
            start_idx = prompt_lens[i] - 1
            response_logits = logits[i, start_idx:-1]
            response_ids = full_encodings["input_ids"][i, prompt_lens[i]:]
            mask = full_encodings["attention_mask"][i, prompt_lens[i]:]
            
            log_probs = F.log_softmax(response_logits, dim=-1)
            token_log_probs = log_probs.gather(1, response_ids.unsqueeze(1).to(self.device)).squeeze()
            masked_log_probs = token_log_probs * mask.to(self.device).float()
            log_probs_list.append(masked_log_probs.sum())
        
        # Restore original state if snapshot was used
        if snapshot_id and snapshot_id in self.snapshots:
            self.model.load_state_dict(original_state, strict=False)
        
        return torch.stack(log_probs_list)
    
    def compute_kl_divergence(
        self,
        prompts: List[str],
        responses: List[str],
        train_model: nn.Module,
        train_tokenizer,
        snapshot_id: Optional[str] = None,
        per_token: bool = False,
    ) -> torch.Tensor:
        """
        Compute KL divergence: KL(π_train || π_ref)
        
        Args:
            per_token: If True, return per-token KL. Otherwise, sum over sequence.
        """
        # Get reference log probs
        ref_log_probs = self.compute_log_probs(prompts, responses, snapshot_id)
        
        # Get training model log probs (with gradients if needed)
        full_texts = [p + r for p, r in zip(prompts, responses)]
        prompt_lens = [len(train_tokenizer.encode(p)) for p in prompts]
        
        full_encodings = train_tokenizer(full_texts, padding=True, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in full_encodings.items()}
        
        train_outputs = train_model(**inputs)
        train_logits = train_outputs.logits
        
        train_log_probs_list = []
        for i in range(len(prompts)):
            start_idx = prompt_lens[i] - 1
            response_logits = train_logits[i, start_idx:-1]
            response_ids = full_encodings["input_ids"][i, prompt_lens[i]:]
            mask = full_encodings["attention_mask"][i, prompt_lens[i]:]
            
            log_probs = F.log_softmax(response_logits, dim=-1)
            token_log_probs = log_probs.gather(1, response_ids.unsqueeze(1).to(self.device)).squeeze()
            masked_log_probs = token_log_probs * mask.to(self.device).float()
            train_log_probs_list.append(masked_log_probs.sum())
        
        train_log_probs = torch.stack(train_log_probs_list)
        
        # KL divergence: π_train * (log π_train - log π_ref)
        # Approximation: (π_train / π_ref).log() ≈ train_log_probs - ref_log_probs
        kl_div = train_log_probs - ref_log_probs
        
        return kl_div if per_token else kl_div.mean()
    
    def clear_snapshots(self):
        """Free memory by clearing old snapshots"""
        self.snapshots.clear()
        torch.cuda.empty_cache()
        print("✓ Cleared all snapshots")
    
    def shutdown(self):
        del self.model
        self.snapshots.clear()
        torch.cuda.empty_cache()
        print("✓ Reference model manager shutdown")

print("✓ Reference model manager defined")

# Multi-Objective Reward System with Pareto Tracking

In [None]:
class ParetoRewardTracker:
    """
    Track multi-objective rewards and maintain Pareto frontier.
    Plug in your own reward functions!
    """
    
    def __init__(self, objective_names: List[str]):
        self.objectives = objective_names
        self.num_objectives = len(objective_names)
        self.pareto_frontier: List[np.ndarray] = []
        self.all_rewards: List[np.ndarray] = []
        
        # YOU OVERRIDE THESE - map objective name to function
        self.reward_functions: Dict[str, Callable[[RolloutState], float]] = {}
    
    def register_reward_function(self, name: str, fn: Callable[[RolloutState], float]):
        """Register a reward function for an objective"""
        if name not in self.objectives:
            raise ValueError(f"Objective {name} not in {self.objectives}")
        self.reward_functions[name] = fn
    
    def compute_rewards(self, state: RolloutState) -> np.ndarray:
        """Compute reward vector for a completed rollout"""
        rewards = np.zeros(self.num_objectives)
        
        for i, obj_name in enumerate(self.objectives):
            if obj_name in self.reward_functions:
                rewards[i] = self.reward_functions[obj_name](state)
            else:
                # Default: use cumulative reward
                rewards[i] = state.cumulative_reward[i] if i < len(state.cumulative_reward) else 0.0
        
        return rewards
    
    def update_pareto_frontier(self, reward_vector: np.ndarray):
        """Update Pareto frontier with new reward vector"""
        self.all_rewards.append(reward_vector)
        
        # Check if new point is dominated
        is_dominated = False
        for frontier_point in self.pareto_frontier:
            if self._dominates(frontier_point, reward_vector):
                is_dominated = True
                break
        
        if not is_dominated:
            # Remove points dominated by new vector
            self.pareto_frontier = [
                p for p in self.pareto_frontier
                if not self._dominates(reward_vector, p)
            ]
            self.pareto_frontier.append(reward_vector)
    
    def _dominates(self, a: np.ndarray, b: np.ndarray) -> bool:
        """Check if a dominates b (a is better in all objectives)"""
        return np.all(a >= b) and np.any(a > b)
    
    def is_on_pareto_frontier(self, reward_vector: np.ndarray) -> bool:
        """Check if point is on Pareto frontier"""
        for p in self.pareto_frontier:
            if np.allclose(p, reward_vector):
                return True
        return False
    
    def get_pareto_weights(self, strategy: str = "random") -> np.ndarray:
        """
        Get weight vector for scalarizing multi-objective rewards.
        Used in training to sample different trade-offs.
        """
        if strategy == "random":
            # Random scalarization (linear scalarization)
            return np.random.dirichlet(np.ones(self.num_objectives))
        elif strategy == "uniform":
            return np.ones(self.num_objectives) / self.num_objectives
        elif strategy == "maximize_first":
            weights = np.zeros(self.num_objectives)
            weights[0] = 1.0
            return weights
        else:
            raise ValueError(f"Unknown strategy: {strategy}")
    
    def scalarize_reward(self, reward_vector: np.ndarray, weights: Optional[np.ndarray] = None) -> float:
        """Convert multi-objective reward to scalar"""
        if weights is None:
            weights = np.ones(self.num_objectives) / self.num_objectives
        return float(np.dot(weights, reward_vector))
    
    def get_frontier_statistics(self) -> Dict:
        """Get statistics about Pareto frontier"""
        if not self.pareto_frontier:
            return {"size": 0}
        
        frontier_array = np.array(self.pareto_frontier)
        return {
            "size": len(self.pareto_frontier),
            "mean": frontier_array.mean(axis=0).tolist(),
            "std": frontier_array.std(axis=0).tolist(),
            "min": frontier_array.min(axis=0).tolist(),
            "max": frontier_array.max(axis=0).tolist(),
        }

print("✓ ParetoRewardTracker defined")

# SVRL Environment Interface

In [None]:
class SVRLEnvironment:
    """
    Self-Verifying RL Environment.
    Model can query environment for information at a cost.
    You implement the verification logic!
    """
    
    def __init__(
        self,
        cost_function: Optional[Callable[[str, RolloutState], float]] = None,
        verification_handler: Optional[Callable[[str], Any]] = None,
    ):
        # YOU OVERRIDE THESE
        self.cost_fn = cost_function or self._default_cost
        self.verification_handler = verification_handler or self._default_verification
        
        # Tracking
        self.total_queries = 0
        self.total_cost_spent = 0.0
        self.query_history: List[Dict] = []
    
    def _default_cost(self, query: str, state: RolloutState) -> float:
        """Default cost: based on query length"""
        return len(query.split()) * 0.1
    
    def _default_verification(self, query: str) -> str:
        """Default verification: placeholder"""
        return f"[Verification result for: {query}]"
    
    def query_environment(
        self,
        query: str,
        state: RolloutState,
    ) -> Tuple[Optional[Any], float, bool]:
        """
        Model queries environment for verification.
        
        Returns:
            result: Verification result (None if over budget)
            cost: Cost of this query
            success: Whether query was successful
        """
        cost = self.cost_fn(query, state)
        
        if state.can_verify(cost):
            state.verification_budget_spent += cost
            result = self.verification_handler(query)
            
            # Track query
            self.total_queries += 1
            self.total_cost_spent += cost
            self.query_history.append({
                "rollout_id": state.rollout_id,
                "step": state.step_count,
                "query": query,
                "cost": cost,
                "budget_remaining": state.max_verification_budget - state.verification_budget_spent,
            })
            
            return result, cost, True
        else:
            return None, cost, False
    
    def parse_verification_request(self, model_output: str) -> Optional[str]:
        """
        Parse model output to extract verification request.
        You can customize this format!
        
        Default format: VERIFY: <query>
        """
        if "VERIFY:" in model_output:
            parts = model_output.split("VERIFY:", 1)
            if len(parts) > 1:
                query = parts[1].strip()
                # Extract until newline or end
                query = query.split("\n")[0].strip()
                return query
        return None
    
    def get_statistics(self) -> Dict:
        """Get environment statistics"""
        return {
            "total_queries": self.total_queries,
            "total_cost": self.total_cost_spent,
            "avg_cost_per_query": self.total_cost_spent / max(1, self.total_queries),
            "recent_queries": self.query_history[-10:],
        }
    
    def reset_tracking(self):
        """Reset tracking for new experiment"""
        self.total_queries = 0
        self.total_cost_spent = 0.0
        self.query_history.clear()

print("✓ SVRLEnvironment defined")

# Efficient Trajectory buffer lengths

In [None]:
class TrajectoryBuffer:
    """
    Buffer for completed trajectories with various sampling strategies.
    Supports priority sampling for RL training.
    """
    
    def __init__(
        self,
        max_buffer_size: int = 10000,
        num_objectives: int = 1,
    ):
        self.max_size = max_buffer_size
        self.num_objectives = num_objectives
        self.trajectories: deque = deque(maxlen=max_buffer_size)
        self.step_statistics: Dict[int, List[np.ndarray]] = defaultdict(list)
        
        # Optional: Pareto tracker for priority sampling
        self.pareto_tracker: Optional[ParetoRewardTracker] = None
    
    def add_trajectory(self, state: RolloutState):
        """Add completed trajectory to buffer"""
        self.trajectories.append(state)
        
        # Track statistics by trajectory length
        self.step_statistics[state.step_count].append(state.cumulative_reward)
        
        # Update Pareto frontier if tracker available
        if self.pareto_tracker:
            self.pareto_tracker.update_pareto_frontier(state.cumulative_reward)
    
    def add_batch(self, states: List[RolloutState]):
        """Add batch of completed trajectories"""
        for state in states:
            self.add_trajectory(state)
    
    def sample_batch(
        self,
        batch_size: int,
        strategy: str = "uniform",
    ) -> List[RolloutState]:
        """
        Sample batch of trajectories for training.
        
        Strategies:
        - uniform: Random sampling
        - completion_weighted: Prefer longer trajectories
        - pareto_weighted: Prefer Pareto-optimal trajectories
        - reward_weighted: Prefer high-reward trajectories
        """
        if len(self.trajectories) == 0:
            return []
        
        batch_size = min(batch_size, len(self.trajectories))
        
        if strategy == "uniform":
            return random.sample(list(self.trajectories), batch_size)
        
        elif strategy == "completion_weighted":
            # Prefer longer trajectories
            weights = np.array([len(t.trajectory) for t in self.trajectories], dtype=float)
            weights = weights / weights.sum()
            indices = np.random.choice(len(self.trajectories), size=batch_size, p=weights, replace=False)
            return [list(self.trajectories)[i] for i in indices]
        
        elif strategy == "pareto_weighted":
            # Prefer Pareto-optimal trajectories
            if not self.pareto_tracker:
                return self.sample_batch(batch_size, "uniform")
            
            weights = np.array([
                2.0 if self.pareto_tracker.is_on_pareto_frontier(t.cumulative_reward) else 1.0
                for t in self.trajectories
            ])
            weights = weights / weights.sum()
            indices = np.random.choice(len(self.trajectories), size=batch_size, p=weights, replace=False)
            return [list(self.trajectories)[i] for i in indices]
        
        elif strategy == "reward_weighted":
            # Prefer high-reward trajectories (sum across objectives)
            weights = np.array([t.cumulative_reward.sum() for t in self.trajectories], dtype=float)
            weights = weights - weights.min() + 1e-6  # Shift to positive
            weights = weights / weights.sum()
            indices = np.random.choice(len(self.trajectories), size=batch_size, p=weights, replace=False)
            return [list(self.trajectories)[i] for i in indices]
        
        else:
            raise ValueError(f"Unknown sampling strategy: {strategy}")
    
    def prepare_training_batch(
        self,
        trajectories: List[RolloutState],
        weights: Optional[np.ndarray] = None,
    ) -> TrainingBatch:
        """
        Convert trajectories to training batch format.
        Flattens all steps for batch processing.
        """
        prompts = []
        responses = []
        token_ids = []
        rewards = []
        
        for traj in trajectories:
            for step in traj.trajectory:
                prompts.append(step.prompt)
                responses.append(step.response.text)
                token_ids.append(step.response.token_ids)
                
                # Scalarize reward if multi-objective
                if weights is not None and step.reward is not None:
                    scalar_reward = np.dot(weights, step.reward)
                else:
                    scalar_reward = step.reward.sum() if step.reward is not None else 0.0
                rewards.append(scalar_reward)
        
        return TrainingBatch(
            prompts=prompts,
            responses=responses,
            token_ids=token_ids,
            rewards=np.array(rewards),
        )
    
    def get_statistics(self) -> Dict:
        """Get buffer statistics"""
        if not self.trajectories:
            return {"size": 0}
        
        all_lengths = [len(t.trajectory) for t in self.trajectories]
        all_rewards = [t.cumulative_reward for t in self.trajectories]
        
        return {
            "size": len(self.trajectories),
            "avg_length": np.mean(all_lengths),
            "std_length": np.std(all_lengths),
            "min_length": min(all_lengths),
            "max_length": max(all_lengths),
            "avg_reward": np.mean(all_rewards, axis=0).tolist() if all_rewards else [],
            "step_distribution": {k: len(v) for k, v in self.step_statistics.items()},
        }
    
    def clear(self):
        """Clear buffer"""
        self.trajectories.clear()
        self.step_statistics.clear()

print("✓ TrajectoryBuffer defined")

# training loop with early stopping

In [None]:
# =============================================================================
# PURE FUNCTIONS: Core RL Primitives (Stateless, Composable)
# =============================================================================

def compute_sequence_log_probs(
    model: nn.Module,
    tokenizer,
    prompts: List[str],
    responses: List[str],
    device: str = "cuda",
) -> torch.Tensor:
    """
    Compute log probabilities of response tokens given prompts.
    Pure function - no side effects.
    
    Returns: Tensor of shape (batch_size,) with total log prob per sequence
    """
    full_texts = [p + r for p, r in zip(prompts, responses)]
    prompt_lens = [len(tokenizer.encode(p)) for p in prompts]
    
    full_encodings = tokenizer(full_texts, padding=True, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in full_encodings.items()}
    
    outputs = model(**inputs)
    logits = outputs.logits  # (batch, seq_len, vocab)
    
    log_probs_list = []
    for i in range(len(prompts)):
        start_idx = prompt_lens[i] - 1
        response_logits = logits[i, start_idx:-1]
        response_ids = full_encodings["input_ids"][i, prompt_lens[i]:]
        mask = full_encodings["attention_mask"][i, prompt_lens[i]:]
        
        log_probs = F.log_softmax(response_logits, dim=-1)
        token_log_probs = log_probs.gather(1, response_ids.unsqueeze(1).to(device)).squeeze(-1)
        masked_log_probs = token_log_probs * mask.to(device).float()
        log_probs_list.append(masked_log_probs.sum())
    
    return torch.stack(log_probs_list)


def compute_kl_divergence(
    train_log_probs: torch.Tensor,
    ref_log_probs: torch.Tensor,
    reduction: str = "mean",
) -> torch.Tensor:
    """
    Compute KL divergence from log probabilities.
    Pure function.
    
    KL(π_train || π_ref) ≈ train_log_probs - ref_log_probs
    """
    kl_div = train_log_probs - ref_log_probs
    
    if reduction == "mean":
        return kl_div.mean()
    elif reduction == "sum":
        return kl_div.sum()
    elif reduction == "none":
        return kl_div
    else:
        raise ValueError(f"Unknown reduction: {reduction}")


def reinforce_loss(
    log_probs: torch.Tensor,
    rewards: torch.Tensor,
    baseline: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    REINFORCE policy gradient loss.
    Pure function.
    
    Loss = -E[reward * log_prob]
    """
    if baseline is not None:
        advantages = rewards - baseline
    else:
        advantages = rewards
    
    return -(advantages * log_probs).mean()


def ppo_loss(
    log_probs: torch.Tensor,
    old_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    clip_epsilon: float = 0.2,
) -> torch.Tensor:
    """
    PPO clipped surrogate loss.
    Pure function.
    """
    ratio = torch.exp(log_probs - old_log_probs)
    clipped_ratio = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon)
    
    loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
    return loss


def grpo_loss(
    log_probs: torch.Tensor,
    rewards: torch.Tensor,
    group_size: int = 4,
) -> torch.Tensor:
    """
    Group Relative Policy Optimization loss.
    Advantages computed relative to group.
    Pure function.
    """
    # Reshape into groups
    batch_size = log_probs.shape[0]
    num_groups = batch_size // group_size
    
    grouped_rewards = rewards[:num_groups * group_size].view(num_groups, group_size)
    grouped_log_probs = log_probs[:num_groups * group_size].view(num_groups, group_size)
    
    # Compute advantages relative to group mean
    group_means = grouped_rewards.mean(dim=1, keepdim=True)
    group_stds = grouped_rewards.std(dim=1, keepdim=True) + 1e-8
    advantages = (grouped_rewards - group_means) / group_stds
    
    loss = -(advantages * grouped_log_probs).mean()
    return loss


# =============================================================================
# ABSTRACT BASE CLASSES: Extend These
# =============================================================================

class LossFunction(ABC):
    """
    Abstract loss function. Subclass to implement your own.
    """
    
    @abstractmethod
    def __call__(
        self,
        batch: TrainingBatch,
        model: nn.Module,
        tokenizer,
        device: str = "cuda",
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Compute loss and return metrics.
        
        Returns:
            loss: Scalar tensor with gradient
            metrics: Dict of metrics for logging
        """
        pass


class ReinforceLoss(LossFunction):
    """Simple REINFORCE loss"""
    
    def __init__(self, normalize_rewards: bool = True):
        self.normalize_rewards = normalize_rewards
    
    def __call__(self, batch, model, tokenizer, device="cuda"):
        model.train()
        
        log_probs = compute_sequence_log_probs(
            model, tokenizer, batch.prompts, batch.responses, device
        )
        
        rewards = torch.tensor(batch.rewards, dtype=torch.float32, device=device)
        
        if self.normalize_rewards:
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        
        loss = reinforce_loss(log_probs, rewards)
        
        return loss, {
            "pg_loss": loss.item(),
            "avg_reward": float(batch.rewards.mean()),
            "reward_std": float(batch.rewards.std()),
        }


class PPOLoss(LossFunction):
    """PPO with clipping"""
    
    def __init__(self, clip_epsilon: float = 0.2):
        self.clip_epsilon = clip_epsilon
    
    def __call__(self, batch, model, tokenizer, device="cuda"):
        model.train()
        
        log_probs = compute_sequence_log_probs(
            model, tokenizer, batch.prompts, batch.responses, device
        )
        
        if batch.old_logprobs is None:
            raise ValueError("PPO requires old_logprobs in batch")
        if batch.advantages is None:
            raise ValueError("PPO requires advantages in batch")
        
        loss = ppo_loss(log_probs, batch.old_logprobs, batch.advantages, self.clip_epsilon)
        
        return loss, {
            "ppo_loss": loss.item(),
            "avg_advantage": float(batch.advantages.mean()),
        }


class GRPOLoss(LossFunction):
    """Group Relative Policy Optimization"""
    
    def __init__(self, group_size: int = 4):
        self.group_size = group_size
    
    def __call__(self, batch, model, tokenizer, device="cuda"):
        model.train()
        
        log_probs = compute_sequence_log_probs(
            model, tokenizer, batch.prompts, batch.responses, device
        )
        
        rewards = torch.tensor(batch.rewards, dtype=torch.float32, device=device)
        loss = grpo_loss(log_probs, rewards, self.group_size)
        
        return loss, {
            "grpo_loss": loss.item(),
            "avg_reward": float(batch.rewards.mean()),
        }


class TerminationCondition(ABC):
    """
    Abstract termination condition. Subclass to implement your own.
    """
    
    @abstractmethod
    def should_terminate(self, state: RolloutState) -> Tuple[bool, Optional[str]]:
        """
        Check if rollout should terminate.
        
        Returns:
            should_stop: bool
            reason: Optional reason string
        """
        pass


class MaxStepsTermination(TerminationCondition):
    """Terminate after max steps"""
    
    def __init__(self, max_steps: int):
        self.max_steps = max_steps
    
    def should_terminate(self, state):
        if state.step_count >= self.max_steps:
            return True, "max_steps"
        return False, None


class KeywordTermination(TerminationCondition):
    """Terminate on specific keywords in response"""
    
    def __init__(self, keywords: List[str]):
        self.keywords = keywords
    
    def should_terminate(self, state):
        if not state.trajectory:
            return False, None
        
        last_text = state.trajectory[-1].response.text
        for kw in self.keywords:
            if kw in last_text:
                return True, f"keyword:{kw}"
        return False, None


class CompositeTermination(TerminationCondition):
    """Combine multiple termination conditions (OR logic)"""
    
    def __init__(self, conditions: List[TerminationCondition]):
        self.conditions = conditions
    
    def should_terminate(self, state):
        for cond in self.conditions:
            should_stop, reason = cond.should_terminate(state)
            if should_stop:
                return True, reason
        return False, None


class RewardFunction(ABC):
    """
    Abstract reward function. Subclass to implement your own.
    """
    
    @abstractmethod
    def compute_reward(self, state: RolloutState, step: StepData) -> np.ndarray:
        """
        Compute reward vector for a step.
        
        Returns:
            reward: np.ndarray of shape (num_objectives,)
        """
        pass


class SparseReward(RewardFunction):
    """Only give reward at end of trajectory"""
    
    def __init__(self, final_reward_fn: Callable[[RolloutState], np.ndarray]):
        self.final_reward_fn = final_reward_fn
    
    def compute_reward(self, state, step):
        if state.terminated:
            return self.final_reward_fn(state)
        return np.zeros_like(state.cumulative_reward)


class StepPenaltyReward(RewardFunction):
    """Small penalty for each step (encourages efficiency)"""
    
    def __init__(self, penalty: float = -0.01, num_objectives: int = 1):
        self.penalty = penalty
        self.num_objectives = num_objectives
    
    def compute_reward(self, state, step):
        reward = np.zeros(self.num_objectives)
        reward[0] = self.penalty  # Penalty in first objective
        return reward


class CompositeReward(RewardFunction):
    """Combine multiple reward functions"""
    
    def __init__(self, reward_fns: List[RewardFunction]):
        self.reward_fns = reward_fns
    
    def compute_reward(self, state, step):
        rewards = [fn.compute_reward(state, step) for fn in self.reward_fns]
        return sum(rewards)


# =============================================================================
# COMPOSABLE ROLLOUT COLLECTOR: Inject Your Dependencies
# =============================================================================

def collect_single_rollout(
    backend: InferenceBackend,
    initial_prompt: str,
    termination_condition: TerminationCondition,
    reward_function: RewardFunction,
    max_steps: int = 50,
    num_objectives: int = 1,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    **gen_kwargs,
) -> RolloutState:
    """
    Collect a single complete rollout.
    Pure function - no global state.
    
    You inject:
    - backend: How to generate
    - termination_condition: When to stop
    - reward_function: How to compute rewards
    """
    state = RolloutState(
        rollout_id=str(uuid.uuid4()),
        initial_prompt=initial_prompt,
        cumulative_reward=np.zeros(num_objectives),
    )
    
    for step_num in range(max_steps):
        # Generate response
        context = state.get_full_context()
        results = backend.generate(
            [context],
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            **gen_kwargs,
        )
        gen_result = results[0]
        
        # Create step
        step_data = StepData(
            prompt=context,
            response=gen_result,
        )
        
        # Add step (before checking termination)
        state.add_step(step_data)
        
        # Check termination
        should_stop, reason = termination_condition.should_terminate(state)
        if gen_result.finish_reason == "stop":
            should_stop, reason = True, "eos"
        
        if should_stop:
            state.terminated = True
            state.termination_reason = reason
            # Compute final reward
            step_data.reward = reward_function.compute_reward(state, step_data)
            state.cumulative_reward = state.cumulative_reward + step_data.reward
            break
        else:
            # Compute step reward
            step_data.reward = reward_function.compute_reward(state, step_data)
            state.cumulative_reward = state.cumulative_reward + step_data.reward
    
    # If we hit max steps without termination
    if not state.terminated:
        state.terminated = True
        state.termination_reason = "max_steps"
    
    return state


def collect_rollout_batch(
    backend: InferenceBackend,
    prompts: List[str],
    termination_condition: TerminationCondition,
    reward_function: RewardFunction,
    **kwargs,
) -> List[RolloutState]:
    """
    Collect multiple rollouts (sequential, for simplicity).
    Use AsyncRolloutManager for parallel collection.
    """
    return [
        collect_single_rollout(
            backend, prompt, termination_condition, reward_function, **kwargs
        )
        for prompt in tqdm(prompts, desc="Collecting rollouts")
    ]


# =============================================================================
# TRAINING STEP: Compose Your Own Pipeline
# =============================================================================

def training_step(
    batch: TrainingBatch,
    model: nn.Module,
    tokenizer,
    optimizer: torch.optim.Optimizer,
    loss_fn: LossFunction,
    ref_model: Optional[nn.Module] = None,
    ref_tokenizer = None,
    kl_coef: float = 0.0,
    max_grad_norm: float = 1.0,
    device: str = "cuda",
) -> Dict[str, float]:
    """
    Single training step. Pure function.
    
    You inject:
    - loss_fn: How to compute loss (REINFORCE, PPO, GRPO, custom)
    - ref_model: Optional reference for KL penalty
    - kl_coef: KL penalty coefficient (0 = no penalty)
    """
    # Compute main loss
    loss, metrics = loss_fn(batch, model, tokenizer, device)
    
    # Add KL penalty if requested
    if kl_coef > 0 and ref_model is not None:
        with torch.no_grad():
            ref_log_probs = compute_sequence_log_probs(
                ref_model, ref_tokenizer or tokenizer,
                batch.prompts, batch.responses, device
            )
        
        train_log_probs = compute_sequence_log_probs(
            model, tokenizer, batch.prompts, batch.responses, device
        )
        
        kl_div = compute_kl_divergence(train_log_probs, ref_log_probs)
        loss = loss + kl_coef * kl_div
        metrics["kl_div"] = kl_div.item()
    else:
        metrics["kl_div"] = 0.0
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    
    # Gradient clipping
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    metrics["grad_norm"] = grad_norm.item()
    
    # Step
    optimizer.step()
    
    metrics["total_loss"] = loss.item()
    return metrics


print("✓ Core primitives defined")
print("\\nAvailable:")
print("  Pure functions: compute_sequence_log_probs, compute_kl_divergence, reinforce_loss, ppo_loss, grpo_loss")
print("  Loss classes: ReinforceLoss, PPOLoss, GRPOLoss (or extend LossFunction)")
print("  Termination: MaxStepsTermination, KeywordTermination, CompositeTermination (or extend TerminationCondition)")
print("  Rewards: SparseReward, StepPenaltyReward, CompositeReward (or extend RewardFunction)")
print("  Rollout: collect_single_rollout, collect_rollout_batch")
print("  Training: training_step (compose your own pipeline)")

# Modular Configuration System

# Implementation Summary

## What's Implemented:
- **Dynamic batching**: Groups rollouts by step count for efficient inference
- **Padding masks**: Handled automatically via HuggingFace tokenizers
- **Priority sampling**: Multiple strategies (completion_weighted, pareto_weighted, reward_weighted)
- **Variable-length rollouts**: AsyncRolloutManager handles different trajectory lengths
- **Full parameter updates**: No LoRA - direct gradient computation on all weights

## Key Features:
- **Pareto frontier tracking**: Maintains non-dominated solutions for multi-objective optimization
- **Dynamic reference model**: Snapshot and restore training model weights mid-training
- **KL divergence control**: Optional penalty to keep policy close to reference
- **SVRL environment**: Query environment for verification at a cost
- **Checkpointing**: Save/load training state including Pareto frontier

In [None]:
# =============================================================================
# UTILITY: Visualization and Analysis
# =============================================================================

def plot_pareto_frontier(tracker: ParetoRewardTracker):
    """Plot 2D Pareto frontier"""
    import matplotlib.pyplot as plt
    
    if tracker.num_objectives != 2:
        print("Plotting only supports 2 objectives")
        return
    
    if not tracker.pareto_frontier:
        print("No Pareto frontier points yet")
        return
    
    frontier = np.array(tracker.pareto_frontier)
    all_points = np.array(tracker.all_rewards)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot all points
    ax.scatter(all_points[:, 0], all_points[:, 1], alpha=0.3, label="All trajectories")
    
    # Plot Pareto frontier
    ax.scatter(frontier[:, 0], frontier[:, 1], color="red", s=100, marker="*", 
               label="Pareto frontier", zorder=5)
    
    # Sort frontier for line plot
    sorted_idx = np.argsort(frontier[:, 0])
    ax.plot(frontier[sorted_idx, 0], frontier[sorted_idx, 1], "r--", alpha=0.5)
    
    ax.set_xlabel(tracker.objectives[0])
    ax.set_ylabel(tracker.objectives[1])
    ax.set_title("Pareto Frontier")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.show()


def plot_training_metrics(trainer: RLTrainer):
    """Plot training metrics over time"""
    import matplotlib.pyplot as plt
    
    if not trainer.train_metrics:
        print("No training metrics yet")
        return
    
    df = pd.DataFrame(trainer.train_metrics)
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Loss
    axes[0, 0].plot(df["step"], df["total_loss"], label="Total Loss")
    axes[0, 0].plot(df["step"], df["pg_loss"], label="PG Loss", alpha=0.7)
    axes[0, 0].set_xlabel("Step")
    axes[0, 0].set_ylabel("Loss")
    axes[0, 0].set_title("Training Loss")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # KL Divergence
    axes[0, 1].plot(df["step"], df["kl_div"])
    axes[0, 1].set_xlabel("Step")
    axes[0, 1].set_ylabel("KL Divergence")
    axes[0, 1].set_title("KL Divergence from Reference")
    axes[0, 1].grid(True, alpha=0.3)
    
    # Average Reward
    axes[1, 0].plot(df["step"], df["avg_reward"])
    axes[1, 0].set_xlabel("Step")
    axes[1, 0].set_ylabel("Average Reward")
    axes[1, 0].set_title("Average Reward per Step")
    axes[1, 0].grid(True, alpha=0.3)
    
    # Buffer and Pareto size
    axes[1, 1].plot(df["step"], df["buffer_size"], label="Buffer Size")
    axes[1, 1].plot(df["step"], df["pareto_frontier_size"], label="Pareto Frontier")
    axes[1, 1].set_xlabel("Step")
    axes[1, 1].set_ylabel("Count")
    axes[1, 1].set_title("Buffer & Pareto Statistics")
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


def analyze_trajectory_lengths(buffer: TrajectoryBuffer):
    """Analyze distribution of trajectory lengths"""
    import matplotlib.pyplot as plt
    
    if not buffer.trajectories:
        print("No trajectories in buffer")
        return
    
    lengths = [len(t.trajectory) for t in buffer.trajectories]
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Histogram
    axes[0].hist(lengths, bins=20, edgecolor="black")
    axes[0].axvline(np.mean(lengths), color="red", linestyle="--", label=f"Mean: {np.mean(lengths):.1f}")
    axes[0].set_xlabel("Trajectory Length (steps)")
    axes[0].set_ylabel("Count")
    axes[0].set_title("Distribution of Trajectory Lengths")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Box plot
    axes[1].boxplot(lengths)
    axes[1].set_ylabel("Steps")
    axes[1].set_title("Trajectory Length Statistics")
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Length statistics:")
    print(f"  Mean: {np.mean(lengths):.2f}")
    print(f"  Std: {np.std(lengths):.2f}")
    print(f"  Min: {min(lengths)}")
    print(f"  Max: {max(lengths)}")
    print(f"  Median: {np.median(lengths):.1f}")

print("✓ Visualization utilities defined")

In [None]:
# =============================================================================
# EXAMPLE: Composing Your Own Training Pipeline
# =============================================================================
# These are PRIMITIVES - you compose them however you want

# -------------------------
# 1. CUSTOM REWARD FUNCTION (Extend RewardFunction)
# -------------------------

class MyTaskReward(RewardFunction):
    """Your custom reward logic"""
    
    def __init__(self, num_objectives: int = 2):
        self.num_objectives = num_objectives
    
    def compute_reward(self, state: RolloutState, step: StepData) -> np.ndarray:
        reward = np.zeros(self.num_objectives)
        
        # YOUR LOGIC HERE
        response_text = step.response.text
        
        # Objective 0: Task completion
        if "ANSWER:" in response_text:
            reward[0] = 1.0
        
        # Objective 1: Efficiency (penalty per step)
        reward[1] = -0.01
        
        return reward


# -------------------------
# 2. CUSTOM TERMINATION (Extend TerminationCondition)
# -------------------------

class MyTermination(TerminationCondition):
    """Your custom termination logic"""
    
    def should_terminate(self, state: RolloutState) -> Tuple[bool, Optional[str]]:
        if not state.trajectory:
            return False, None
        
        last_text = state.trajectory[-1].response.text
        
        # YOUR LOGIC HERE
        if "ANSWER:" in last_text:
            return True, "task_complete"
        if "I give up" in last_text:
            return True, "gave_up"
        if state.step_count >= 20:
            return True, "max_steps"
        
        return False, None


# -------------------------
# 3. CUSTOM LOSS FUNCTION (Extend LossFunction)
# -------------------------

class MyCustomLoss(LossFunction):
    """Your custom loss computation"""
    
    def __init__(self, entropy_coef: float = 0.01):
        self.entropy_coef = entropy_coef
    
    def __call__(self, batch, model, tokenizer, device="cuda"):
        model.train()
        
        # Compute log probs using the primitive
        log_probs = compute_sequence_log_probs(
            model, tokenizer, batch.prompts, batch.responses, device
        )
        
        # YOUR LOSS LOGIC HERE
        rewards = torch.tensor(batch.rewards, dtype=torch.float32, device=device)
        
        # Example: REINFORCE with entropy bonus
        pg_loss = reinforce_loss(log_probs, rewards)
        
        # Add entropy term (optional)
        # entropy = ... (compute if needed)
        # loss = pg_loss - self.entropy_coef * entropy
        
        loss = pg_loss
        
        return loss, {
            "my_loss": loss.item(),
            "avg_reward": float(batch.rewards.mean()),
        }


# =============================================================================
# EXAMPLE: Simple Training Loop (Compose the Primitives)
# =============================================================================

def example_training_loop():
    """
    Shows how to compose primitives into a training loop.
    No god-class - you control everything.
    """
    
    # 1. Load model
    MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    
    backend = HuggingFaceBackend(MODEL_NAME, dtype="bfloat16")
    model = backend.model
    tokenizer = backend.tokenizer
    
    # 2. Create optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)
    
    # 3. Your injected dependencies
    termination = CompositeTermination([
        KeywordTermination(["ANSWER:", "DONE"]),
        MaxStepsTermination(20),
    ])
    
    reward_fn = MyTaskReward(num_objectives=2)
    loss_fn = ReinforceLoss(normalize_rewards=True)
    
    # 4. Optional: Reference model for KL
    # ref_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, ...)
    ref_model = None  # Set to None for no KL penalty
    
    # 5. Training buffer
    buffer = TrajectoryBuffer(max_buffer_size=1000, num_objectives=2)
    pareto = ParetoRewardTracker(["completion", "efficiency"])
    buffer.pareto_tracker = pareto
    
    # 6. YOUR TRAINING LOOP - you control it
    for iteration in range(10):
        # Get prompts (YOUR DATA)
        prompts = [
            "Solve: What is 15 + 27? Think step by step, then give ANSWER:",
            "Solve: What is 8 * 9? Think step by step, then give ANSWER:",
        ]
        
        # Collect rollouts using the primitive
        rollouts = collect_rollout_batch(
            backend=backend,
            prompts=prompts,
            termination_condition=termination,
            reward_function=reward_fn,
            max_steps=20,
            num_objectives=2,
            max_new_tokens=128,
            temperature=0.7,
        )
        
        # Add to buffer
        buffer.add_batch(rollouts)
        
        # Sample from buffer
        sampled = buffer.sample_batch(batch_size=4, strategy="uniform")
        
        # Prepare batch
        weights = pareto.get_pareto_weights("random")  # Random scalarization
        train_batch = buffer.prepare_training_batch(sampled, weights)
        
        # Training step using the primitive
        metrics = training_step(
            batch=train_batch,
            model=model,
            tokenizer=tokenizer,
            optimizer=optimizer,
            loss_fn=loss_fn,
            ref_model=ref_model,
            kl_coef=0.0,  # No KL penalty
            max_grad_norm=1.0,
        )
        
        print(f"Iter {iteration}: loss={metrics['total_loss']:.4f}, reward={metrics.get('avg_reward', 0):.4f}")
    
    return model, buffer, pareto


# =============================================================================
# EXAMPLE: SVRL (Self-Verifying RL) with Primitives
# =============================================================================

class SVRLReward(RewardFunction):
    """Reward that accounts for verification cost"""
    
    def __init__(self, env: SVRLEnvironment, correctness_fn: Callable):
        self.env = env
        self.correctness_fn = correctness_fn
    
    def compute_reward(self, state: RolloutState, step: StepData) -> np.ndarray:
        # 3 objectives: correctness, efficiency, verification_cost
        reward = np.zeros(3)
        
        # Check for verification request
        query = self.env.parse_verification_request(step.response.text)
        if query:
            result, cost, success = self.env.query_environment(query, state)
            if success:
                # Penalize verification cost
                reward[2] = -cost
        
        # Final rewards on termination
        if state.terminated:
            reward[0] = self.correctness_fn(state)  # Task correctness
            reward[1] = -state.step_count * 0.1  # Efficiency
        
        return reward


def example_svrl():
    """SVRL experiment using primitives"""
    
    # Your verification environment
    def verify_math(query: str) -> str:
        try:
            return f"Result: {eval(query)}"
        except:
            return "Error: Invalid expression"
    
    svrl_env = SVRLEnvironment(
        cost_function=lambda q, s: 1.0 + s.step_count * 0.1,
        verification_handler=verify_math,
    )
    
    def check_correctness(state: RolloutState) -> float:
        # YOUR LOGIC: check if final answer is correct
        return 1.0 if "42" in state.get_full_context() else 0.0
    
    reward_fn = SVRLReward(svrl_env, check_correctness)
    
    # Use with collect_single_rollout...
    return svrl_env, reward_fn


# =============================================================================
# EXAMPLE: PPO Training (Just Change the Loss Function)
# =============================================================================

def example_ppo_training():
    """PPO training - just swap the loss function"""
    
    # ... setup as before ...
    
    # Use PPO loss instead
    loss_fn = PPOLoss(clip_epsilon=0.2)
    
    # You need to track old_logprobs and compute advantages
    # This is where you'd add value network, GAE, etc.
    
    # Same training loop, different loss function
    pass


# =============================================================================
# EXAMPLE: No KL Penalty (Just Set kl_coef=0)
# =============================================================================

def example_no_kl():
    """Training without KL penalty"""
    
    # Just call training_step with kl_coef=0
    # metrics = training_step(batch, model, tokenizer, optimizer, loss_fn, kl_coef=0.0)
    pass


# =============================================================================
# EXAMPLE: Dynamic Reference Model
# =============================================================================

def example_dynamic_ref():
    """Update reference model during training"""
    
    # Clone weights periodically
    # ref_model.load_state_dict(train_model.state_dict())
    
    # Or use ReferenceModelManager for snapshots
    pass


print("✓ Composition examples defined")
print("\\nKey insight: No monolithic trainer class!")
print("You compose these primitives:")
print("  1. collect_single_rollout(backend, prompt, termination, reward, ...)")
print("  2. buffer.sample_batch(...)")
print("  3. buffer.prepare_training_batch(trajectories, weights)")
print("  4. training_step(batch, model, tokenizer, optimizer, loss_fn, ...)")
print("\\nSwap any component:")
print("  - Different loss? Use PPOLoss, GRPOLoss, or write your own")
print("  - Different termination? Extend TerminationCondition")
print("  - Different rewards? Extend RewardFunction")
print("  - No KL penalty? Set kl_coef=0")
print("  - Different sampling? Use buffer.sample_batch with different strategy")

# Thinking Token Handling (Qwen, DeepSeek, etc.)
Models with reasoning/thinking tokens need special handling in multi-turn rollouts:
- Thinking tokens are generated but NOT sent back in the next turn
- Only the actual response is added to conversation history

In [None]:
# =============================================================================
# THINKING TOKEN HANDLING
# =============================================================================

# Qwen thinking tokens
QWEN_THINK_START_TOKEN = 151667  # <think>
QWEN_THINK_END_TOKEN = 151668    # </think>

@dataclass
class ThinkingResponse:
    """Response split into thinking and content parts"""
    thinking: str
    content: str
    full_text: str
    thinking_token_ids: List[int]
    content_token_ids: List[int]


def parse_thinking_tokens(
    output_ids: List[int],
    tokenizer,
    think_end_token: int = QWEN_THINK_END_TOKEN,
) -> ThinkingResponse:
    """
    Parse model output to separate thinking from response.
    Critical for multi-turn: only content goes back to model.
    
    Args:
        output_ids: Generated token IDs
        tokenizer: Tokenizer for decoding
        think_end_token: Token ID marking end of thinking
    
    Returns:
        ThinkingResponse with separated thinking and content
    """
    if isinstance(output_ids, torch.Tensor):
        output_ids = output_ids.tolist()
    
    # Find thinking end token (search backwards for robustness)
    try:
        reversed_ids = output_ids[::-1]
        reverse_idx = reversed_ids.index(think_end_token)
        split_idx = len(output_ids) - reverse_idx
    except ValueError:
        # No thinking token found - entire output is content
        split_idx = 0
    
    thinking_ids = output_ids[:split_idx]
    content_ids = output_ids[split_idx:]
    
    thinking_text = tokenizer.decode(thinking_ids, skip_special_tokens=True).strip()
    content_text = tokenizer.decode(content_ids, skip_special_tokens=True).strip()
    full_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    return ThinkingResponse(
        thinking=thinking_text,
        content=content_text,
        full_text=full_text,
        thinking_token_ids=thinking_ids,
        content_token_ids=content_ids,
    )


def apply_chat_template(
    messages: List[Dict[str, str]],
    tokenizer,
    enable_thinking: bool = True,
    add_generation_prompt: bool = True,
) -> str:
    """
    Apply chat template to messages.
    For models with thinking support, enables thinking mode.
    
    Args:
        messages: List of {"role": "user/assistant/system", "content": "..."}
        tokenizer: Tokenizer with chat template
        enable_thinking: Enable reasoning/thinking for supported models
        add_generation_prompt: Add generation prompt at end
    
    Returns:
        Formatted prompt string
    """
    if not hasattr(tokenizer, 'apply_chat_template'):
        # Fallback for tokenizers without chat template
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
    # Try to pass enable_thinking if tokenizer supports it
    try:
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
            enable_thinking=enable_thinking,
        )
    except TypeError:
        # Tokenizer doesn't support enable_thinking parameter
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
        )


@dataclass
class Message:
    """Single message in conversation"""
    role: str  # "user", "assistant", "system", "tool"
    content: str
    thinking: Optional[str] = None  # Stored but not sent back
    tool_calls: Optional[List[Dict]] = None
    tool_call_id: Optional[str] = None
    name: Optional[str] = None  # For tool responses
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dict for chat template (without thinking)"""
        d = {"role": self.role, "content": self.content}
        if self.tool_calls:
            d["tool_calls"] = self.tool_calls
        if self.tool_call_id:
            d["tool_call_id"] = self.tool_call_id
        if self.name:
            d["name"] = self.name
        return d


@dataclass
class Conversation:
    """Conversation history with thinking separation"""
    messages: List[Message] = field(default_factory=list)
    
    def add_user_message(self, content: str) -> 'Conversation':
        """Add user message"""
        new_msgs = self.messages + [Message(role="user", content=content)]
        return Conversation(messages=new_msgs)
    
    def add_assistant_message(
        self,
        content: str,
        thinking: Optional[str] = None,
        tool_calls: Optional[List[Dict]] = None,
    ) -> 'Conversation':
        """Add assistant message (content only, thinking stored separately)"""
        msg = Message(
            role="assistant",
            content=content,
            thinking=thinking,
            tool_calls=tool_calls,
        )
        new_msgs = self.messages + [msg]
        return Conversation(messages=new_msgs)
    
    def add_tool_result(
        self,
        tool_call_id: str,
        name: str,
        content: str,
    ) -> 'Conversation':
        """Add tool/function result"""
        msg = Message(
            role="tool",
            content=content,
            tool_call_id=tool_call_id,
            name=name,
        )
        new_msgs = self.messages + [msg]
        return Conversation(messages=new_msgs)
    
    def to_messages_list(self) -> List[Dict[str, Any]]:
        """Convert to list of dicts for chat template (NO thinking)"""
        return [m.to_dict() for m in self.messages]
    
    def format_prompt(
        self,
        tokenizer,
        enable_thinking: bool = True,
    ) -> str:
        """Format conversation as prompt string"""
        return apply_chat_template(
            self.to_messages_list(),
            tokenizer,
            enable_thinking=enable_thinking,
        )
    
    def get_full_text_with_thinking(self) -> str:
        """Get full text including thinking (for analysis)"""
        parts = []
        for m in self.messages:
            if m.thinking:
                parts.append(f"{m.role}: <think>{m.thinking}</think>{m.content}")
            else:
                parts.append(f"{m.role}: {m.content}")
        return "\n".join(parts)


def generate_with_thinking(
    backend: InferenceBackend,
    conversation: Conversation,
    tokenizer,
    enable_thinking: bool = True,
    max_new_tokens: int = 512,
    temperature: float = 0.6,  # Lower for focused reasoning
    **gen_kwargs,
) -> Tuple[ThinkingResponse, GenerationResult]:
    """
    Generate response with thinking token parsing.
    
    Returns:
        thinking_response: Parsed thinking and content
        gen_result: Full generation result with all metadata
    """
    # Format prompt with chat template
    prompt = conversation.format_prompt(tokenizer, enable_thinking=enable_thinking)
    
    # Generate
    results = backend.generate(
        [prompt],
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        return_logprobs=True,  # Important for RL
        **gen_kwargs,
    )
    gen_result = results[0]
    
    # Parse thinking tokens
    thinking_response = parse_thinking_tokens(
        gen_result.token_ids,
        tokenizer,
    )
    
    return thinking_response, gen_result


def collect_multiturn_rollout(
    backend: InferenceBackend,
    tokenizer,
    initial_messages: List[Dict[str, str]],
    termination_condition: TerminationCondition,
    reward_function: RewardFunction,
    max_turns: int = 10,
    enable_thinking: bool = True,
    max_new_tokens: int = 512,
    temperature: float = 0.6,
    num_objectives: int = 1,
    **gen_kwargs,
) -> RolloutState:
    """
    Collect multi-turn rollout with thinking token handling.
    
    Key: Only content (not thinking) is added to conversation for next turn.
    Thinking is stored in metadata for analysis.
    """
    # Initialize conversation
    conversation = Conversation(
        messages=[Message(**m) for m in initial_messages]
    )
    
    # Initialize rollout state
    state = RolloutState(
        rollout_id=str(uuid.uuid4()),
        initial_prompt=conversation.format_prompt(tokenizer, enable_thinking),
        cumulative_reward=np.zeros(num_objectives),
        metadata={"conversation": conversation, "enable_thinking": enable_thinking},
    )
    
    for turn in range(max_turns):
        # Generate with thinking
        thinking_resp, gen_result = generate_with_thinking(
            backend,
            conversation,
            tokenizer,
            enable_thinking=enable_thinking,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            **gen_kwargs,
        )
        
        # Add ONLY content to conversation (not thinking!)
        conversation = conversation.add_assistant_message(
            content=thinking_resp.content,
            thinking=thinking_resp.thinking,
        )
        
        # Create step data with thinking info
        step_data = StepData(
            prompt=state.get_full_context(),
            response=gen_result,
            metadata={
                "thinking": thinking_resp.thinking,
                "content": thinking_resp.content,
                "turn": turn,
            },
        )
        
        # Add step
        state.add_step(step_data)
        state.metadata["conversation"] = conversation
        
        # Check termination
        should_stop, reason = termination_condition.should_terminate(state)
        if gen_result.finish_reason == "stop":
            should_stop, reason = True, "eos"
        
        if should_stop:
            state.terminated = True
            state.termination_reason = reason
            step_data.reward = reward_function.compute_reward(state, step_data)
            state.cumulative_reward = state.cumulative_reward + step_data.reward
            break
        else:
            step_data.reward = reward_function.compute_reward(state, step_data)
            state.cumulative_reward = state.cumulative_reward + step_data.reward
    
    if not state.terminated:
        state.terminated = True
        state.termination_reason = "max_turns"
    
    return state


print("✓ Thinking token handling defined")
print("  - parse_thinking_tokens(output_ids, tokenizer)")
print("  - apply_chat_template(messages, tokenizer, enable_thinking)")
print("  - Conversation class for multi-turn with thinking separation")
print("  - generate_with_thinking(backend, conversation, tokenizer)")
print("  - collect_multiturn_rollout(...) for complete rollout with thinking")

# Tool/Function Calling Support
Enable LLMs to call tools/functions during rollouts. Supports vLLM and SGLang with model-specific parsers (Qwen, Llama, DeepSeek).

In [None]:
# =============================================================================
# TOOL/FUNCTION CALLING SUPPORT
# =============================================================================

@dataclass
class ToolDefinition:
    """Definition of a tool/function the model can call"""
    name: str
    description: str
    parameters: Dict[str, Any]  # JSON Schema
    
    def to_openai_format(self) -> Dict[str, Any]:
        """Convert to OpenAI-compatible format"""
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.parameters,
            }
        }


@dataclass
class ToolCall:
    """A tool call made by the model"""
    id: str
    name: str
    arguments: Dict[str, Any]
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "id": self.id,
            "type": "function",
            "function": {
                "name": self.name,
                "arguments": json.dumps(self.arguments),
            }
        }


@dataclass
class ToolResult:
    """Result from executing a tool"""
    tool_call_id: str
    name: str
    content: str
    success: bool = True
    error: Optional[str] = None


class ToolExecutor(ABC):
    """Abstract tool executor. Implement your tools here."""
    
    @abstractmethod
    def execute(self, name: str, arguments: Dict[str, Any]) -> str:
        """Execute tool and return result string"""
        pass
    
    @abstractmethod
    def get_tools(self) -> List[ToolDefinition]:
        """Get list of available tools"""
        pass


class SimpleToolExecutor(ToolExecutor):
    """Simple tool executor with registered functions"""
    
    def __init__(self):
        self.tools: Dict[str, ToolDefinition] = {}
        self.handlers: Dict[str, Callable] = {}
    
    def register_tool(
        self,
        name: str,
        description: str,
        parameters: Dict[str, Any],
        handler: Callable[[Dict], str],
    ):
        """Register a tool"""
        self.tools[name] = ToolDefinition(name, description, parameters)
        self.handlers[name] = handler
    
    def execute(self, name: str, arguments: Dict[str, Any]) -> str:
        if name not in self.handlers:
            return f"Error: Unknown tool '{name}'"
        
        try:
            return self.handlers[name](arguments)
        except Exception as e:
            return f"Error executing {name}: {str(e)}"
    
    def get_tools(self) -> List[ToolDefinition]:
        return list(self.tools.values())


def parse_tool_calls_qwen(text: str) -> List[ToolCall]:
    """
    Parse tool calls from Qwen model output.
    
    Qwen format:
    <|plugin|>
    {"name": "tool_name", "parameters": {...}}
    """
    calls = []
    
    if "<|plugin|>" in text:
        parts = text.split("<|plugin|>")
        for i, part in enumerate(parts[1:], start=1):  # Skip first part
            part = part.strip()
            if part:
                try:
                    # Try to parse JSON
                    # Find the JSON object
                    json_start = part.find("{")
                    json_end = part.rfind("}") + 1
                    if json_start >= 0 and json_end > json_start:
                        json_str = part[json_start:json_end]
                        data = json.loads(json_str)
                        
                        calls.append(ToolCall(
                            id=f"call_{i}",
                            name=data.get("name", ""),
                            arguments=data.get("parameters", {}),
                        ))
                except json.JSONDecodeError:
                    pass
    
    return calls


def parse_tool_calls_llama(text: str) -> List[ToolCall]:
    """
    Parse tool calls from Llama 3.x model output.
    
    Llama format (pythonic):
    [function_name(arg1="value1", arg2="value2")]
    """
    import re
    calls = []
    
    # Match pattern: [function_name(...)]
    pattern = r'\[(\w+)\((.*?)\)\]'
    matches = re.findall(pattern, text)
    
    for i, (func_name, args_str) in enumerate(matches):
        try:
            # Parse pythonic arguments
            args = {}
            if args_str:
                # Simple parsing: key="value" or key=value
                for part in args_str.split(","):
                    part = part.strip()
                    if "=" in part:
                        key, value = part.split("=", 1)
                        key = key.strip()
                        value = value.strip().strip('"\'')
                        args[key] = value
            
            calls.append(ToolCall(
                id=f"call_{i}",
                name=func_name,
                arguments=args,
            ))
        except Exception:
            pass
    
    return calls


def parse_tool_calls_deepseek(text: str) -> List[ToolCall]:
    """
    Parse tool calls from DeepSeek model output.
    
    DeepSeek format:
    <｜tool_calls_begin｜><｜tool_call_begin｜>function<｜tool_sep｜>name
    {"arg": "value"}
    <｜tool_call_end｜><｜tool_calls_end｜>
    """
    calls = []
    
    if "<｜tool_calls_begin｜>" in text:
        # Extract between markers
        start_marker = "<｜tool_call_begin｜>"
        end_marker = "<｜tool_call_end｜>"
        sep_marker = "<｜tool_sep｜>"
        
        parts = text.split(start_marker)
        for i, part in enumerate(parts[1:], start=1):
            if end_marker in part:
                call_text = part.split(end_marker)[0]
                
                if sep_marker in call_text:
                    func_type, rest = call_text.split(sep_marker, 1)
                    lines = rest.strip().split("\n")
                    
                    if lines:
                        func_name = lines[0].strip()
                        json_str = "\n".join(lines[1:]).strip()
                        
                        try:
                            args = json.loads(json_str) if json_str else {}
                            calls.append(ToolCall(
                                id=f"call_{i}",
                                name=func_name,
                                arguments=args,
                            ))
                        except json.JSONDecodeError:
                            pass
    
    return calls


def parse_tool_calls(text: str, parser: str = "qwen25") -> List[ToolCall]:
    """
    Parse tool calls from model output using specified parser.
    
    Args:
        text: Model output text
        parser: Parser type ("qwen25", "llama3", "pythonic", "deepseekv3")
    
    Returns:
        List of ToolCall objects
    """
    if parser in ["qwen25", "qwen3"]:
        return parse_tool_calls_qwen(text)
    elif parser in ["llama3", "llama4"]:
        return parse_tool_calls_llama(text)
    elif parser == "pythonic":
        return parse_tool_calls_llama(text)  # Same format
    elif parser in ["deepseekv3", "deepseekv31"]:
        return parse_tool_calls_deepseek(text)
    else:
        # Try all parsers
        calls = parse_tool_calls_qwen(text)
        if not calls:
            calls = parse_tool_calls_llama(text)
        if not calls:
            calls = parse_tool_calls_deepseek(text)
        return calls


def format_tools_for_prompt(
    tools: List[ToolDefinition],
    format_type: str = "openai",
) -> Any:
    """
    Format tools for model prompt.
    
    Args:
        tools: List of tool definitions
        format_type: "openai" (standard) or "json_schema"
    
    Returns:
        Formatted tools list
    """
    if format_type == "openai":
        return [tool.to_openai_format() for tool in tools]
    elif format_type == "json_schema":
        return {
            "tools": [
                {
                    "name": tool.name,
                    "description": tool.description,
                    "parameters": tool.parameters,
                }
                for tool in tools
            ]
        }
    else:
        return [tool.to_openai_format() for tool in tools]


class ToolUseRolloutManager:
    """
    Rollout manager that supports tool/function calling.
    
    Flow:
    1. Generate response
    2. Check for tool calls
    3. Execute tools
    4. Add results to conversation
    5. Continue until no more tool calls or termination
    """
    
    def __init__(
        self,
        backend: InferenceBackend,
        tokenizer,
        tool_executor: ToolExecutor,
        tool_call_parser: str = "qwen25",
        max_tool_calls_per_turn: int = 5,
    ):
        self.backend = backend
        self.tokenizer = tokenizer
        self.tool_executor = tool_executor
        self.tool_call_parser = tool_call_parser
        self.max_tool_calls_per_turn = max_tool_calls_per_turn
    
    def generate_with_tools(
        self,
        conversation: Conversation,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        enable_thinking: bool = True,
        **gen_kwargs,
    ) -> Tuple[GenerationResult, List[ToolCall], ThinkingResponse]:
        """
        Generate response that may contain tool calls.
        
        Returns:
            gen_result: Full generation result
            tool_calls: List of parsed tool calls
            thinking_response: Parsed thinking and content
        """
        # Format prompt with tools
        prompt = conversation.format_prompt(self.tokenizer, enable_thinking)
        
        # Generate
        results = self.backend.generate(
            [prompt],
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            return_logprobs=True,
            **gen_kwargs,
        )
        gen_result = results[0]
        
        # Parse thinking
        thinking_response = parse_thinking_tokens(gen_result.token_ids, self.tokenizer)
        
        # Parse tool calls from content
        tool_calls = parse_tool_calls(thinking_response.content, self.tool_call_parser)
        
        return gen_result, tool_calls, thinking_response
    
    def execute_tool_calls(
        self,
        tool_calls: List[ToolCall],
    ) -> List[ToolResult]:
        """Execute all tool calls"""
        results = []
        
        for call in tool_calls[:self.max_tool_calls_per_turn]:
            result_str = self.tool_executor.execute(call.name, call.arguments)
            
            results.append(ToolResult(
                tool_call_id=call.id,
                name=call.name,
                content=result_str,
            ))
        
        return results
    
    def collect_rollout_with_tools(
        self,
        initial_messages: List[Dict[str, str]],
        termination_condition: TerminationCondition,
        reward_function: RewardFunction,
        max_turns: int = 10,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        enable_thinking: bool = True,
        num_objectives: int = 1,
        **gen_kwargs,
    ) -> RolloutState:
        """
        Collect rollout with tool use support.
        
        Each turn:
        1. Generate (may include tool calls)
        2. If tool calls: execute and add results
        3. If no tool calls: check termination
        """
        conversation = Conversation(
            messages=[Message(**m) for m in initial_messages]
        )
        
        state = RolloutState(
            rollout_id=str(uuid.uuid4()),
            initial_prompt=conversation.format_prompt(self.tokenizer, enable_thinking),
            cumulative_reward=np.zeros(num_objectives),
            metadata={
                "conversation": conversation,
                "tool_calls": [],
                "tool_results": [],
            },
        )
        
        for turn in range(max_turns):
            # Generate
            gen_result, tool_calls, thinking_resp = self.generate_with_tools(
                conversation,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                enable_thinking=enable_thinking,
                **gen_kwargs,
            )
            
            # Add assistant message
            if tool_calls:
                # Has tool calls
                conversation = conversation.add_assistant_message(
                    content=thinking_resp.content,
                    thinking=thinking_resp.thinking,
                    tool_calls=[tc.to_dict() for tc in tool_calls],
                )
                
                # Execute tools
                tool_results = self.execute_tool_calls(tool_calls)
                
                # Add tool results to conversation
                for result in tool_results:
                    conversation = conversation.add_tool_result(
                        tool_call_id=result.tool_call_id,
                        name=result.name,
                        content=result.content,
                    )
                
                # Track
                state.metadata["tool_calls"].append(tool_calls)
                state.metadata["tool_results"].append(tool_results)
                
                # Create step
                step_data = StepData(
                    prompt=state.get_full_context(),
                    response=gen_result,
                    metadata={
                        "thinking": thinking_resp.thinking,
                        "content": thinking_resp.content,
                        "tool_calls": [tc.to_dict() for tc in tool_calls],
                        "tool_results": [r.content for r in tool_results],
                        "turn": turn,
                    },
                )
                state.add_step(step_data)
                state.metadata["conversation"] = conversation
                
                # Compute reward for tool use step
                step_data.reward = reward_function.compute_reward(state, step_data)
                state.cumulative_reward = state.cumulative_reward + step_data.reward
                
                # Continue to next turn (don't check termination yet)
                continue
            
            else:
                # No tool calls - regular response
                conversation = conversation.add_assistant_message(
                    content=thinking_resp.content,
                    thinking=thinking_resp.thinking,
                )
                
                step_data = StepData(
                    prompt=state.get_full_context(),
                    response=gen_result,
                    metadata={
                        "thinking": thinking_resp.thinking,
                        "content": thinking_resp.content,
                        "turn": turn,
                    },
                )
                state.add_step(step_data)
                state.metadata["conversation"] = conversation
                
                # Check termination
                should_stop, reason = termination_condition.should_terminate(state)
                if gen_result.finish_reason == "stop":
                    should_stop, reason = True, "eos"
                
                if should_stop:
                    state.terminated = True
                    state.termination_reason = reason
                    step_data.reward = reward_function.compute_reward(state, step_data)
                    state.cumulative_reward = state.cumulative_reward + step_data.reward
                    break
                else:
                    step_data.reward = reward_function.compute_reward(state, step_data)
                    state.cumulative_reward = state.cumulative_reward + step_data.reward
        
        if not state.terminated:
            state.terminated = True
            state.termination_reason = "max_turns"
        
        return state


# Example tool setup
def example_tool_setup():
    """Example: Set up tools for math verification"""
    
    executor = SimpleToolExecutor()
    
    # Register math calculator
    executor.register_tool(
        name="calculate",
        description="Perform mathematical calculations",
        parameters={
            "type": "object",
            "properties": {
                "expression": {
                    "type": "string",
                    "description": "Mathematical expression to evaluate"
                }
            },
            "required": ["expression"]
        },
        handler=lambda args: str(eval(args["expression"]))
    )
    
    # Register web search (placeholder)
    executor.register_tool(
        name="search",
        description="Search the web for information",
        parameters={
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "Search query"
                }
            },
            "required": ["query"]
        },
        handler=lambda args: f"Search results for: {args['query']}"
    )
    
    return executor


print("✓ Tool/function calling support defined")
print("  - ToolDefinition, ToolCall, ToolResult dataclasses")
print("  - SimpleToolExecutor for registering tools")
print("  - parse_tool_calls(text, parser) for Qwen/Llama/DeepSeek")
print("  - ToolUseRolloutManager.collect_rollout_with_tools(...)")
print("  - format_tools_for_prompt(tools)")

# Supabase Database Logging
Log every step, rollout, and training update to Supabase for analysis. Raw data storage - you control the schema.

In [None]:
# =============================================================================
# SUPABASE DATABASE LOGGING
# =============================================================================

# Install supabase if needed
# !pip install -q supabase

try:
    from supabase import create_client, Client
    SUPABASE_AVAILABLE = True
    print("✓ Supabase client available")
except ImportError:
    SUPABASE_AVAILABLE = False
    print("⚠ Supabase not installed. Run: pip install supabase")


class SupabaseLogger:
    """
    Log RL training data to Supabase.
    
    Collects data in batches for efficient writes.
    All data stored raw - you control analysis later.
    """
    
    def __init__(
        self,
        supabase_url: str,
        supabase_key: str,
        experiment_id: Optional[str] = None,
        batch_size: int = 100,
        auto_flush: bool = True,
    ):
        if not SUPABASE_AVAILABLE:
            raise ImportError("Supabase not installed. Run: pip install supabase")
        
        self.client: Client = create_client(supabase_url, supabase_key)
        self.experiment_id = experiment_id or f"exp_{int(time.time())}"
        self.batch_size = batch_size
        self.auto_flush = auto_flush
        
        # Batch buffers
        self.step_buffer: List[Dict] = []
        self.rollout_buffer: List[Dict] = []
        self.training_buffer: List[Dict] = []
        
        print(f"✓ SupabaseLogger initialized")
        print(f"  Experiment ID: {self.experiment_id}")
    
    def log_step(
        self,
        rollout_id: str,
        step_num: int,
        prompt: str,
        response: str,
        thinking: Optional[str] = None,
        reward: Optional[np.ndarray] = None,
        tool_calls: Optional[List[Dict]] = None,
        metadata: Optional[Dict] = None,
    ):
        """Log a single step in a rollout"""
        step_data = {
            "experiment_id": self.experiment_id,
            "rollout_id": rollout_id,
            "step_num": step_num,
            "timestamp": time.time(),
            "prompt": prompt,
            "response": response,
            "thinking": thinking,
            "reward": reward.tolist() if isinstance(reward, np.ndarray) else reward,
            "tool_calls": json.dumps(tool_calls) if tool_calls else None,
            "metadata": json.dumps(metadata) if metadata else None,
        }
        
        self.step_buffer.append(step_data)
        
        if self.auto_flush and len(self.step_buffer) >= self.batch_size:
            self.flush_steps()
    
    def log_rollout(
        self,
        state: RolloutState,
        final_reward: Optional[np.ndarray] = None,
    ):
        """Log a complete rollout"""
        rollout_data = {
            "experiment_id": self.experiment_id,
            "rollout_id": state.rollout_id,
            "timestamp": time.time(),
            "initial_prompt": state.initial_prompt,
            "num_steps": state.step_count,
            "terminated": state.terminated,
            "termination_reason": state.termination_reason,
            "cumulative_reward": state.cumulative_reward.tolist(),
            "final_reward": final_reward.tolist() if isinstance(final_reward, np.ndarray) else final_reward,
            "verification_budget_spent": state.verification_budget_spent,
            "metadata": json.dumps(state.metadata) if state.metadata else None,
        }
        
        self.rollout_buffer.append(rollout_data)
        
        # Also log each step
        for i, step in enumerate(state.trajectory):
            self.log_step(
                rollout_id=state.rollout_id,
                step_num=i,
                prompt=step.prompt,
                response=step.response.text,
                thinking=step.metadata.get("thinking") if step.metadata else None,
                reward=step.reward,
                tool_calls=step.metadata.get("tool_calls") if step.metadata else None,
                metadata=step.metadata,
            )
        
        if self.auto_flush and len(self.rollout_buffer) >= self.batch_size:
            self.flush_rollouts()
    
    def log_training_step(
        self,
        step_num: int,
        loss: float,
        metrics: Dict[str, float],
        learning_rate: Optional[float] = None,
        batch_size: Optional[int] = None,
    ):
        """Log a training update"""
        training_data = {
            "experiment_id": self.experiment_id,
            "step_num": step_num,
            "timestamp": time.time(),
            "loss": loss,
            "metrics": json.dumps(metrics),
            "learning_rate": learning_rate,
            "batch_size": batch_size,
        }
        
        self.training_buffer.append(training_data)
        
        if self.auto_flush and len(self.training_buffer) >= self.batch_size:
            self.flush_training()
    
    def flush_steps(self):
        """Flush step buffer to database"""
        if not self.step_buffer:
            return
        
        try:
            self.client.table("rl_steps").insert(self.step_buffer).execute()
            print(f"✓ Flushed {len(self.step_buffer)} steps to Supabase")
            self.step_buffer.clear()
        except Exception as e:
            print(f"⚠ Error flushing steps: {e}")
    
    def flush_rollouts(self):
        """Flush rollout buffer to database"""
        if not self.rollout_buffer:
            return
        
        try:
            self.client.table("rl_rollouts").insert(self.rollout_buffer).execute()
            print(f"✓ Flushed {len(self.rollout_buffer)} rollouts to Supabase")
            self.rollout_buffer.clear()
        except Exception as e:
            print(f"⚠ Error flushing rollouts: {e}")
    
    def flush_training(self):
        """Flush training buffer to database"""
        if not self.training_buffer:
            return
        
        try:
            self.client.table("rl_training").insert(self.training_buffer).execute()
            print(f"✓ Flushed {len(self.training_buffer)} training steps to Supabase")
            self.training_buffer.clear()
        except Exception as e:
            print(f"⚠ Error flushing training: {e}")
    
    def flush_all(self):
        """Flush all buffers"""
        self.flush_steps()
        self.flush_rollouts()
        self.flush_training()
    
    def log_experiment_config(self, config: Dict[str, Any]):
        """Log experiment configuration"""
        try:
            self.client.table("rl_experiments").insert({
                "experiment_id": self.experiment_id,
                "timestamp": time.time(),
                "config": json.dumps(config),
            }).execute()
            print(f"✓ Logged experiment config")
        except Exception as e:
            print(f"⚠ Error logging config: {e}")
    
    def get_experiment_data(self, table: str = "rl_rollouts") -> List[Dict]:
        """Retrieve experiment data from Supabase"""
        try:
            response = self.client.table(table).select("*").eq(
                "experiment_id", self.experiment_id
            ).execute()
            return response.data
        except Exception as e:
            print(f"⚠ Error retrieving data: {e}")
            return []


def create_supabase_tables_sql():
    """
    SQL to create tables in Supabase.
    Run this in your Supabase SQL editor.
    """
    return '''
-- Experiments table
CREATE TABLE IF NOT EXISTS rl_experiments (
    id SERIAL PRIMARY KEY,
    experiment_id TEXT NOT NULL,
    timestamp FLOAT NOT NULL,
    config JSONB,
    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);

-- Rollouts table
CREATE TABLE IF NOT EXISTS rl_rollouts (
    id SERIAL PRIMARY KEY,
    experiment_id TEXT NOT NULL,
    rollout_id TEXT NOT NULL,
    timestamp FLOAT NOT NULL,
    initial_prompt TEXT,
    num_steps INTEGER,
    terminated BOOLEAN,
    termination_reason TEXT,
    cumulative_reward JSONB,
    final_reward JSONB,
    verification_budget_spent FLOAT,
    metadata JSONB,
    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);

-- Steps table
CREATE TABLE IF NOT EXISTS rl_steps (
    id SERIAL PRIMARY KEY,
    experiment_id TEXT NOT NULL,
    rollout_id TEXT NOT NULL,
    step_num INTEGER NOT NULL,
    timestamp FLOAT NOT NULL,
    prompt TEXT,
    response TEXT,
    thinking TEXT,
    reward JSONB,
    tool_calls JSONB,
    metadata JSONB,
    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);

-- Training updates table
CREATE TABLE IF NOT EXISTS rl_training (
    id SERIAL PRIMARY KEY,
    experiment_id TEXT NOT NULL,
    step_num INTEGER NOT NULL,
    timestamp FLOAT NOT NULL,
    loss FLOAT,
    metrics JSONB,
    learning_rate FLOAT,
    batch_size INTEGER,
    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);

-- Indexes for faster queries
CREATE INDEX IF NOT EXISTS idx_rollouts_experiment ON rl_rollouts(experiment_id);
CREATE INDEX IF NOT EXISTS idx_steps_rollout ON rl_steps(rollout_id);
CREATE INDEX IF NOT EXISTS idx_training_experiment ON rl_training(experiment_id);
'''


class LoggingRewardWrapper(RewardFunction):
    """
    Wrapper that logs rewards to Supabase.
    Compose with your actual reward function.
    """
    
    def __init__(
        self,
        reward_fn: RewardFunction,
        logger: SupabaseLogger,
    ):
        self.reward_fn = reward_fn
        self.logger = logger
    
    def compute_reward(self, state: RolloutState, step: StepData) -> np.ndarray:
        reward = self.reward_fn.compute_reward(state, step)
        
        # Log step with reward
        self.logger.log_step(
            rollout_id=state.rollout_id,
            step_num=state.step_count - 1,  # Already added
            prompt=step.prompt,
            response=step.response.text,
            thinking=step.metadata.get("thinking") if step.metadata else None,
            reward=reward,
            tool_calls=step.metadata.get("tool_calls") if step.metadata else None,
            metadata=step.metadata,
        )
        
        return reward


def example_supabase_usage():
    """Example: Using Supabase logger in training loop"""
    
    # Setup (replace with your credentials)
    SUPABASE_URL = "https://your-project.supabase.co"
    SUPABASE_KEY = "your-anon-key"
    
    # Create logger
    logger = SupabaseLogger(
        supabase_url=SUPABASE_URL,
        supabase_key=SUPABASE_KEY,
        experiment_id="pareto_rl_exp_001",
        batch_size=50,  # Flush every 50 items
    )
    
    # Log experiment config
    logger.log_experiment_config({
        "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "num_objectives": 2,
        "max_steps": 20,
        "learning_rate": 1e-6,
    })
    
    # Wrap your reward function with logging
    my_reward = MyTaskReward(num_objectives=2)
    logged_reward = LoggingRewardWrapper(my_reward, logger)
    
    # Use in training loop
    # rollout = collect_single_rollout(..., reward_function=logged_reward)
    # logger.log_rollout(rollout)
    
    # After training step
    # logger.log_training_step(step_num, loss, metrics)
    
    # Don't forget to flush at end
    # logger.flush_all()
    
    return logger


print("✓ Supabase logging defined")
print("  - SupabaseLogger for batched writes")
print("  - log_step(), log_rollout(), log_training_step()")
print("  - LoggingRewardWrapper to auto-log rewards")
print("  - create_supabase_tables_sql() for schema")
print("\\nSetup:")
print("  1. Create tables in Supabase using create_supabase_tables_sql()")
print("  2. Get your SUPABASE_URL and SUPABASE_KEY")
print("  3. Create SupabaseLogger(url, key)")
print("  4. Wrap reward function or call log methods manually")