In [None]:
!nvidia-smi -L || true

import sys
print("Python:", sys.version)

# Install required packages
try:
    %pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                     datasets>=3.3.2 torch wandb huggingface_hub \
                     sentencepiece protobuf tqdm matplotlib pandas
except Exception:
    !pip install -q transformers>=4.51.3 accelerate>=1.4.0 peft>=0.14.0 \
                     datasets>=3.3.2 torch wandb huggingface_hub \
                     sentencepiece protobuf tqdm matplotlib pandas

import os, random, time, json, platform
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import pandas as pd
from IPython.display import display

print("\n=== Environment ===")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "Please connect a GPU for RL training."

In [None]:
import os

# WandB API key - get from https://wandb.ai/authorize
WANDB_API_KEY = ""  # Your WandB API key
if WANDB_API_KEY:
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY

# HuggingFace token - get from https://huggingface.co/settings/tokens
HF_TOKEN = ""  # Your HuggingFace token
if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN

In [None]:
import wandb
from huggingface_hub import login

# WandB login
try:
    wandb.login()
    print("✓ WandB login successful")
except Exception as e:
    print(f"⚠ WandB login failed: {e}")
    print("Training will continue without WandB logging")

# HuggingFace login
try:
    if os.environ.get('HF_TOKEN'):
        login(token=os.environ['HF_TOKEN'])
        print("✓ HuggingFace login successful")
except Exception as e:
    print(f"⚠ HuggingFace login failed: {e}")

# Async Rollout manager with dynamic batching
Since rollouts vary dramatically in length (6-50 steps), use an async architecture:

In [None]:
class AsyncRolloutManager:
    def __init__(self, max_concurrent_rollouts=32):
        self.active_rollouts = {}  # id -> RolloutState
        self.completed_queue = Queue()
        self.step_buffers = defaultdict(list)  # group by step count
        
    async def step_all(self):
        # Group rollouts by current step for efficient batching
        step_groups = defaultdict(list)
        for rid, state in self.active_rollouts.items():
            if not state.terminated:
                step_groups[state.step_count].append(rid)
        
        # Process each group with vllm/sglang batch inference
        for step_num, rollout_ids in step_groups.items():
            batch_results = await self.batch_generate(rollout_ids)
            self.process_results(rollout_ids, batch_results)

# Flexible Rollout State with Checkpointing
Track everything per-rollout with ability to branch/restore:


In [None]:
@dataclass
class RolloutState:
    prompt: str
    trajectory: List[StepData]
    step_count: int = 0
    terminated: bool = False
    termination_reason: Optional[str] = None
    cumulative_reward: np.ndarray  # Multi-dimensional for Pareto
    verification_budget_spent: float = 0.0
    ref_model_snapshot: Optional[str] = None  # For dynamic ref changes
    
    def checkpoint(self) -> 'RolloutState':
        return deepcopy(self)
    
    def can_verify(self, cost: float) -> bool:
        return self.verification_budget_spent + cost <= self.max_budget

# Multi-Backend Inference with Unified Interface

In [None]:
class InferenceBackend(ABC):
    @abstractmethod
    async def generate(self, prompts, **kwargs) -> List[GenerationResult]:
        pass
    
    @abstractmethod
    def get_logits(self, prompts, responses) -> torch.Tensor:
        pass

class VLLMBackend(InferenceBackend):
    def __init__(self, model_name, tensor_parallel_size=1):
        self.engine = AsyncLLMEngine.from_engine_args(...)
        
class SGLangBackend(InferenceBackend):
    def __init__(self, model_name):
        self.runtime = sgl.Runtime(model_path=model_name)

# Dynamic Reference Model Manager
Handle changing reference models mid-training:

In [None]:

class ReferenceModelManager:
    def __init__(self):
        self.checkpoints = {}  # timestamp -> model state
        self.current_ref = None
        
    def snapshot_current(self) -> str:
        """Create snapshot of current ref model"""
        snapshot_id = f"ref_{time.time()}"
        self.checkpoints[snapshot_id] = self.get_model_state()
        return snapshot_id
    
    def compute_kl(self, prompts, responses, ref_snapshot=None):
        if ref_snapshot and ref_snapshot in self.checkpoints:
            # Use specific snapshot
            ref_logits = self.get_logits_from_snapshot(ref_snapshot, ...)
        else:
            ref_logits = self.current_ref.get_logits(...)
        return kl_divergence(train_logits, ref_logits)

