In [None]:
!nvidia-smi -L || true

import sys
print("Python:", sys.version)

# Install required packages
# Core ML
!pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                 datasets>=3.3.2 torch wandb huggingface_hub \
                 sentencepiece protobuf tqdm matplotlib pandas

# Inference backends (install based on what you'll use)
# Uncomment the backend you want:
# !pip install -q vllm>=0.6.0  # For vLLM backend
# !pip install -q "sglang[all]>=0.4.0"  # For SGLang backend

print("\n=== Environment ===")
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "Please connect a GPU for RL training."

In [None]:
# Core imports for RL harness
import os
import random
import time
import json
import asyncio
import uuid
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional, Any, Tuple, Callable
from collections import defaultdict, deque
from queue import Queue

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from tqdm.auto import tqdm
import pandas as pd

# HuggingFace
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Optional: vLLM (comment out if not installed)
try:
    from vllm import LLM, SamplingParams
    VLLM_AVAILABLE = True
    print("✓ vLLM available")
except ImportError:
    VLLM_AVAILABLE = False
    print("⚠ vLLM not installed")

# Optional: SGLang (comment out if not installed)
try:
    import sglang as sgl
    SGLANG_AVAILABLE = True
    print("✓ SGLang available")
except ImportError:
    SGLANG_AVAILABLE = False
    print("⚠ SGLang not installed")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Core Data Structures
Immutable data structures for tracking rollouts and generation results.

In [None]:
@dataclass
class GenerationResult:
    """Output from inference backend"""
    text: str
    token_ids: List[int]
    logprobs: Optional[List[float]] = None  # Log probabilities per token
    top_logprobs: Optional[List[Dict[int, float]]] = None  # Top-k logprobs
    finish_reason: str = "length"  # "stop", "length", "error"
    
@dataclass
class StepData:
    """Single step in a trajectory"""
    prompt: str
    response: GenerationResult
    action_logprobs: Optional[torch.Tensor] = None  # For gradient computation
    reward: Optional[np.ndarray] = None  # Multi-dimensional reward vector
    value_estimate: Optional[float] = None
    timestamp: float = field(default_factory=time.time)
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class RolloutState:
    """Complete state of a single rollout"""
    rollout_id: str
    initial_prompt: str
    trajectory: List[StepData] = field(default_factory=list)
    step_count: int = 0
    terminated: bool = False
    termination_reason: Optional[str] = None
    cumulative_reward: np.ndarray = field(default_factory=lambda: np.zeros(1))
    verification_budget_spent: float = 0.0
    max_verification_budget: float = 10.0
    ref_model_snapshot_id: Optional[str] = None  # For dynamic ref changes
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def checkpoint(self) -> 'RolloutState':
        """Create a deep copy for branching"""
        return deepcopy(self)
    
    def can_verify(self, cost: float) -> bool:
        """Check if verification is within budget"""
        return self.verification_budget_spent + cost <= self.max_verification_budget
    
    def add_step(self, step: StepData):
        """Add a step to trajectory"""
        self.trajectory.append(step)
        self.step_count += 1
        if step.reward is not None:
            self.cumulative_reward = self.cumulative_reward + step.reward
    
    def get_full_context(self) -> str:
        """Get the full conversation context"""
        context = self.initial_prompt
        for step in self.trajectory:
            context += step.response.text
        return context
    
    def to_dict(self) -> Dict:
        """Serialize for storage"""
        return asdict(self)

@dataclass
class TrainingBatch:
    """Batch of data for training update"""
    prompts: List[str]
    responses: List[str]
    token_ids: List[List[int]]
    rewards: np.ndarray  # (batch_size, num_objectives)
    old_logprobs: Optional[torch.Tensor] = None  # For importance sampling
    advantages: Optional[torch.Tensor] = None
    returns: Optional[torch.Tensor] = None
    
print("✓ Data structures defined")

In [None]:
import os

# WandB API key - get from https://wandb.ai/authorize
WANDB_API_KEY = ""  # Your WandB API key
if WANDB_API_KEY:
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY

# HuggingFace token - get from https://huggingface.co/settings/tokens
HF_TOKEN = ""  # Your HuggingFace token
if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN

In [None]:
import wandb
from huggingface_hub import login

# WandB login
try:
    wandb.login()
    print("✓ WandB login successful")
except Exception as e:
    print(f"⚠ WandB login failed: {e}")
    print("Training will continue without WandB logging")

# HuggingFace login
try:
    if os.environ.get('HF_TOKEN'):
        login(token=os.environ['HF_TOKEN'])
        print("✓ HuggingFace login successful")
except Exception as e:
    print(f"⚠ HuggingFace login failed: {e}")

# Async Rollout manager with dynamic batching
Since rollouts vary dramatically in length (6-50 steps), use an async architecture:

In [None]:
class AsyncRolloutManager:
    """
    Manages concurrent rollouts with dynamic batching.
    Groups rollouts by step count for efficient batch inference.
    """
    
    def __init__(
        self,
        inference_backend: InferenceBackend,
        max_concurrent_rollouts: int = 32,
        max_steps: int = 50,
        num_objectives: int = 1,
    ):
        self.backend = inference_backend
        self.max_concurrent = max_concurrent_rollouts
        self.max_steps = max_steps
        self.num_objectives = num_objectives
        
        self.active_rollouts: Dict[str, RolloutState] = {}
        self.completed_rollouts: List[RolloutState] = []
        self.rollout_counter = 0
        
        # Customizable hooks - you override these!
        self.should_terminate_fn: Optional[Callable[[RolloutState], bool]] = None
        self.get_reward_fn: Optional[Callable[[RolloutState, StepData], np.ndarray]] = None
        self.process_response_fn: Optional[Callable[[str], str]] = None
    
    def start_rollout(self, prompt: str, metadata: Optional[Dict] = None) -> str:
        """Start a new rollout from prompt"""
        if len(self.active_rollouts) >= self.max_concurrent:
            raise RuntimeError(f"Max concurrent rollouts ({self.max_concurrent}) reached")
        
        rollout_id = f"rollout_{self.rollout_counter}"
        self.rollout_counter += 1
        
        state = RolloutState(
            rollout_id=rollout_id,
            initial_prompt=prompt,
            cumulative_reward=np.zeros(self.num_objectives),
            metadata=metadata or {},
        )
        
        self.active_rollouts[rollout_id] = state
        return rollout_id
    
    def step_all(
        self,
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        **gen_kwargs
    ) -> Dict[str, GenerationResult]:
        """
        Step all active rollouts.
        Groups by step count for efficient batching.
        Returns: Dict mapping rollout_id to generation result
        """
        if not self.active_rollouts:
            return {}
        
        # Group active rollouts by step count for batching
        step_groups: Dict[int, List[str]] = defaultdict(list)
        for rid, state in self.active_rollouts.items():
            if not state.terminated:
                step_groups[state.step_count].append(rid)
        
        all_results = {}
        
        # Process each step group as a batch
        for step_num, rollout_ids in step_groups.items():
            # Gather prompts (full context so far)
            prompts = [self.active_rollouts[rid].get_full_context() for rid in rollout_ids]
            
            # Batch generation
            gen_results = self.backend.generate(
                prompts,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                **gen_kwargs
            )
            
            # Process results for each rollout
            for rid, gen_result in zip(rollout_ids, gen_results):
                state = self.active_rollouts[rid]
                
                # Optionally process response
                if self.process_response_fn:
                    gen_result.text = self.process_response_fn(gen_result.text)
                
                # Create step data
                step_data = StepData(
                    prompt=state.get_full_context(),
                    response=gen_result,
                )
                
                # Compute reward if function provided
                if self.get_reward_fn:
                    step_data.reward = self.get_reward_fn(state, step_data)
                else:
                    step_data.reward = np.zeros(self.num_objectives)
                
                # Add step to trajectory
                state.add_step(step_data)
                
                # Check termination conditions
                terminated = False
                termination_reason = None
                
                # Max steps reached
                if state.step_count >= self.max_steps:
                    terminated = True
                    termination_reason = "max_steps"
                
                # EOS token generated
                if gen_result.finish_reason == "stop":
                    terminated = True
                    termination_reason = "eos"
                
                # Custom termination
                if self.should_terminate_fn and self.should_terminate_fn(state):
                    terminated = True
                    termination_reason = "custom"
                
                if terminated:
                    state.terminated = True
                    state.termination_reason = termination_reason
                
                all_results[rid] = gen_result
        
        return all_results
    
    def complete_terminated_rollouts(self) -> List[RolloutState]:
        """Move terminated rollouts to completed list"""
        newly_completed = []
        to_remove = []
        
        for rid, state in self.active_rollouts.items():
            if state.terminated:
                self.completed_rollouts.append(state)
                newly_completed.append(state)
                to_remove.append(rid)
        
        for rid in to_remove:
            del self.active_rollouts[rid]
        
        return newly_completed
    
    def has_active_rollouts(self) -> bool:
        """Check if there are active rollouts"""
        return len(self.active_rollouts) > 0
    
    def get_statistics(self) -> Dict:
        """Get current rollout statistics"""
        active_steps = [s.step_count for s in self.active_rollouts.values()]
        completed_steps = [s.step_count for s in self.completed_rollouts]
        
        return {
            "active": len(self.active_rollouts),
            "completed": len(self.completed_rollouts),
            "active_avg_steps": np.mean(active_steps) if active_steps else 0,
            "completed_avg_steps": np.mean(completed_steps) if completed_steps else 0,
            "completed_rewards": [s.cumulative_reward.tolist() for s in self.completed_rollouts[-10:]],
        }
    
    def reset(self):
        """Reset manager for new training iteration"""
        self.active_rollouts.clear()
        self.completed_rollouts.clear()