# Multi-Objective Reward System with Pareto Tracking

In [None]:
class ParetoRewardTracker:
    def __init__(self, objectives: List[str]):
        self.objectives = objectives
        self.pareto_frontier = []
        
    def compute_rewards(self, state: RolloutState) -> np.ndarray:
        """Return vector of rewards"""
        rewards = np.zeros(len(self.objectives))
        for i, obj in enumerate(self.objectives):
            rewards[i] = self.reward_functions[obj](state)
        return rewards
    
    def update_pareto_frontier(self, reward_vector):
        # Check if dominates or is dominated
        self.pareto_frontier = self.compute_frontier(
            self.pareto_frontier + [reward_vector]
        )

# SVRL Environment Interface

In [None]:
class SVRLEnvironment:
    def __init__(self, verification_cost_fn):
        self.cost_fn = verification_cost_fn
        
    async def query_environment(self, query: str, state: RolloutState):
        """Model pays cost to get environment info"""
        cost = self.cost_fn(query, state)
        
        if state.can_verify(cost):
            state.verification_budget_spent += cost
            result = await self.execute_verification(query)
            return result, cost
        else:
            return None, cost  # Over budget

# Efficient Trajectory buffer lengths

In [None]:
class TrajectoryBuffer:
    def __init__(self, max_buffer_size=10000):
        self.complete_trajectories = deque(maxlen=max_buffer_size)
        self.step_statistics = defaultdict(list)  # step -> rewards
        
    def add_trajectory(self, state: RolloutState):
        self.complete_trajectories.append(state)
        # Track statistics by trajectory length
        self.step_statistics[state.step_count].append(
            state.cumulative_reward
        )
    
    def sample_batch(self, strategy='completion_weighted'):
        if strategy == 'completion_weighted':
            # Prefer longer, successful trajectories
            weights = [len(t.trajectory) for t in self.complete_trajectories]
        elif strategy == 'pareto_weighted':
            # Sample from Pareto frontier more often
            weights = [self.is_pareto(t) for t in self.complete_trajectories]
        return random.choices(self.complete_trajectories, weights=weights, k=batch_size)

# training loop with early stopping

In [None]:
async def training_loop():
    rollout_manager = AsyncRolloutManager()
    
    while not converged:
        # Start new rollouts
        for _ in range(num_parallel_rollouts):
            rollout_manager.start_rollout()
        
        # Step all active rollouts
        while rollout_manager.has_active():
            await rollout_manager.step_all()
            
            # Check termination conditions per rollout
            for rid, state in rollout_manager.active_rollouts.items():
                if should_terminate(state):
                    state.terminated = True
                    state.termination_reason = get_termination_reason(state)
                    rollout_manager.complete_rollout(rid)
        
        # Update model with completed trajectories
        batch = trajectory_buffer.sample_batch()
        loss = compute_loss(batch)
        optimizer.step()

# Modular Configuration System

Implementation details

Use padding masks for batching different length sequences
Store trajectories in a format that allows fast slicing (e.g., Apache Arrow)
Implement priority sampling based on trajectory quality/length

In [None]:
def train_step_pareto(batch):
    # Sample random weight vector for this step
    weights = np.random.dirichlet(np.ones(num_objectives))
    
    # Compute weighted sum of objectives
    loss = sum(w * obj_loss for w, obj_loss in zip(weights, losses))
    
    # Track all objectives separately
    metrics = {f"obj_{i}": loss.item() for i, loss in enumerate(losses)}

In [None]:
async def svrl_generate(prompt, state):
    # Model first decides whether to verify
    verify_prompt = f"{prompt}\n\nShould I verify? (cost: {cost})"
    decision = await model.generate(verify_prompt)
    
    if "VERIFY:" in decision:
        query = extract_verification_query(decision)
        info, cost = await env.query_environment(query, state)
        # Continue generation with new info
        enriched_prompt = f"{prompt}\nVerification result: {info}"
        return await model.generate(enriched_prompt)
    else:
        return await model.generate(prompt)