print("✓ AsyncRolloutManager defined")

# Flexible Rollout State with Checkpointing
Track everything per-rollout with ability to branch/restore:


In [None]:
@dataclass
class RolloutState:
    prompt: str
    trajectory: List[StepData]
    step_count: int = 0
    terminated: bool = False
    termination_reason: Optional[str] = None
    cumulative_reward: np.ndarray  # Multi-dimensional for Pareto
    verification_budget_spent: float = 0.0
    ref_model_snapshot: Optional[str] = None  # For dynamic ref changes
    
    def checkpoint(self) -> 'RolloutState':
        return deepcopy(self)
    
    def can_verify(self, cost: float) -> bool:
        return self.verification_budget_spent + cost <= self.max_budget

# Multi-Backend Inference with Unified Interface

In [None]:
class InferenceBackend(ABC):
    """Abstract base class for inference backends"""
    
    @abstractmethod
    def generate(self, prompts: List[str], **kwargs) -> List[GenerationResult]:
        """Generate responses for a batch of prompts"""
        pass
    
    @abstractmethod
    def shutdown(self):
        """Clean up resources"""
        pass


class HuggingFaceBackend(InferenceBackend):
    """
    HuggingFace backend - full gradient access.
    Use this for:
    - Gradient computation during training (FULL UPDATES)
    - Reference model logit computation
    - Small-scale experiments
    """
    
    def __init__(
        self,
        model_name: str,
        device: str = "cuda",
        dtype: str = "bfloat16",  # "float32", "bfloat16", "float16"
        gradient_checkpointing: bool = False,  # Save memory during backward pass
    ):
        self.device = device
        self.model_name = model_name
        self.dtype_map = {
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
        }
        self.torch_dtype = self.dtype_map.get(dtype, torch.bfloat16)
        
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model - FULL WEIGHTS
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=self.torch_dtype,
            device_map=None,  # We'll move to device manually for full control
            trust_remote_code=True,
        )
        self.model = self.model.to(device)
        
        # Optional: Gradient checkpointing for memory efficiency
        if gradient_checkpointing:
            self.model.gradient_checkpointing_enable()
            print(f"✓ Gradient checkpointing enabled")
        
        # Count trainable parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"✓ HuggingFace backend loaded: {model_name}")
        print(f"  Total params: {total_params:,}")
        print(f"  Trainable params: {trainable_params:,}")
        print(f"  Dtype: {dtype}")
    
    def generate(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        do_sample: bool = True,
        return_logprobs: bool = False,
        **kwargs
    ) -> List[GenerationResult]:
        """Generate responses with optional logprobs"""
        self.model.eval()  # Set to eval mode for generation
        results = []
        
        for prompt in prompts:
            inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature if do_sample else 1.0,
                    top_p=top_p if do_sample else 1.0,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    output_scores=return_logprobs,
                    return_dict_in_generate=True,
                )
            
            # Decode
            generated_ids = outputs.sequences[0][inputs["input_ids"].shape[1]:]
            text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            
            # Extract logprobs if requested
            logprobs = None
            if return_logprobs and hasattr(outputs, "scores"):
                logprobs = []
                for i, score in enumerate(outputs.scores):
                    probs = F.log_softmax(score[0], dim=-1)
                    token_id = generated_ids[i].item()
                    logprobs.append(probs[token_id].item())
            
            results.append(GenerationResult(
                text=text,
                token_ids=generated_ids.tolist(),
                logprobs=logprobs,
                finish_reason="stop" if generated_ids[-1] == self.tokenizer.eos_token_id else "length"
            ))
        
        return results
    
    def get_logits(
        self,
        prompts: List[str],
        responses: List[str],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get logits for prompt+response pairs.
        Returns: (logits, attention_mask)
        Essential for KL divergence computation in RL.
        """
        full_texts = [p + r for p, r in zip(prompts, responses)]
        
        inputs = self.tokenizer(
            full_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        return outputs.logits, inputs["attention_mask"]
    
    def compute_log_probs(
        self,
        prompts: List[str],
        responses: List[str],
    ) -> torch.Tensor:
        """
        Compute log probabilities of responses given prompts.
        Critical for RL training - stable numerical computation.
        """
        full_texts = [p + r for p, r in zip(prompts, responses)]
        
        # Tokenize
        prompt_encodings = self.tokenizer(prompts, padding=True, return_tensors="pt")
        full_encodings = self.tokenizer(full_texts, padding=True, return_tensors="pt")
        
        prompt_lens = [len(self.tokenizer.encode(p)) for p in prompts]
        
        inputs = {k: v.to(self.device) for k, v in full_encodings.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits  # (batch, seq_len, vocab)
        
        # Compute log probs for each response token
        log_probs_list = []
        for i in range(len(prompts)):
            # Response starts after prompt
            start_idx = prompt_lens[i] - 1  # -1 because we predict next token
            response_logits = logits[i, start_idx:-1]  # (response_len, vocab)
            response_ids = full_encodings["input_ids"][i, prompt_lens[i]:]
            
            # Filter out padding
            mask = full_encodings["attention_mask"][i, prompt_lens[i]:]
            
            # Compute log softmax
            log_probs = F.log_softmax(response_logits, dim=-1)
            
            # Gather log probs of actual tokens
            token_log_probs = log_probs.gather(1, response_ids.unsqueeze(1).to(self.device)).squeeze()
            
            # Apply mask and sum
            masked_log_probs = token_log_probs * mask.to(self.device).float()
            log_probs_list.append(masked_log_probs.sum())
        
        return torch.stack(log_probs_list)
    
    def shutdown(self):
        del self.model
        torch.cuda.empty_cache()
        print("✓ HuggingFace backend shutdown")


class VLLMBackend(InferenceBackend):
    """
    vLLM backend - fast inference with PagedAttention.
    Use this for:
    - Fast rollout generation
    - Large batch inference
    Note: vLLM doesn't support gradient computation!
    """
    
    def __init__(
        self,
        model_name: str,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.9,
        max_model_len: int = 4096,
    ):
        if not VLLM_AVAILABLE:
            raise ImportError("vLLM not installed. Run: pip install vllm")
        
        self.model_name = model_name
        self.llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            max_model_len=max_model_len,
            trust_remote_code=True,
        )
        print(f"✓ vLLM backend loaded: {model_name}")
    
    def generate(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        return_logprobs: bool = False,
        n_logprobs: int = 5,
        **kwargs
    ) -> List[GenerationResult]:
        """Batch generate with vLLM"""
        
        sampling_params = SamplingParams(
            max_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            logprobs=n_logprobs if return_logprobs else None,
        )
        
        outputs = self.llm.generate(prompts, sampling_params)
        
        results = []
        for output in outputs:
            completion = output.outputs[0]
            
            logprobs = None
            if return_logprobs and completion.logprobs:
                logprobs = [lp[completion.token_ids[i]].logprob 
                           for i, lp in enumerate(completion.logprobs)]
            
            results.append(GenerationResult(
                text=completion.text,
                token_ids=list(completion.token_ids),
                logprobs=logprobs,
                finish_reason=completion.finish_reason or "length"
            ))
        
        return results
    
    def shutdown(self):
        del self.llm
        torch.cuda.empty_cache()
        print("✓ vLLM backend shutdown")


class SGLangBackend(InferenceBackend):
    """
    SGLang backend - RadixAttention for efficient KV caching.
    Use this for:
    - Fast rollout generation
    - Prefix caching (good for multi-turn)
    Note: SGLang doesn't support gradient computation!
    """
    
    def __init__(
        self,
        model_name: str,
        mem_fraction_static: float = 0.8,
        tp_size: int = 1,
    ):
        if not SGLANG_AVAILABLE:
            raise ImportError("SGLang not installed. Run: pip install 'sglang[all]'")
        
        self.model_name = model_name
        self.engine = sgl.Engine(
            model_path=model_name,
            mem_fraction_static=mem_fraction_static,
            tp_size=tp_size,
        )
        print(f"✓ SGLang backend loaded: {model_name}")
    
    def generate(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        top_p: float = 0.95,
        return_logprobs: bool = False,
        n_logprobs: int = 5,
        **kwargs
    ) -> List[GenerationResult]:
        """Batch generate with SGLang"""
        
        sampling_params = {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }
        
        if return_logprobs:
            outputs = self.engine.generate(
                prompts,
                sampling_params,
                return_logprob=True,
                top_logprobs_num=n_logprobs,
            )
        else:
            outputs = self.engine.generate(prompts, sampling_params)
        
        results = []
        for output in outputs:
            logprobs = None
            if return_logprobs and "meta_info" in output:
                # SGLang returns: (logprob, token_id, text)
                logprobs = [lp[0] for lp in output["meta_info"]["output_token_logprobs"]]
            
            results.append(GenerationResult(
                text=output["text"],
                token_ids=output.get("output_ids", []),
                logprobs=logprobs,
                finish_reason=output.get("meta_info", {}).get("finish_reason", {}).get("type", "length")
            ))
        
        return results
    
    def shutdown(self):
        self.engine.shutdown()
        print("✓ SGLang backend shutdown")


print("✓ Inference backends defined")

# Dynamic Reference Model Manager
Handle changing reference models mid-training:

In [None]:
class ReferenceModelManager:
    """
    Manages reference model for KL divergence computation.
    Supports dynamic snapshots mid-training.
    """
    
    def __init__(self, model_name: str, device: str = "cuda"):
        self.model_name = model_name
        self.device = device
        self.snapshots: Dict[str, Dict] = {}  # snapshot_id -> state_dict
        
        # Load reference model (frozen)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
        self.model.eval()
        
        # Freeze all parameters
        for param in self.model.parameters():
            param.requires_grad = False
        
        print(f"✓ Reference model loaded: {model_name}")
    
    def snapshot_current_training_model(self, training_model: nn.Module) -> str:
        """
        Create snapshot of current training model weights.
        Useful for dynamic reference model updates.
        """
        snapshot_id = f"snapshot_{int(time.time() * 1000)}"
        # Store only trainable parameters (e.g., LoRA weights)
        self.snapshots[snapshot_id] = {
            k: v.cpu().clone() for k, v in training_model.state_dict().items()
        }
        print(f"✓ Created snapshot: {snapshot_id}")
        return snapshot_id
    
    def load_snapshot(self, snapshot_id: str):
        """Load a previously saved snapshot into reference model"""
        if snapshot_id not in self.snapshots:
            raise ValueError(f"Snapshot {snapshot_id} not found")
        
        self.model.load_state_dict(self.snapshots[snapshot_id], strict=False)
        print(f"✓ Loaded snapshot: {snapshot_id}")
    
    def compute_log_probs(
        self,
        prompts: List[str],
        responses: List[str],
        snapshot_id: Optional[str] = None,
    ) -> torch.Tensor:
        """
        Compute log probabilities from reference model.
        Optionally uses a specific snapshot.
        """
        # Load snapshot if specified
        if snapshot_id and snapshot_id in self.snapshots:
            original_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            self.model.load_state_dict(self.snapshots[snapshot_id], strict=False)
        
        # Compute log probs
        full_texts = [p + r for p, r in zip(prompts, responses)]
        prompt_lens = [len(self.tokenizer.encode(p)) for p in prompts]
        
        full_encodings = self.tokenizer(full_texts, padding=True, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in full_encodings.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
        
        log_probs_list = []
        for i in range(len(prompts)):
            start_idx = prompt_lens[i] - 1
            response_logits = logits[i, start_idx:-1]
            response_ids = full_encodings["input_ids"][i, prompt_lens[i]:]
            mask = full_encodings["attention_mask"][i, prompt_lens[i]:]
            
            log_probs = F.log_softmax(response_logits, dim=-1)
            token_log_probs = log_probs.gather(1, response_ids.unsqueeze(1).to(self.device)).squeeze()
            masked_log_probs = token_log_probs * mask.to(self.device).float()
            log_probs_list.append(masked_log_probs.sum())
        
        # Restore original state if snapshot was used
        if snapshot_id and snapshot_id in self.snapshots:
            self.model.load_state_dict(original_state, strict=False)
        
        return torch.stack(log_probs_list)
    
    def compute_kl_divergence(
        self,
        prompts: List[str],
        responses: List[str],
        train_model: nn.Module,
        train_tokenizer,
        snapshot_id: Optional[str] = None,
        per_token: bool = False,
    ) -> torch.Tensor:
        """
        Compute KL divergence: KL(π_train || π_ref)
        
        Args:
            per_token: If True, return per-token KL. Otherwise, sum over sequence.
        """
        # Get reference log probs
        ref_log_probs = self.compute_log_probs(prompts, responses, snapshot_id)
        
        # Get training model log probs (with gradients if needed)
        full_texts = [p + r for p, r in zip(prompts, responses)]
        prompt_lens = [len(train_tokenizer.encode(p)) for p in prompts]
        
        full_encodings = train_tokenizer(full_texts, padding=True, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in full_encodings.items()}
        
        train_outputs = train_model(**inputs)
        train_logits = train_outputs.logits
        
        train_log_probs_list = []
        for i in range(len(prompts)):
            start_idx = prompt_lens[i] - 1
            response_logits = train_logits[i, start_idx:-1]
            response_ids = full_encodings["input_ids"][i, prompt_lens[i]:]
            mask = full_encodings["attention_mask"][i, prompt_lens[i]:]
            
            log_probs = F.log_softmax(response_logits, dim=-1)
            token_log_probs = log_probs.gather(1, response_ids.unsqueeze(1).to(self.device)).squeeze()
            masked_log_probs = token_log_probs * mask.to(self.device).float()
            train_log_probs_list.append(masked_log_probs.sum())
        
        train_log_probs = torch.stack(train_log_probs_list)
        
        # KL divergence: π_train * (log π_train - log π_ref)
        # Approximation: (π_train / π_ref).log() ≈ train_log_probs - ref_log_probs
        kl_div = train_log_probs - ref_log_probs
        
        return kl_div if per_token else kl_div.mean()
    
    def clear_snapshots(self):
        """Free memory by clearing old snapshots"""
        self.snapshots.clear()
        torch.cuda.empty_cache()
        print("✓ Cleared all snapshots")
    
    def shutdown(self):
        del self.model
        self.snapshots.clear()
        torch.cuda.empty_cache()
        print("✓ Reference model manager shutdown")

print("✓ Reference model manager defined")

# Multi-Objective Reward System with Pareto Tracking

In [None]:
class ParetoRewardTracker:
    """
    Track multi-objective rewards and maintain Pareto frontier.
    Plug in your own reward functions!
    """
    
    def __init__(self, objective_names: List[str]):
        self.objectives = objective_names
        self.num_objectives = len(objective_names)
        self.pareto_frontier: List[np.ndarray] = []
        self.all_rewards: List[np.ndarray] = []
        
        # YOU OVERRIDE THESE - map objective name to function
        self.reward_functions: Dict[str, Callable[[RolloutState], float]] = {}
    
    def register_reward_function(self, name: str, fn: Callable[[RolloutState], float]):
        """Register a reward function for an objective"""
        if name not in self.objectives:
            raise ValueError(f"Objective {name} not in {self.objectives}")
        self.reward_functions[name] = fn
    
    def compute_rewards(self, state: RolloutState) -> np.ndarray:
        """Compute reward vector for a completed rollout"""
        rewards = np.zeros(self.num_objectives)
        
        for i, obj_name in enumerate(self.objectives):
            if obj_name in self.reward_functions:
                rewards[i] = self.reward_functions[obj_name](state)
            else:
                # Default: use cumulative reward
                rewards[i] = state.cumulative_reward[i] if i < len(state.cumulative_reward) else 0.0
        
        return rewards
    
    def update_pareto_frontier(self, reward_vector: np.ndarray):
        """Update Pareto frontier with new reward vector"""
        self.all_rewards.append(reward_vector)
        
        # Check if new point is dominated
        is_dominated = False
        for frontier_point in self.pareto_frontier:
            if self._dominates(frontier_point, reward_vector):
                is_dominated = True
                break
        
        if not is_dominated:
            # Remove points dominated by new vector
            self.pareto_frontier = [
                p for p in self.pareto_frontier
                if not self._dominates(reward_vector, p)
            ]
            self.pareto_frontier.append(reward_vector)
    
    def _dominates(self, a: np.ndarray, b: np.ndarray) -> bool:
        """Check if a dominates b (a is better in all objectives)"""
        return np.all(a >= b) and np.any(a > b)
    
    def is_on_pareto_frontier(self, reward_vector: np.ndarray) -> bool:
        """Check if point is on Pareto frontier"""
        for p in self.pareto_frontier:
            if np.allclose(p, reward_vector):
                return True
        return False
    
    def get_pareto_weights(self, strategy: str = "random") -> np.ndarray:
        """
        Get weight vector for scalarizing multi-objective rewards.
        Used in training to sample different trade-offs.
        """
        if strategy == "random":
            # Random scalarization (linear scalarization)
            return np.random.dirichlet(np.ones(self.num_objectives))
        elif strategy == "uniform":
            return np.ones(self.num_objectives) / self.num_objectives
        elif strategy == "maximize_first":
            weights = np.zeros(self.num_objectives)
            weights[0] = 1.0
            return weights
        else:
            raise ValueError(f"Unknown strategy: {strategy}")
    
    def scalarize_reward(self, reward_vector: np.ndarray, weights: Optional[np.ndarray] = None) -> float:
        """Convert multi-objective reward to scalar"""
        if weights is None:
            weights = np.ones(self.num_objectives) / self.num_objectives
        return float(np.dot(weights, reward_vector))
    
    def get_frontier_statistics(self) -> Dict:
        """Get statistics about Pareto frontier"""
        if not self.pareto_frontier:
            return {"size": 0}
        
        frontier_array = np.array(self.pareto_frontier)
        return {
            "size": len(self.pareto_frontier),
            "mean": frontier_array.mean(axis=0).tolist(),
            "std": frontier_array.std(axis=0).tolist(),
            "min": frontier_array.min(axis=0).tolist(),
            "max": frontier_array.max(axis=0).tolist(),
        }

print("✓ ParetoRewardTracker defined")

# SVRL Environment Interface

In [None]:
class SVRLEnvironment:
    """
    Self-Verifying RL Environment.
    Model can query environment for information at a cost.
    You implement the verification logic!
    """
    
    def __init__(
        self,
        cost_function: Optional[Callable[[str, RolloutState], float]] = None,
        verification_handler: Optional[Callable[[str], Any]] = None,
    ):
        # YOU OVERRIDE THESE
        self.cost_fn = cost_function or self._default_cost
        self.verification_handler = verification_handler or self._default_verification
        
        # Tracking
        self.total_queries = 0
        self.total_cost_spent = 0.0
        self.query_history: List[Dict] = []
    
    def _default_cost(self, query: str, state: RolloutState) -> float:
        """Default cost: based on query length"""
        return len(query.split()) * 0.1
    
    def _default_verification(self, query: str) -> str:
        """Default verification: placeholder"""
        return f"[Verification result for: {query}]"
    
    def query_environment(
        self,
        query: str,
        state: RolloutState,
    ) -> Tuple[Optional[Any], float, bool]:
        """
        Model queries environment for verification.
        
        Returns:
            result: Verification result (None if over budget)
            cost: Cost of this query
            success: Whether query was successful
        """
        cost = self.cost_fn(query, state)
        
        if state.can_verify(cost):
            state.verification_budget_spent += cost
            result = self.verification_handler(query)
            
            # Track query
            self.total_queries += 1
            self.total_cost_spent += cost
            self.query_history.append({
                "rollout_id": state.rollout_id,
                "step": state.step_count,
                "query": query,
                "cost": cost,
                "budget_remaining": state.max_verification_budget - state.verification_budget_spent,
            })
            
            return result, cost, True
        else:
            return None, cost, False
    
    def parse_verification_request(self, model_output: str) -> Optional[str]:
        """
        Parse model output to extract verification request.
        You can customize this format!
        
        Default format: VERIFY: <query>
        """
        if "VERIFY:" in model_output:
            parts = model_output.split("VERIFY:", 1)
            if len(parts) > 1:
                query = parts[1].strip()
                # Extract until newline or end
                query = query.split("\n")[0].strip()
                return query
        return None
    
    def get_statistics(self) -> Dict:
        """Get environment statistics"""
        return {
            "total_queries": self.total_queries,
            "total_cost": self.total_cost_spent,
            "avg_cost_per_query": self.total_cost_spent / max(1, self.total_queries),
            "recent_queries": self.query_history[-10:],
        }
    
    def reset_tracking(self):
        """Reset tracking for new experiment"""
        self.total_queries = 0
        self.total_cost_spent = 0.0
        self.query_history.clear()

print("✓ SVRLEnvironment defined")

# Efficient Trajectory buffer lengths

In [None]:
class TrajectoryBuffer:
    """
    Buffer for completed trajectories with various sampling strategies.
    Supports priority sampling for RL training.
    """
    
    def __init__(
        self,
        max_buffer_size: int = 10000,
        num_objectives: int = 1,
    ):
        self.max_size = max_buffer_size
        self.num_objectives = num_objectives
        self.trajectories: deque = deque(maxlen=max_buffer_size)
        self.step_statistics: Dict[int, List[np.ndarray]] = defaultdict(list)
        
        # Optional: Pareto tracker for priority sampling
        self.pareto_tracker: Optional[ParetoRewardTracker] = None
    
    def add_trajectory(self, state: RolloutState):
        """Add completed trajectory to buffer"""
        self.trajectories.append(state)
        
        # Track statistics by trajectory length
        self.step_statistics[state.step_count].append(state.cumulative_reward)
        
        # Update Pareto frontier if tracker available
        if self.pareto_tracker:
            self.pareto_tracker.update_pareto_frontier(state.cumulative_reward)
    
    def add_batch(self, states: List[RolloutState]):
        """Add batch of completed trajectories"""
        for state in states:
            self.add_trajectory(state)
    
    def sample_batch(
        self,
        batch_size: int,
        strategy: str = "uniform",
    ) -> List[RolloutState]:
        """
        Sample batch of trajectories for training.
        
        Strategies:
        - uniform: Random sampling
        - completion_weighted: Prefer longer trajectories
        - pareto_weighted: Prefer Pareto-optimal trajectories
        - reward_weighted: Prefer high-reward trajectories
        """
        if len(self.trajectories) == 0:
            return []
        
        batch_size = min(batch_size, len(self.trajectories))
        
        if strategy == "uniform":
            return random.sample(list(self.trajectories), batch_size)
        
        elif strategy == "completion_weighted":
            # Prefer longer trajectories
            weights = np.array([len(t.trajectory) for t in self.trajectories], dtype=float)
            weights = weights / weights.sum()
            indices = np.random.choice(len(self.trajectories), size=batch_size, p=weights, replace=False)
            return [list(self.trajectories)[i] for i in indices]
        
        elif strategy == "pareto_weighted":
            # Prefer Pareto-optimal trajectories
            if not self.pareto_tracker:
                return self.sample_batch(batch_size, "uniform")
            
            weights = np.array([
                2.0 if self.pareto_tracker.is_on_pareto_frontier(t.cumulative_reward) else 1.0
                for t in self.trajectories
            ])
            weights = weights / weights.sum()
            indices = np.random.choice(len(self.trajectories), size=batch_size, p=weights, replace=False)
            return [list(self.trajectories)[i] for i in indices]
        
        elif strategy == "reward_weighted":
            # Prefer high-reward trajectories (sum across objectives)
            weights = np.array([t.cumulative_reward.sum() for t in self.trajectories], dtype=float)
            weights = weights - weights.min() + 1e-6  # Shift to positive
            weights = weights / weights.sum()
            indices = np.random.choice(len(self.trajectories), size=batch_size, p=weights, replace=False)
            return [list(self.trajectories)[i] for i in indices]
        
        else:
            raise ValueError(f"Unknown sampling strategy: {strategy}")
    
    def prepare_training_batch(
        self,
        trajectories: List[RolloutState],
        weights: Optional[np.ndarray] = None,
    ) -> TrainingBatch:
        """
        Convert trajectories to training batch format.
        Flattens all steps for batch processing.
        """
        prompts = []
        responses = []
        token_ids = []
        rewards = []
        
        for traj in trajectories:
            for step in traj.trajectory:
                prompts.append(step.prompt)
                responses.append(step.response.text)
                token_ids.append(step.response.token_ids)
                
                # Scalarize reward if multi-objective
                if weights is not None and step.reward is not None:
                    scalar_reward = np.dot(weights, step.reward)
                else:
                    scalar_reward = step.reward.sum() if step.reward is not None else 0.0
                rewards.append(scalar_reward)
        
        return TrainingBatch(
            prompts=prompts,
            responses=responses,
            token_ids=token_ids,
            rewards=np.array(rewards),
        )
    
    def get_statistics(self) -> Dict:
        """Get buffer statistics"""
        if not self.trajectories:
            return {"size": 0}
        
        all_lengths = [len(t.trajectory) for t in self.trajectories]
        all_rewards = [t.cumulative_reward for t in self.trajectories]
        
        return {
            "size": len(self.trajectories),
            "avg_length": np.mean(all_lengths),
            "std_length": np.std(all_lengths),
            "min_length": min(all_lengths),
            "max_length": max(all_lengths),
            "avg_reward": np.mean(all_rewards, axis=0).tolist() if all_rewards else [],
            "step_distribution": {k: len(v) for k, v in self.step_statistics.items()},
        }
    
    def clear(self):
        """Clear buffer"""
        self.trajectories.clear()
        self.step_statistics.clear()

print("✓ TrajectoryBuffer defined")

# training loop with early stopping

In [None]:
class RLTrainer:
    """
    Main RL trainer that orchestrates rollouts and gradient updates.
    Modular design - swap out components as needed.
    """
    
    def __init__(
        self,
        # Inference backend for fast rollout generation
        inference_backend: InferenceBackend,
        # HuggingFace model for gradient computation
        training_backend: HuggingFaceBackend,
        # Reference model manager for KL divergence
        ref_model_manager: ReferenceModelManager,
        # Pareto reward tracking
        pareto_tracker: ParetoRewardTracker,
        # Config
        num_objectives: int = 1,
        max_steps: int = 50,
        batch_size: int = 4,
        learning_rate: float = 1e-5,
        kl_coef: float = 0.1,
        max_grad_norm: float = 1.0,
        use_wandb: bool = False,
    ):
        self.inference_backend = inference_backend
        self.training_backend = training_backend
        self.ref_model_manager = ref_model_manager
        self.pareto_tracker = pareto_tracker
        
        self.num_objectives = num_objectives
        self.max_steps = max_steps
        self.batch_size = batch_size
        self.kl_coef = kl_coef
        self.max_grad_norm = max_grad_norm
        self.use_wandb = use_wandb
        
        # Rollout manager
        self.rollout_manager = AsyncRolloutManager(
            inference_backend=inference_backend,
            max_concurrent_rollouts=batch_size * 2,
            max_steps=max_steps,
            num_objectives=num_objectives,
        )
        
        # Trajectory buffer
        self.buffer = TrajectoryBuffer(
            max_buffer_size=10000,
            num_objectives=num_objectives,
        )
        self.buffer.pareto_tracker = pareto_tracker
        
        # Optimizer (on training model)
        self.training_backend.model.train()
        self.optimizer = torch.optim.AdamW(
            self.training_backend.model.parameters(),
            lr=learning_rate,
        )
        
        # Metrics
        self.step_count = 0
        self.train_metrics: List[Dict] = []
    
    def collect_rollouts(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
    ) -> List[RolloutState]:
        """Collect complete rollouts for given prompts"""
        # Start rollouts
        for prompt in prompts:
            self.rollout_manager.start_rollout(prompt)
        
        # Step until all complete
        with tqdm(total=self.max_steps, desc="Rollout steps") as pbar:
            while self.rollout_manager.has_active_rollouts():
                self.rollout_manager.step_all(
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                )
                
                # Move completed to buffer
                completed = self.rollout_manager.complete_terminated_rollouts()
                for state in completed:
                    self.buffer.add_trajectory(state)
                
                pbar.update(1)
        
        # Return all completed from this round
        return self.rollout_manager.completed_rollouts[-len(prompts):]
    
    def compute_policy_gradient_loss(
        self,
        batch: TrainingBatch,
        weights: Optional[np.ndarray] = None,
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Compute policy gradient loss.
        Simple REINFORCE with baseline.
        You can extend this for PPO, GRPO, etc.
        """
        # Enable gradients on training model
        self.training_backend.model.train()
        
        # Compute log probabilities
        total_loss = torch.tensor(0.0, device=self.training_backend.device)
        
        for i in range(len(batch.prompts)):
            prompt = batch.prompts[i]
            response = batch.responses[i]
            reward = batch.rewards[i]
            
            # Get full encoding
            full_text = prompt + response
            prompt_len = len(self.training_backend.tokenizer.encode(prompt))
            
            encoding = self.training_backend.tokenizer(
                full_text,
                return_tensors="pt",
                padding=True,
            )
            inputs = {k: v.to(self.training_backend.device) for k, v in encoding.items()}
            
            # Forward pass
            outputs = self.training_backend.model(**inputs)
            logits = outputs.logits
            
            # Compute log probs for response tokens
            start_idx = prompt_len - 1
            response_logits = logits[0, start_idx:-1]
            response_ids = encoding["input_ids"][0, prompt_len:]
            
            log_probs = F.log_softmax(response_logits, dim=-1)
            token_log_probs = log_probs.gather(
                1, response_ids.unsqueeze(1).to(self.training_backend.device)
            ).squeeze()
            
            # Policy gradient: -reward * log_prob
            # Negative because we want to maximize reward
            pg_loss = -reward * token_log_probs.sum()
            total_loss = total_loss + pg_loss
        
        avg_loss = total_loss / len(batch.prompts)
        
        # Compute KL divergence
        kl_div = self.ref_model_manager.compute_kl_divergence(
            batch.prompts,
            batch.responses,
            self.training_backend.model,
            self.training_backend.tokenizer,
        )
        
        # Total loss with KL penalty
        final_loss = avg_loss + self.kl_coef * kl_div
        
        metrics = {
            "pg_loss": avg_loss.item(),
            "kl_div": kl_div.item(),
            "total_loss": final_loss.item(),
            "avg_reward": float(np.mean(batch.rewards)),
        }
        
        return final_loss, metrics
    
    def train_step(
        self,
        prompts: List[str],
        scalarization_strategy: str = "random",
    ) -> Dict:
        """
        Single training step:
        1. Collect rollouts
        2. Sample from buffer
        3. Compute loss and update
        """
        # 1. Collect rollouts
        self.collect_rollouts(prompts)
        
        # 2. Sample batch from buffer
        sampled_trajs = self.buffer.sample_batch(
            batch_size=self.batch_size,
            strategy="completion_weighted",
        )
        
        # 3. Get scalarization weights for multi-objective
        weights = self.pareto_tracker.get_pareto_weights(strategy=scalarization_strategy)
        
        # 4. Prepare training batch
        train_batch = self.buffer.prepare_training_batch(sampled_trajs, weights)
        
        # 5. Compute loss
        loss, metrics = self.compute_policy_gradient_loss(train_batch, weights)
        
        # 6. Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.training_backend.model.parameters(),
            self.max_grad_norm,
        )
        metrics["grad_norm"] = grad_norm.item()
        
        # 7. Optimizer step
        self.optimizer.step()
        
        # 8. Update tracking
        self.step_count += 1
        metrics["step"] = self.step_count
        metrics["buffer_size"] = len(self.buffer.trajectories)
        metrics["pareto_frontier_size"] = len(self.pareto_tracker.pareto_frontier)
        metrics["weights"] = weights.tolist()
        
        self.train_metrics.append(metrics)
        
        # Log to wandb if enabled
        if self.use_wandb:
            try:
                wandb.log(metrics, step=self.step_count)
            except Exception:
                pass
        
        return metrics
    
    def train(
        self,
        prompt_generator: Callable[[], List[str]],
        num_iterations: int = 100,
        eval_every: int = 10,
        save_every: int = 50,
        checkpoint_dir: str = "./checkpoints",
    ):
        """
        Main training loop.
        
        Args:
            prompt_generator: Function that returns batch of prompts
            num_iterations: Number of training iterations
        """
        print(f"Starting training for {num_iterations} iterations")
        
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        for i in tqdm(range(num_iterations), desc="Training"):
            # Get prompts
            prompts = prompt_generator()
            
            # Train step
            metrics = self.train_step(prompts)
            
            # Log progress
            if (i + 1) % eval_every == 0:
                print(f"\nIteration {i+1}:")
                print(f"  Loss: {metrics['total_loss']:.4f}")
                print(f"  KL: {metrics['kl_div']:.4f}")
                print(f"  Avg Reward: {metrics['avg_reward']:.4f}")
                print(f"  Buffer Size: {metrics['buffer_size']}")
                print(f"  Pareto Frontier: {metrics['pareto_frontier_size']}")
            
            # Save checkpoint
            if (i + 1) % save_every == 0:
                self.save_checkpoint(
                    f"{checkpoint_dir}/checkpoint_{i+1}.pt"
                )
    
    def save_checkpoint(self, path: str):
        """Save training checkpoint"""
        checkpoint = {
            "step": self.step_count,
            "model_state": self.training_backend.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "metrics": self.train_metrics,
            "pareto_frontier": [p.tolist() for p in self.pareto_tracker.pareto_frontier],
        }
        torch.save(checkpoint, path)
        print(f"✓ Saved checkpoint to {path}")
    
    def load_checkpoint(self, path: str):
        """Load training checkpoint"""
        checkpoint = torch.load(path, map_location=self.training_backend.device)
        self.step_count = checkpoint["step"]
        self.training_backend.model.load_state_dict(checkpoint["model_state"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state"])
        self.train_metrics = checkpoint["metrics"]
        self.pareto_tracker.pareto_frontier = [
            np.array(p) for p in checkpoint["pareto_frontier"]
        ]
        print(f"✓ Loaded checkpoint from {path}")

print("✓ RLTrainer defined")

# Modular Configuration System

In [ ]:
# =============================================================================
# CONFIGURATION: Experiment Settings
# =============================================================================

@dataclass
class ExperimentConfig:
    """Configuration for RL experiments"""
    
    # Model settings
    model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    dtype: str = "bfloat16"
    gradient_checkpointing: bool = True
    
    # Training settings
    num_objectives: int = 2
    max_steps: int = 50
    batch_size: int = 4
    learning_rate: float = 1e-6
    kl_coef: float = 0.1
    max_grad_norm: float = 1.0
    
    # Rollout settings
    max_new_tokens: int = 256
    temperature: float = 1.0
    top_p: float = 0.95
    
    # SVRL settings (if using)
    svrl_enabled: bool = False
    verification_budget: float = 10.0
    
    # Buffer settings
    buffer_size: int = 10000
    sampling_strategy: str = "completion_weighted"  # uniform, pareto_weighted, reward_weighted
    
    # Dynamic reference model
    update_ref_model: bool = False
    ref_update_frequency: int = 100  # steps
    
    # Logging
    use_wandb: bool = False
    checkpoint_dir: str = "./checkpoints"
    save_every: int = 50
    
    def to_dict(self) -> Dict:
        return asdict(self)
    
    def save(self, path: str):
        with open(path, "w") as f:
            json.dump(self.to_dict(), f, indent=2)
        print(f"✓ Config saved to {path}")
    
    @classmethod
    def load(cls, path: str) -> "ExperimentConfig":
        with open(path, "r") as f:
            data = json.load(f)
        return cls(**data)


# Example configurations for different experiments
def get_pareto_rl_config() -> ExperimentConfig:
    """Config for Pareto/Multi-objective RL experiments"""
    return ExperimentConfig(
        model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        num_objectives=2,  # task_completion, efficiency
        max_steps=50,
        batch_size=4,
        learning_rate=1e-6,
        kl_coef=0.1,
    )


def get_svrl_config() -> ExperimentConfig:
    """Config for Self-Verifying RL experiments"""
    return ExperimentConfig(
        model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        num_objectives=3,  # correctness, efficiency, verification_cost
        svrl_enabled=True,
        verification_budget=10.0,
        max_steps=30,
    )


def get_no_kl_config() -> ExperimentConfig:
    """Config for experiments without KL penalty"""
    return ExperimentConfig(
        model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        kl_coef=0.0,  # No KL penalty
        max_steps=20,
    )


def get_dynamic_ref_config() -> ExperimentConfig:
    """Config for dynamic reference model experiments"""
    return ExperimentConfig(
        model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        update_ref_model=True,
        ref_update_frequency=50,
        kl_coef=0.05,  # Lower KL coef since ref model updates
    )


print("✓ Configuration system defined")
print("\\nPreset configs available:")
print("  - get_pareto_rl_config()")
print("  - get_svrl_config()")
print("  - get_no_kl_config()")
print("  - get_dynamic_ref_config()")

# Implementation Summary

## What's Implemented:
- **Dynamic batching**: Groups rollouts by step count for efficient inference
- **Padding masks**: Handled automatically via HuggingFace tokenizers
- **Priority sampling**: Multiple strategies (completion_weighted, pareto_weighted, reward_weighted)
- **Variable-length rollouts**: AsyncRolloutManager handles different trajectory lengths
- **Full parameter updates**: No LoRA - direct gradient computation on all weights

## Key Features:
- **Pareto frontier tracking**: Maintains non-dominated solutions for multi-objective optimization
- **Dynamic reference model**: Snapshot and restore training model weights mid-training
- **KL divergence control**: Optional penalty to keep policy close to reference
- **SVRL environment**: Query environment for verification at a cost
- **Checkpointing**: Save/load training state including Pareto frontier

In [None]:
# =============================================================================
# UTILITY: Visualization and Analysis
# =============================================================================

def plot_pareto_frontier(tracker: ParetoRewardTracker):
    """Plot 2D Pareto frontier"""
    import matplotlib.pyplot as plt
    
    if tracker.num_objectives != 2:
        print("Plotting only supports 2 objectives")
        return
    
    if not tracker.pareto_frontier:
        print("No Pareto frontier points yet")
        return
    
    frontier = np.array(tracker.pareto_frontier)
    all_points = np.array(tracker.all_rewards)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot all points
    ax.scatter(all_points[:, 0], all_points[:, 1], alpha=0.3, label="All trajectories")
    
    # Plot Pareto frontier
    ax.scatter(frontier[:, 0], frontier[:, 1], color="red", s=100, marker="*", 
               label="Pareto frontier", zorder=5)
    
    # Sort frontier for line plot
    sorted_idx = np.argsort(frontier[:, 0])
    ax.plot(frontier[sorted_idx, 0], frontier[sorted_idx, 1], "r--", alpha=0.5)
    
    ax.set_xlabel(tracker.objectives[0])
    ax.set_ylabel(tracker.objectives[1])
    ax.set_title("Pareto Frontier")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.show()


def plot_training_metrics(trainer: RLTrainer):
    """Plot training metrics over time"""
    import matplotlib.pyplot as plt
    
    if not trainer.train_metrics:
        print("No training metrics yet")
        return
    
    df = pd.DataFrame(trainer.train_metrics)
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Loss
    axes[0, 0].plot(df["step"], df["total_loss"], label="Total Loss")
    axes[0, 0].plot(df["step"], df["pg_loss"], label="PG Loss", alpha=0.7)
    axes[0, 0].set_xlabel("Step")
    axes[0, 0].set_ylabel("Loss")
    axes[0, 0].set_title("Training Loss")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # KL Divergence
    axes[0, 1].plot(df["step"], df["kl_div"])
    axes[0, 1].set_xlabel("Step")
    axes[0, 1].set_ylabel("KL Divergence")
    axes[0, 1].set_title("KL Divergence from Reference")
    axes[0, 1].grid(True, alpha=0.3)
    
    # Average Reward
    axes[1, 0].plot(df["step"], df["avg_reward"])
    axes[1, 0].set_xlabel("Step")
    axes[1, 0].set_ylabel("Average Reward")
    axes[1, 0].set_title("Average Reward per Step")
    axes[1, 0].grid(True, alpha=0.3)
    
    # Buffer and Pareto size
    axes[1, 1].plot(df["step"], df["buffer_size"], label="Buffer Size")
    axes[1, 1].plot(df["step"], df["pareto_frontier_size"], label="Pareto Frontier")
    axes[1, 1].set_xlabel("Step")
    axes[1, 1].set_ylabel("Count")
    axes[1, 1].set_title("Buffer & Pareto Statistics")
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


def analyze_trajectory_lengths(buffer: TrajectoryBuffer):
    """Analyze distribution of trajectory lengths"""
    import matplotlib.pyplot as plt
    
    if not buffer.trajectories:
        print("No trajectories in buffer")
        return
    
    lengths = [len(t.trajectory) for t in buffer.trajectories]
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Histogram
    axes[0].hist(lengths, bins=20, edgecolor="black")
    axes[0].axvline(np.mean(lengths), color="red", linestyle="--", label=f"Mean: {np.mean(lengths):.1f}")
    axes[0].set_xlabel("Trajectory Length (steps)")
    axes[0].set_ylabel("Count")
    axes[0].set_title("Distribution of Trajectory Lengths")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Box plot
    axes[1].boxplot(lengths)
    axes[1].set_ylabel("Steps")
    axes[1].set_title("Trajectory Length Statistics")
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Length statistics:")
    print(f"  Mean: {np.mean(lengths):.2f}")
    print(f"  Std: {np.std(lengths):.2f}")
    print(f"  Min: {min(lengths)}")
    print(f"  Max: {max(lengths)}")
    print(f"  Median: {np.median(lengths):.1f}")

print("✓ Visualization utilities defined")

In [None]:
# =============================================================================
# EXAMPLE: How to use the RL harness (FULL PARAMETER UPDATES)
# =============================================================================
# This shows how to plug everything together.
# You customize: rewards, termination, prompts

def example_setup():
    """
    Example setup for Pareto RL experiment.
    Uses full parameter updates (no LoRA).
    
    NOTE: For 7B model with full updates:
    - Inference: ~14GB VRAM (bfloat16)
    - Training: ~28GB VRAM (weights + gradients + optimizer states)
    - Total with ref model: ~42GB+ VRAM
    
    Use gradient_checkpointing=True to reduce memory at cost of speed.
    Or use smaller model (1B-3B) for experiments.
    """
    
    # MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # Requires 40GB+ VRAM
    MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # Good for testing
    # MODEL_NAME = "facebook/opt-1.3b"  # Another small option
    
    # 1. Setup inference backend (for fast rollout generation)
    # Option A: Same HuggingFace model (simplest, slower)
    inference = HuggingFaceBackend(
        model_name=MODEL_NAME,
        dtype="bfloat16",
        gradient_checkpointing=False,
    )
    
    # Option B: vLLM (faster, uncomment if installed)
    # inference = VLLMBackend(
    #     model_name=MODEL_NAME,
    #     tensor_parallel_size=1,
    #     gpu_memory_utilization=0.3,  # Leave room for training model
    # )
    
    # 2. Setup training backend (for gradient computation) - FULL PARAMS
    training = HuggingFaceBackend(
        model_name=MODEL_NAME,
        dtype="bfloat16",
        gradient_checkpointing=True,  # Enable for memory efficiency
    )
    
    # 3. Setup reference model (for KL divergence)
    ref_manager = ReferenceModelManager(
        model_name=MODEL_NAME,
    )
    
    # 4. Setup Pareto reward tracker with 2 objectives
    objectives = ["task_completion", "efficiency"]
    pareto_tracker = ParetoRewardTracker(objectives)
    
    # 5. YOUR REWARD FUNCTIONS - customize these!
    def task_completion_reward(state: RolloutState) -> float:
        """Example: Did the model solve the task?"""
        final_response = state.trajectory[-1].response.text if state.trajectory else ""
        if "ANSWER:" in final_response:
            return 1.0
        return 0.0
    
    def efficiency_reward(state: RolloutState) -> float:
        """Example: How efficient was the solution?"""
        max_steps = 50
        efficiency = 1.0 - (state.step_count / max_steps)
        return max(0.0, efficiency)
    
    pareto_tracker.register_reward_function("task_completion", task_completion_reward)
    pareto_tracker.register_reward_function("efficiency", efficiency_reward)
    
    # 6. Create trainer
    trainer = RLTrainer(
        inference_backend=inference,
        training_backend=training,
        ref_model_manager=ref_manager,
        pareto_tracker=pareto_tracker,
        num_objectives=len(objectives),
        max_steps=20,
        batch_size=2,  # Smaller batch for memory
        learning_rate=1e-6,  # Lower LR for full model updates
        kl_coef=0.1,
        use_wandb=False,
    )
    
    # 7. YOUR PROMPT GENERATOR
    def get_prompts() -> List[str]:
        problems = [
            "Solve step by step: What is 15 + 27?",
            "Solve step by step: What is 8 * 9?",
        ]
        return problems
    
    # 8. YOUR TERMINATION CONDITION
    def should_terminate(state: RolloutState) -> bool:
        if state.trajectory:
            last_response = state.trajectory[-1].response.text
            return "ANSWER:" in last_response or "DONE" in last_response
        return False
    
    trainer.rollout_manager.should_terminate_fn = should_terminate
    
    # 9. YOUR STEP REWARD FUNCTION
    def step_reward(state: RolloutState, step: StepData) -> np.ndarray:
        return np.array([-0.01, 0.0])
    
    trainer.rollout_manager.get_reward_fn = step_reward
    
    return trainer, get_prompts


def example_svrl_setup():
    """
    Example setup for SVRL (Self-Verifying RL) experiment.
    Model can query environment for verification at a cost.
    """
    
    def verify_math(query: str) -> str:
        try:
            result = eval(query)
            return f"Result: {result}"
        except Exception as e:
            return f"Error: {str(e)}"
    
    def verification_cost(query: str, state: RolloutState) -> float:
        base_cost = 1.0
        step_penalty = state.step_count * 0.1
        return base_cost + step_penalty
    
    svrl_env = SVRLEnvironment(
        cost_function=verification_cost,
        verification_handler=verify_math,
    )
    
    return svrl_env


def example_dynamic_ref_model():
    """
    Example: Change reference model mid-training.
    Useful for curriculum learning or adaptive KL penalties.
    """
    def update_ref_model_periodically(trainer: RLTrainer, interval: int = 100):
        """
        Snapshot training model as new reference every N steps.
        This lets the policy drift without strong KL penalty.
        """
        if trainer.step_count % interval == 0 and trainer.step_count > 0:
            snapshot_id = trainer.ref_model_manager.snapshot_current_training_model(
                trainer.training_backend.model
            )
            print(f"✓ Updated reference model at step {trainer.step_count}")
            return snapshot_id
        return None
    
    return update_ref_model_periodically


def example_no_kl_training():
    """
    Example: Train without KL divergence penalty.
    Set kl_coef=0 or modify compute_policy_gradient_loss.
    """
    class NoKLTrainer(RLTrainer):
        def compute_policy_gradient_loss(self, batch, weights=None):
            # Skip KL computation entirely
            self.training_backend.model.train()
            total_loss = torch.tensor(0.0, device=self.training_backend.device)
            
            for i in range(len(batch.prompts)):
                prompt = batch.prompts[i]
                response = batch.responses[i]
                reward = batch.rewards[i]
                
                full_text = prompt + response
                prompt_len = len(self.training_backend.tokenizer.encode(prompt))
                
                encoding = self.training_backend.tokenizer(full_text, return_tensors="pt")
                inputs = {k: v.to(self.training_backend.device) for k, v in encoding.items()}
                
                outputs = self.training_backend.model(**inputs)
                logits = outputs.logits
                
                start_idx = prompt_len - 1
                response_logits = logits[0, start_idx:-1]
                response_ids = encoding["input_ids"][0, prompt_len:]
                
                log_probs = F.log_softmax(response_logits, dim=-1)
                token_log_probs = log_probs.gather(
                    1, response_ids.unsqueeze(1).to(self.training_backend.device)
                ).squeeze()
                
                pg_loss = -reward * token_log_probs.sum()
                total_loss = total_loss + pg_loss
            
            avg_loss = total_loss / len(batch.prompts)
            
            metrics = {
                "pg_loss": avg_loss.item(),
                "kl_div": 0.0,  # No KL
                "total_loss": avg_loss.item(),
                "avg_reward": float(np.mean(batch.rewards)),
            }
            
            return avg_loss, metrics
    
    return NoKLTrainer


# =============================================================================
# RUN TRAINING (uncomment to execute)
# =============================================================================
"""
# Setup
trainer, prompt_generator = example_setup()

# Optional: Disable KL penalty
# trainer.kl_coef = 0.0

# Optional: Dynamic reference model updates
# update_ref_fn = example_dynamic_ref_model()

# Train!
for i in range(10):  # Short run
    prompts = prompt_generator()
    metrics = trainer.train_step(prompts)
    print(f"Step {metrics['step']}: Loss={metrics['total_loss']:.4f}, Reward={metrics['avg_reward']:.4f}")
    
    # Optional: Update ref model
    # update_ref_fn(trainer, interval=5)

# Check results
print("\\nPareto Frontier:", trainer.pareto_tracker.get_frontier_statistics())
print("Buffer Stats:", trainer.buffer.get_statistics())
"""

print("✓ Example setup defined - ready for full parameter updates!")
print("\\nYour TODO:")
print("1. Choose your model (MODEL_NAME) - TinyLlama for testing, larger for real experiments")
print("2. Implement YOUR reward functions")
print("3. Implement YOUR prompt generator")
print("4. Implement YOUR termination conditions")
print("5. Run training loop")
print("\\nMemory requirements (bfloat16):")
print("  - 1B model: ~8GB total")
print("  - 3B model: ~24GB total")
print("  - 7B model: ~42GB total (use gradient_checkpointing)")
print("\\nFor experiments without KL penalty: set trainer.kl_coef = 0.0")