# Multi-Turn RL Evaluation

Framework for evaluating multi-turn RL rollouts where:
- Tasks run for ~50 steps (some terminate early)
- Every step is recorded for each rollout
- Rewards assigned at the end
- Data collected for backprop

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict, Union, Optional, Any, Tuple
from copy import deepcopy
import json
import pickle
from dataclasses import dataclass, field, asdict
from datetime import datetime

In [None]:
class LLMGenerator:
    def __init__(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto", device_map="auto")
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"  # decoder-only: left padding
        self.think_start_id = 151667  # <think>
        self.think_end_id = 151668    # </think>

    def generate(self,
                 prompts: Union[str, List[str]],
                 max_new_tokens: int = 512,
                 temperature: float = 0.6,
                 num_return_sequences: int = 1,
                 enable_thinking: bool = True,
                 return_thinking: bool = True,
                 **kwargs
                 ) -> Union[str, List[str]]:

        single_prompt = isinstance(prompts[0], dict)

        print("single prompt", single_prompt)

        if single_prompt:
            prompts = [prompts]

        # Apply chat template with thinking enabled/disabled
        texts = [
            self.tokenizer.apply_chat_template(
                prompt,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=enable_thinking
            )
            for prompt in prompts
        ]

        model_inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=temperature > 0,
                num_return_sequences=num_return_sequences,
                **kwargs
            )

        results = self._decode_batch(
            outputs,
            model_inputs.input_ids,
            num_return_sequences=num_return_sequences,
            return_thinking=return_thinking
        )

        return results[0] if single_prompt else results

    def _decode_batch(self, outputs, input_ids, num_return_sequences, return_thinking):
        batch_size = input_ids.shape[0]
        prompt_lens = [input_ids[i].shape[0] for i in range(batch_size)]

        # Reshape to [batch_size, num_return_sequences, seq_len]
        outputs = outputs.view(batch_size, num_return_sequences, -1)

        results = []
        for batch_idx in range(batch_size):
            prompt_len = prompt_lens[batch_idx]
            batch_results = []

            for seq_idx in range(num_return_sequences):
                # Get full generated sequence for this sample
                full_seq = outputs[batch_idx, seq_idx]

                # Slice off prompt tokens to get model's output only
                output_ids = full_seq[prompt_len:].tolist()

                if return_thinking:
                    thinking, content = self._parse_thinking(output_ids)
                    # Return tuple of (thinking, content)
                    batch_results.append([thinking, content])
                else:
                    # Standard: decode everything after prompt
                    content = self.tokenizer.decode(
                        full_seq[prompt_len:],
                        skip_special_tokens=True
                    ).strip()
                    batch_results.append(content)

            # If num_return_sequences=1, unwrap the list
            results.append(batch_results[0] if num_return_sequences == 1 else batch_results)

        return results

    def _parse_thinking(self, output_ids: List[int]) -> Tuple[str, str]:
        """
        Split thinking and content at the </think> token.
        Uses reverse index trick to find last occurrence.
        """
        try:
            # Find last </think> token (handles nested <think> tags)
            # output_ids[::-1] creates reversed list
            # .index() finds first occurrence in reversed = last in original
            think_end_idx = len(output_ids) - output_ids[::-1].index(self.think_end_id)

            # Include </think> in thinking part
            thinking_ids = output_ids[:think_end_idx]
            content_ids = output_ids[think_end_idx:]
        except ValueError:
            # No </think> token found - model skipped thinking
            thinking_ids = []
            content_ids = output_ids

        thinking = self.tokenizer.decode(thinking_ids, skip_special_tokens=True)
        content = self.tokenizer.decode(content_ids, skip_special_tokens=True)

        return thinking.strip(), content.strip()

In [None]:
import re


class ToolUseLLM:
    """
    Tool-using LLM with Qwen's Hermes-style function calling.
    
    Answers two key questions:
    1. How do you know if there's a tool call? → Check for <tool_call> tags
    2. How do you execute efficiently? → Parse JSON and use **kwargs
    """
    
    def __init__(self, model_name: str):
        """Initialize model with tool calling support"""
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype="auto",
            device_map="auto"
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def has_tool_call(self, response_text: str) -> bool:
        """
        Q1: How do you know if the LLM responded with a tool call?
        A: Check if response contains <tool_call>...</tool_call> tags
        
        Qwen3 (Hermes-style) outputs:
        <tool_call>
        {"name": "function_name", "arguments": {"arg": "value"}}
        </tool_call>
        """
        return '<tool_call>' in response_text and '</tool_call>' in response_text
    
    def execute_tool_call(
        self, 
        tool_call: Dict[str, Any], 
        tool_implementations: Dict[str, callable]
    ) -> str:
        """
        Q2: How do you execute the tool call with supplied parameters efficiently?
        A: Parse JSON arguments and use **kwargs unpacking
        
        Steps:
        1. Extract function name and arguments from tool_call
        2. If arguments are JSON string, parse them
        3. Call function with **kwargs: func(**args)
        
        Example:
            tool_call = {
                "function": {
                    "name": "get_weather",
                    "arguments": {"location": "SF", "unit": "celsius"}
                }
            }
            
            # Executes: get_weather(location="SF", unit="celsius")
        """
        tool_name = tool_call["function"]["name"]
        tool_args = tool_call["function"]["arguments"]
        
        # Step 1: Validate tool exists
        if tool_name not in tool_implementations:
            return f"Tool {tool_name} not found"
        
        # Step 2: Parse arguments if they're a JSON string
        if isinstance(tool_args, str):
            try:
                tool_args = json.loads(tool_args)
            except json.JSONDecodeError as e:
                return f"Error parsing arguments: {str(e)}"
        
        # Step 3: Execute with **kwargs unpacking (EFFICIENT!)
        try:
            result = tool_implementations[tool_name](**tool_args)
            return json.dumps(result) if not isinstance(result, str) else result
        except Exception as e:
            return f"Error executing {tool_name}: {str(e)}"
    
    def run(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict[str, Any]],
        tool_implementations: Dict[str, callable],
        max_turns: int = 10,
        max_new_tokens: int = 1024,
        temperature: float = 0.6
    ) -> Dict[str, Any]:
        """
        Run multi-turn agent loop with tool calling.
        
        Args:
            messages: Initial conversation history
            tools: Tool definitions in OpenAI format
            tool_implementations: Dict mapping tool names to functions
            max_turns: Maximum agent turns
            max_new_tokens: Max tokens per generation
            temperature: Sampling temperature
        
        Returns:
            {
                "messages": full conversation history,
                "final_response": last assistant message,
                "tool_calls": list of all tool calls made,
                "turns": number of turns taken
            }
        """
        conversation = deepcopy(messages)
        all_tool_calls = []
        
        for turn in range(max_turns):
            # Generate response with tools in context
            prompt = self.tokenizer.apply_chat_template(
                conversation,
                tools=tools,
                tokenize=False,
                add_generation_prompt=True
            )
            
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    do_sample=temperature > 0,
                    pad_token_id=self.tokenizer.pad_token_id
                )
            
            response_ids = outputs[0][inputs.input_ids.shape[1]:]
            response_text = self.tokenizer.decode(response_ids, skip_special_tokens=True)
            
            # Q1: Check if response has tool calls
            if not self.has_tool_call(response_text):
                # No tool calls → final response
                conversation.append({
                    "role": "assistant",
                    "content": response_text
                })
                return {
                    "messages": conversation,
                    "final_response": response_text,
                    "tool_calls": all_tool_calls,
                    "turns": turn + 1
                }
            
            # Parse tool calls
            tool_calls = self._parse_tool_calls(response_text)
            
            conversation.append({
                "role": "assistant",
                "content": response_text,
                "tool_calls": tool_calls
            })
            
            # Q2: Execute each tool call efficiently
            for tool_call in tool_calls:
                all_tool_calls.append(tool_call)
                
                result_str = self.execute_tool_call(tool_call, tool_implementations)
                
                # Add tool result to conversation
                conversation.append({
                    "role": "tool",
                    "name": tool_call["function"]["name"],
                    "content": result_str
                })
        
        return {
            "messages": conversation,
            "final_response": "Max turns reached",
            "tool_calls": all_tool_calls,
            "turns": max_turns
        }
    
    def _parse_tool_calls(self, response_text: str) -> List[Dict[str, Any]]:
        """
        Parse Hermes-style tool calls from response text.
        
        Format: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
        
        Returns OpenAI-format tool call dicts
        """
        tool_calls = []
        
        pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
        matches = re.findall(pattern, response_text, re.DOTALL)
        
        for i, match in enumerate(matches):
            try:
                tool_data = json.loads(match)
                
                tool_call = {
                    "id": f"call_{i}",
                    "type": "function",
                    "function": {
                        "name": tool_data.get("name", ""),
                        "arguments": tool_data.get("arguments", {})
                    }
                }
                tool_calls.append(tool_call)
            except json.JSONDecodeError:
                continue
        
        return tool_calls

In [None]:
def generate_templates(prompts):
    """Returns an array of prompts. Array of 1 prompt if only one."""
    gen_p = [{"role": "user", "content": p} for p in prompts]
    return gen_p


def batch_history(history, prompts):
    """returns a batch of history + prompts based on the number of prompts"""
    print(len(prompts))
    batch = []
    templated_prompts = generate_templates(prompts)
    for i in range(len(templated_prompts)):
        k = deepcopy(history)
        k.append(templated_prompts[i])
        batch.append(k)
    return batch


def create_batch(bh, num_gen, outputs, include_thinking=False):
    """flattens a prompts x generations x content array to (prompts * generations) x content array"""
    batch_size = len(bh)
    bh_history = []
    bh_wthinking = []
    for i in range(batch_size):
        _wo_thinking = deepcopy(bh[i])
        _thinking = deepcopy(bh[i])
        if num_gen > 1:  # 3d array
            for j in range(num_gen):
                if include_thinking:
                    _thinking.append({"role": "assistant", "content": outputs[i][j]})

                _wo_thinking.append({"role": "assistant", "content": outputs[i][j][1]})
                bh_history.append(_wo_thinking)
                bh_wthinking.append(_thinking)
        else:  # 2d array
            a = deepcopy(bh[i])
            if include_thinking:
                _thinking.append({"role": "assistant", "content": outputs[i]})

            _wo_thinking.append({"role": "assistant", "content": outputs[i][1]})
            bh_history.append(_wo_thinking)
            bh_wthinking.append(_thinking)
    return bh_history, bh_wthinking


def res(generator, batch, max_new_tokens=1000, num_return_sequences=1):
    outputs = generator.generate(batch, max_new_tokens=max_new_tokens, num_return_sequences=num_return_sequences)
    gens_with_history, gens_w_thinking = create_batch(batch, num_return_sequences, outputs, include_thinking=True)
    return outputs, gens_with_history, gens_w_thinking

In [None]:
@dataclass
class StepData:
    """Data for a single step in a rollout"""
    step_num: int
    prompt: str
    thinking: str
    content: str
    history_wo_thinking: List[Dict[str, str]]
    history_w_thinking: List[Dict[str, str]]
    tokens: List[int] = field(default_factory=list)  # Token IDs for backprop


@dataclass
class RolloutData:
    """Complete data for a single rollout"""
    rollout_id: int
    initial_prompt: str
    steps: List[StepData] = field(default_factory=list)
    is_complete: bool = False
    terminated_early: bool = False
    termination_step: int = -1
    reward: float = 0.0
    metadata: Dict[str, Any] = field(default_factory=dict)

    def add_step(self, step_data: StepData):
        self.steps.append(step_data)

    def get_total_steps(self) -> int:
        return len(self.steps)

    def mark_complete(self, early: bool = False):
        self.is_complete = True
        self.terminated_early = early
        self.termination_step = len(self.steps)

    def set_reward(self, reward: float):
        self.reward = reward

    def to_dict(self):
        return asdict(self)

In [None]:
class RolloutManager:
    """Manages multiple parallel rollouts for multi-turn RL evaluation"""

    def __init__(self, generator: LLMGenerator, n_rollouts: int, max_steps: int = 50):
        self.generator = generator
        self.n_rollouts = n_rollouts
        self.max_steps = max_steps
        self.rollouts: List[RolloutData] = []
        self.active_rollout_indices: List[int] = []

    def initialize_rollouts(self, initial_prompts: List[str]):
        """Initialize n rollouts with initial prompts"""
        assert len(initial_prompts) == self.n_rollouts, "Number of prompts must match n_rollouts"

        self.rollouts = [
            RolloutData(
                rollout_id=i,
                initial_prompt=initial_prompts[i],
                metadata={"created_at": datetime.now().isoformat()}
            )
            for i in range(self.n_rollouts)
        ]
        self.active_rollout_indices = list(range(self.n_rollouts))
        print(f"Initialized {self.n_rollouts} rollouts")

    def execute_step(self, step_num: int, prompts: List[str], max_new_tokens: int = 1000):
        """Execute a single step for all active rollouts"""
        if not self.active_rollout_indices:
            print("No active rollouts")
            return

        # Get active rollouts
        active_rollouts = [self.rollouts[i] for i in self.active_rollout_indices]

        # Build histories for each active rollout
        histories = []
        for rollout in active_rollouts:
            if len(rollout.steps) == 0:
                # First step: empty history
                histories.append([])
            else:
                # Use the last step's history without thinking
                histories.append(rollout.steps[-1].history_wo_thinking)

        # Create batch with new prompts
        batch = []
        for i, prompt in enumerate(prompts):
            hist = deepcopy(histories[i])
            hist.append({"role": "user", "content": prompt})
            batch.append(hist)

        # Generate outputs (num_return_sequences=1 as per requirement)
        outputs, gens_with_history, gens_w_thinking = res(
            self.generator,
            batch,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1
        )

        # Store step data for each active rollout
        for i, rollout_idx in enumerate(self.active_rollout_indices):
            thinking, content = outputs[i]  # outputs[i] is [thinking, content]

            step_data = StepData(
                step_num=step_num,
                prompt=prompts[i],
                thinking=thinking,
                content=content,
                history_wo_thinking=gens_with_history[i],
                history_w_thinking=gens_w_thinking[i],
                tokens=[]  # Will be populated if needed for backprop
            )

            self.rollouts[rollout_idx].add_step(step_data)

        print(f"Step {step_num} completed for {len(self.active_rollout_indices)} rollouts")

    def check_termination_conditions(self, termination_fn=None) -> List[int]:
        """Check which rollouts should terminate. Returns indices of rollouts to terminate."""
        to_terminate = []

        for rollout_idx in self.active_rollout_indices:
            rollout = self.rollouts[rollout_idx]

            # Check max steps
            if rollout.get_total_steps() >= self.max_steps:
                to_terminate.append(rollout_idx)
                rollout.mark_complete(early=False)
                continue

            # Custom termination function
            if termination_fn and rollout.steps:
                last_step = rollout.steps[-1]
                if termination_fn(rollout, last_step):
                    to_terminate.append(rollout_idx)
                    rollout.mark_complete(early=True)

        # Remove terminated rollouts from active list
        for idx in to_terminate:
            self.active_rollout_indices.remove(idx)

        if to_terminate:
            print(f"Terminated {len(to_terminate)} rollouts. {len(self.active_rollout_indices)} still active")

        return to_terminate

    def run_rollouts(self, prompt_generator_fn, termination_fn=None, max_new_tokens: int = 1000):
        """
        Run all rollouts until completion.

        Args:
            prompt_generator_fn: Function that takes (rollout_data, step_num) and returns next prompt
            termination_fn: Optional function that takes (rollout_data, step_data) and returns bool
            max_new_tokens: Max tokens to generate per step
        """
        step_num = 0

        while self.active_rollout_indices and step_num < self.max_steps:
            # Generate prompts for active rollouts
            prompts = []
            for rollout_idx in self.active_rollout_indices:
                rollout = self.rollouts[rollout_idx]
                prompt = prompt_generator_fn(rollout, step_num)
                prompts.append(prompt)

            # Execute step
            self.execute_step(step_num, prompts, max_new_tokens)

            # Check termination
            self.check_termination_conditions(termination_fn)

            step_num += 1

        # Mark any remaining active rollouts as complete
        for rollout_idx in self.active_rollout_indices:
            self.rollouts[rollout_idx].mark_complete(early=False)

        print(f"\nAll rollouts complete. Total steps: {step_num}")

    def assign_rewards(self, reward_fn):
        """Assign rewards to all rollouts using provided reward function"""
        for rollout in self.rollouts:
            reward = reward_fn(rollout)
            rollout.set_reward(reward)
        print(f"Rewards assigned to {len(self.rollouts)} rollouts")

    def get_rollout_summary(self) -> Dict[str, Any]:
        """Get summary statistics of all rollouts"""
        total_steps = [r.get_total_steps() for r in self.rollouts]
        early_terminations = sum(1 for r in self.rollouts if r.terminated_early)
        rewards = [r.reward for r in self.rollouts]

        return {
            "total_rollouts": len(self.rollouts),
            "avg_steps": sum(total_steps) / len(total_steps) if total_steps else 0,
            "min_steps": min(total_steps) if total_steps else 0,
            "max_steps": max(total_steps) if total_steps else 0,
            "early_terminations": early_terminations,
            "avg_reward": sum(rewards) / len(rewards) if rewards else 0,
            "step_distribution": total_steps,
            "rewards": rewards
        }

    def save_rollouts(self, filepath: str, format: str = "pickle"):
        """Save all rollout data to file"""
        if format == "pickle":
            with open(filepath, "wb") as f:
                pickle.dump(self.rollouts, f)
        elif format == "json":
            with open(filepath, "w") as f:
                data = [r.to_dict() for r in self.rollouts]
                json.dump(data, f, indent=2)
        else:
            raise ValueError(f"Unknown format: {format}")

        print(f"Saved {len(self.rollouts)} rollouts to {filepath}")

    @staticmethod
    def load_rollouts(filepath: str, format: str = "pickle") -> List[RolloutData]:
        """Load rollout data from file"""
        if format == "pickle":
            with open(filepath, "rb") as f:
                return pickle.load(f)
        elif format == "json":
            with open(filepath, "r") as f:
                data = json.load(f)
                # Would need to reconstruct RolloutData objects from dicts
                return data
        else:
            raise ValueError(f"Unknown format: {format}")

In [None]:
def dummy_reward_function(rollout: RolloutData) -> float:
    """
    Dummy reward function - replace with actual task-specific logic.
    
    Examples:
    - Check if final output matches expected answer
    - Count number of correct steps
    - Measure similarity to reference solution
    """
    # Example: reward based on number of steps completed
    num_steps = rollout.get_total_steps()
    
    # Higher reward for completing more steps (up to max)
    if rollout.terminated_early:
        # Penalty for early termination
        return num_steps * 0.5
    else:
        # Bonus for completing full rollout
        return num_steps * 1.0 + 10.0


def dummy_termination_check(rollout: RolloutData, last_step: StepData) -> bool:
    """
    Dummy termination check - replace with actual task-specific logic.
    
    Examples:
    - Check if output contains stop phrase ("DONE", "FINAL ANSWER:", etc.)
    - Check if task objective is met
    - Check for errors or invalid states
    """
    # Example: terminate if content contains "DONE"
    if "DONE" in last_step.content.upper():
        return True
    
    # Example: random early termination (10% chance) to simulate giving up
    import random
    if random.random() < 0.1:
        return True
    
    return False


def dummy_prompt_generator(rollout: RolloutData, step_num: int) -> str:
    """
    Dummy prompt generator - replace with actual task-specific logic.
    
    Examples:
    - For math: generate next problem in sequence
    - For coding: provide next test case or requirement
    - For reasoning: ask follow-up questions
    """
    if step_num == 0:
        # First step uses initial prompt
        return rollout.initial_prompt
    else:
        # Subsequent steps: simple continuation
        return f"Continue with step {step_num + 1}"

In [None]:
# Example usage

# Initialize generator
generator = LLMGenerator("Qwen/Qwen3-0.6B")

# Setup rollouts
n_rollouts = 4
max_steps = 50

# Create initial prompts for each rollout
initial_prompts = [
    "Solve this problem step by step: What is 15 + 27?",
    "Write a story about a robot learning to cook.",
    "Explain how photosynthesis works.",
    "Debug this code: for i in range(10) print(i)"
]

# Create manager
manager = RolloutManager(generator, n_rollouts=n_rollouts, max_steps=max_steps)

In [None]:
# Initialize rollouts
manager.initialize_rollouts(initial_prompts)

In [None]:
# Run all rollouts to completion
manager.run_rollouts(
    prompt_generator_fn=dummy_prompt_generator,
    termination_fn=dummy_termination_check,
    max_new_tokens=512
)

In [None]:
# Assign rewards
manager.assign_rewards(dummy_reward_function)

In [None]:
# Get summary
summary = manager.get_rollout_summary()
print("\n=== Rollout Summary ===")
for key, value in summary.items():
    print(f"{key}: {value}")

In [None]:
# Save rollouts for later analysis/backprop
manager.save_rollouts("rollout_data.pkl", format="pickle")

In [None]:
# Inspect individual rollout
rollout = manager.rollouts[0]
print(f"\n=== Rollout {rollout.rollout_id} ===")
print(f"Initial prompt: {rollout.initial_prompt}")
print(f"Total steps: {rollout.get_total_steps()}")
print(f"Terminated early: {rollout.terminated_early}")
print(f"Reward: {rollout.reward}")
print(f"\nSteps:")
for step in rollout.steps:
    print(f"  Step {step.step_num}: {step.prompt[:50]}... -> {step.content[:50]}...")

# ========================================
# TOOL USE / FUNCTION CALLING
# ========================================

In [None]:
# Example: Using ToolUseLLM for agentic workflows

# Define tools in OpenAI format
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA"
                    },
                    "unit": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "Temperature unit"
                    }
                },
                "required": ["location"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "calculate",
            "description": "Perform a mathematical calculation",
            "parameters": {
                "type": "object",
                "properties": {
                    "expression": {
                        "type": "string",
                        "description": "The mathematical expression to evaluate"
                    }
                },
                "required": ["expression"]
            }
        }
    }
]

# Implement the tools
def get_weather(location: str, unit: str = "fahrenheit") -> str:
    \"\"\"Dummy weather function\"\"\"
    temps = {
        "San Francisco, CA": (65, 18),
        "New York, NY": (50, 10),
        "London, UK": (55, 13)
    }
    temp_f, temp_c = temps.get(location, (70, 21))
    temp = temp_c if unit == "celsius" else temp_f
    return f"The weather in {location} is {temp}°{unit[0].upper()} and sunny"

def calculate(expression: str) -> float:
    \"\"\"Safe calculator\"\"\"
    try:
        # Simple eval (use safe evaluation in production!)
        result = eval(expression, {"__builtins__": {}}, {})
        return result
    except Exception as e:
        return f"Error: {str(e)}"

# Map tool names to implementations
tool_implementations = {
    "get_weather": get_weather,
    "calculate": calculate
}

# Initialize ToolUseLLM
tool_llm = ToolUseLLM("Qwen/Qwen3-0.6B")

In [None]:
# DEMO: Showing the two key mechanisms explicitly

# Initialize
tool_llm_demo = ToolUseLLM("Qwen/Qwen3-0.6B")

# Example response from model
example_response = """
I'll check the weather for you.
<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "celsius"}}
</tool_call>
"""

print("=== DEMO: Tool Calling Mechanisms ===\n")

# Q1: How do you know there's a tool call?
has_call = tool_llm_demo.has_tool_call(example_response)
print(f"Q1: Does response have tool call? {has_call}")
print(f"    Detection: Check for '<tool_call>' tags in response text\n")

# Parse the call
tool_calls = tool_llm_demo._parse_tool_calls(example_response)
print(f"Parsed {len(tool_calls)} tool call(s):")
print(f"  {tool_calls[0]}\n")

# Q2: How do you execute efficiently?
print("Q2: Executing tool call...")
result = tool_llm_demo.execute_tool_call(
    tool_calls[0],
    {"get_weather": get_weather}  # From earlier definition
)
print(f"    Method: **kwargs unpacking")
print(f"    Execution: get_weather(location='San Francisco, CA', unit='celsius')")
print(f"    Result: {result}\n")

print("Key insight: Detection = string check, Execution = **kwargs unpacking!")

In [None]:
# Example 1: Weather query
result = tool_llm.run(
    messages=[
        {"role": "user", "content": "What's the weather in San Francisco?"}
    ],
    tools=tools,
    tool_implementations=tool_implementations,
    max_turns=5
)

print("=== Weather Query Example ===")
print(f"Turns: {result['turns']}")
print(f"Tool calls made: {len(result['tool_calls'])}")
print(f"\nFinal response: {result['final_response']}")
print(f"\nFull conversation:")
for msg in result['messages']:
    role = msg['role']
    content = msg.get('content', '')
    print(f"{role}: {content[:100]}...")

In [None]:
# Example 2: Multi-step calculation
result = tool_llm.run(
    messages=[
        {"role": "user", "content": "Calculate 15 * 27, then add 100 to the result"}
    ],
    tools=tools,
    tool_implementations=tool_implementations,
    max_turns=10
)

print("\n=== Multi-step Calculation Example ===")
print(f"Turns: {result['turns']}")
print(f"Tool calls made: {len(result['tool_calls'])}")
print(f"\nTool calls:")
for i, tc in enumerate(result['tool_calls']):
    print(f"  {i+1}. {tc['function']['name']}({tc['function']['arguments']})")
print(f"\nFinal response: {result['final_response']}")

## Tool Use Implementation Notes

**Q1: How do you know if the LLM responded with a tool call?**
```python
def has_tool_call(response_text: str) -> bool:
    # Qwen3 outputs: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
    return '<tool_call>' in response_text and '</tool_call>' in response_text
```

**Q2: How do you execute the tool call efficiently?**
```python
def execute_tool_call(tool_call, tool_implementations):
    tool_name = tool_call["function"]["name"]
    tool_args = tool_call["function"]["arguments"]
    
    # Parse if JSON string
    if isinstance(tool_args, str):
        tool_args = json.loads(tool_args)
    
    # Execute with **kwargs unpacking (EFFICIENT!)
    return tool_implementations[tool_name](**tool_args)
```

**Example:**
```python
# Tool call from LLM:
# <tool_call>{"name": "get_weather", "arguments": {"location": "SF", "unit": "celsius"}}</tool_call>

# Parsed to:
tool_call = {
    "function": {
        "name": "get_weather",
        "arguments": {"location": "SF", "unit": "celsius"}
    }
}

# Executed as:
result = get_weather(location="SF", unit="celsius")  # **kwargs unpacking
```

**Key Points:**
- Detection: Simple string check for `<tool_call>` tags
- Execution: `**kwargs` unpacking for clean parameter passing
- Hermes format: JSON inside XML tags (Qwen3 native)
- Multi-turn: Loop until no more tool calls detected

In [None]:
# Access full step data for backprop
# Each step contains:
# - step.prompt: the input prompt
# - step.thinking: the thinking process
# - step.content: the actual response
# - step.history_wo_thinking: conversation history without thinking
# - step.history_w_thinking: conversation history with thinking
# - step.tokens: token IDs (can be populated during generation if needed)

# Example: get all responses for backprop
for rollout in manager.rollouts:
    print(f"\nRollout {rollout.rollout_id} - Reward: {rollout.reward}")
    for step in rollout.steps:
        # Here you would:
        # 1. Re-tokenize or use stored tokens
        # 2. Compute loss with reward signal
        # 3. Backprop gradients
        print(f"  Step {step.step_num}: content length = {len(step.content)} chars")

# ========================================
# RL TRAINING COMPONENTS
# ========================================

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np


@dataclass
class RLStepData(StepData):
    """Extended step data for RL training with logprobs and values"""
    logprobs: torch.Tensor = None  # Log probabilities of generated tokens
    values: torch.Tensor = None     # Value estimates V(s_t)
    response_tokens: List[int] = field(default_factory=list)  # Response token IDs
    advantages: float = 0.0         # Computed advantage A_t
    returns: float = 0.0            # Computed return R_t
    old_logprobs: torch.Tensor = None  # For PPO ratio calculation


class SimpleValueHead(nn.Module):
    """Simple value head that attaches to the LM"""
    def __init__(self, hidden_size: int):
        super().__init__()
        self.value_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )
    
    def forward(self, hidden_states):
        """
        Args:
            hidden_states: [batch_size, seq_len, hidden_size]
        Returns:
            values: [batch_size, seq_len, 1]
        """
        return self.value_head(hidden_states)

In [None]:
def compute_advantages_and_returns(
    rollouts: List[RolloutData],
    gamma: float = 0.99,
    lam: float = 0.95,
    use_gae: bool = True
) -> None:
    """
    Compute advantages and returns for all rollouts using GAE.
    
    For outcome-based RL (reward at end):
    - R_T = reward (final reward)
    - R_t = 0 for t < T
    - V(s_t) is the value estimate
    - δ_t = r_t + γV(s_{t+1}) - V(s_t)  
    - A_t = δ_t + γλδ_{t+1} + (γλ)^2δ_{t+2} + ...
    
    Args:
        rollouts: List of rollout data with values populated
        gamma: Discount factor
        lam: GAE lambda parameter
        use_gae: Whether to use GAE (True) or simple advantage (False)
    """
    for rollout in rollouts:
        T = len(rollout.steps)
        if T == 0:
            continue
        
        # Extract values for this rollout
        values = []
        for step in rollout.steps:
            if hasattr(step, 'values') and step.values is not None:
                # Take last token value as state value
                val = step.values[-1].item() if torch.is_tensor(step.values) else step.values
                values.append(val)
            else:
                values.append(0.0)
        
        values = np.array(values)
        
        # Rewards: 0 everywhere except final step
        rewards = np.zeros(T)
        rewards[-1] = rollout.reward
        
        if use_gae:
            # GAE: A_t = Σ_{l=0}^{T-t-1} (γλ)^l δ_{t+l}
            deltas = []
            for t in range(T):
                if t < T - 1:
                    delta = rewards[t] + gamma * values[t + 1] - values[t]
                else:
                    # Last step: no next value
                    delta = rewards[t] - values[t]
                deltas.append(delta)
            
            deltas = np.array(deltas)
            
            # Compute advantages via reverse iteration
            advantages = np.zeros(T)
            advantages[-1] = deltas[-1]
            for t in reversed(range(T - 1)):
                advantages[t] = deltas[t] + gamma * lam * advantages[t + 1]
        else:
            # Simple advantage: A_t = R_t - V(s_t)
            # where R_t is the discounted return from time t
            returns = np.zeros(T)
            returns[-1] = rewards[-1]
            for t in reversed(range(T - 1)):
                returns[t] = rewards[t] + gamma * returns[t + 1]
            
            advantages = returns - values
        
        # Compute returns: R_t = A_t + V(s_t)
        returns = advantages + values
        
        # Store in step data
        for t, step in enumerate(rollout.steps):
            if hasattr(step, 'advantages'):
                step.advantages = float(advantages[t])
                step.returns = float(returns[t])
    
    print(f\"Computed advantages and returns for {len(rollouts)} rollouts using GAE={use_gae}\")

In [None]:
def compute_ppo_loss(
    logprobs: torch.Tensor,
    old_logprobs: torch.Tensor,
    advantages: torch.Tensor,
    clip_epsilon: float = 0.2
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Compute PPO clipped policy loss.
    
    L^CLIP(θ) = -E[min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A)]
    where r(θ) = π_θ(a|s) / π_θ_old(a|s) = exp(logprob - old_logprob)
    
    Args:
        logprobs: Log probabilities from current policy [batch_size, seq_len]
        old_logprobs: Log probabilities from old policy [batch_size, seq_len]
        advantages: Advantage estimates [batch_size]
        clip_epsilon: PPO clip parameter
    
    Returns:
        loss: Policy loss (scalar)
        stats: Dictionary with statistics
    """
    # Compute probability ratio
    log_ratio = logprobs - old_logprobs
    ratio = torch.exp(log_ratio)
    
    # Expand advantages to match sequence dimension if needed
    if advantages.dim() == 1 and logprobs.dim() == 2:
        advantages = advantages.unsqueeze(-1)  # [batch_size, 1]
    
    # Compute surrogate losses
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages
    
    # PPO loss is negative of minimum (we want to maximize)
    policy_loss = -torch.min(surr1, surr2).mean()
    
    # Statistics
    with torch.no_grad():
        approx_kl = ((ratio - 1) - log_ratio).mean().item()
        clipfrac = ((ratio - 1.0).abs() > clip_epsilon).float().mean().item()
    
    stats = {
        'policy_loss': policy_loss.item(),
        'approx_kl': approx_kl,
        'clipfrac': clipfrac,
        'ratio_mean': ratio.mean().item(),
        'ratio_std': ratio.std().item()
    }
    
    return policy_loss, stats


def compute_value_loss(
    values: torch.Tensor,
    returns: torch.Tensor,
    clip_value: bool = False,
    old_values: Optional[torch.Tensor] = None,
    clip_epsilon: float = 0.2
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Compute value function loss.
    
    L^VF(θ) = (V_θ(s) - R)^2
    
    With optional clipping (PPO-style):
    L^VF_CLIP(θ) = max((V - R)^2, (clip(V, V_old - ε, V_old + ε) - R)^2)
    
    Args:
        values: Value predictions [batch_size, seq_len] or [batch_size]
        returns: Target returns [batch_size]
        clip_value: Whether to use value clipping
        old_values: Old value predictions (needed if clip_value=True)
        clip_epsilon: Clip parameter
    
    Returns:
        loss: Value loss (scalar)
        stats: Dictionary with statistics
    """
    # Take last value if sequence
    if values.dim() == 2:
        values = values[:, -1]  # [batch_size]
    
    if clip_value and old_values is not None:
        if old_values.dim() == 2:
            old_values = old_values[:, -1]
        
        # Clipped value loss
        values_clipped = old_values + torch.clamp(
            values - old_values, -clip_epsilon, clip_epsilon
        )
        vf_loss1 = (values - returns) ** 2
        vf_loss2 = (values_clipped - returns) ** 2
        value_loss = torch.max(vf_loss1, vf_loss2).mean()
    else:
        # Standard MSE loss
        value_loss = F.mse_loss(values, returns)
    
    stats = {
        'value_loss': value_loss.item(),
        'value_mean': values.mean().item(),
        'value_std': values.std().item(),
        'returns_mean': returns.mean().item()
    }
    
    return value_loss, stats


def compute_kl_divergence(
    logprobs: torch.Tensor,
    old_logprobs: torch.Tensor
) -> torch.Tensor:
    """
    Compute KL divergence between old and new policy.
    
    KL(π_old || π_new) = E[log(π_old) - log(π_new)]
                        = E[old_logprobs - logprobs]
    
    Args:
        logprobs: Log probabilities from new policy
        old_logprobs: Log probabilities from old policy
    
    Returns:
        kl: KL divergence (scalar)
    """
    kl = (old_logprobs - logprobs).mean()
    return kl

In [None]:
class RLTrainer:
    """
    Outer training loop for multi-turn RL.
    
    Workflow:
    1. Generate n rollouts (trajectories)
    2. Assign rewards to completed rollouts
    3. Compute values for all steps
    4. Calculate advantages using GAE
    5. Update model using PPO
    6. Repeat
    """
    
    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        value_head: SimpleValueHead,
        config: Dict[str, Any]
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.value_head = value_head
        self.config = config
        
        # Optimizers
        self.policy_optimizer = Adam(
            self.model.parameters(),
            lr=config.get('policy_lr', 1e-5)
        )
        self.value_optimizer = Adam(
            self.value_head.parameters(),
            lr=config.get('value_lr', 1e-4)
        )
        
        # Training stats
        self.training_stats = []
    
    def generate_with_values(
        self,
        prompts: List[List[Dict[str, str]]],
        max_new_tokens: int = 512,
        temperature: float = 0.6
    ) -> Tuple[List[Tuple[str, str]], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        """
        Generate responses and compute values + logprobs.
        
        Returns:
            outputs: List of (thinking, content) tuples
            all_logprobs: List of logprob tensors for each response
            all_values: List of value tensors for each response  
            all_tokens: List of token ID tensors for each response
        """
        # Apply chat template
        texts = [
            self.tokenizer.apply_chat_template(
                prompt,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=True
            )
            for prompt in prompts
        ]
        
        model_inputs = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors=\"pt\"
        ).to(self.model.device)
        
        # Generate with output scores and hidden states
        with torch.no_grad():
            generation_output = self.model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=temperature > 0,
                output_scores=True,
                output_hidden_states=True,
                return_dict_in_generate=True
            )
        
        generated_sequences = generation_output.sequences
        scores = generation_output.scores  # Tuple of [batch_size, vocab_size] for each step
        
        # Compute logprobs and values
        batch_size = generated_sequences.shape[0]
        prompt_lens = [model_inputs.input_ids[i].ne(self.tokenizer.pad_token_id).sum().item() 
                       for i in range(batch_size)]
        
        all_outputs = []
        all_logprobs = []
        all_values = []
        all_tokens = []
        
        for i in range(batch_size):
            prompt_len = prompt_lens[i]
            response_ids = generated_sequences[i, prompt_len:].tolist()
            
            # Decode thinking and content
            thinking, content = self._parse_thinking(response_ids)
            all_outputs.append((thinking, content))
            all_tokens.append(response_ids)
            
            # Compute logprobs for generated tokens
            logprobs = []
            for t, score in enumerate(scores):
                if t < len(response_ids):
                    token_logprobs = F.log_softmax(score[i], dim=-1)
                    selected_logprob = token_logprobs[response_ids[t]]
                    logprobs.append(selected_logprob)
            
            all_logprobs.append(torch.stack(logprobs) if logprobs else torch.tensor([]))
            
            # Compute values (dummy for now - would need hidden states)
            # In practice, you'd run value_head on the hidden states
            dummy_values = torch.zeros(len(response_ids))
            all_values.append(dummy_values)
        
        return all_outputs, all_logprobs, all_values, all_tokens
    
    def _parse_thinking(self, output_ids: List[int]) -> Tuple[str, str]:
        \"\"\"Parse thinking from token IDs\"\"\"
        think_start_id = 151667
        think_end_id = 151668
        
        try:
            think_end_idx = len(output_ids) - output_ids[::-1].index(think_end_id)
            thinking_ids = output_ids[:think_end_idx]
            content_ids = output_ids[think_end_idx:]
        except ValueError:
            thinking_ids = []
            content_ids = output_ids
        
        thinking = self.tokenizer.decode(thinking_ids, skip_special_tokens=True)
        content = self.tokenizer.decode(content_ids, skip_special_tokens=True)
        
        return thinking.strip(), content.strip()
    
    def train_step(
        self,
        rollouts: List[RolloutData],
        n_epochs: int = 4,
        batch_size: int = 32
    ) -> Dict[str, float]:
        \"\"\"
        Perform one training step on collected rollouts.
        
        Args:
            rollouts: List of completed rollouts with advantages computed
            n_epochs: Number of PPO epochs
            batch_size: Minibatch size
        
        Returns:
            stats: Training statistics
        \"\"\"
        # Flatten all steps from all rollouts
        all_steps = []
        for rollout in rollouts:
            for step in rollout.steps:
                if hasattr(step, 'logprobs') and step.logprobs is not None:
                    all_steps.append(step)
        
        if not all_steps:
            print(\"No steps with logprobs to train on\")
            return {}
        
        # Training loop
        total_stats = {
            'policy_loss': 0.0,
            'value_loss': 0.0,
            'total_loss': 0.0,
            'approx_kl': 0.0,
            'clipfrac': 0.0
        }
        
        n_updates = 0
        
        for epoch in range(n_epochs):
            # Shuffle steps
            import random
            random.shuffle(all_steps)
            
            # Mini-batch training
            for i in range(0, len(all_steps), batch_size):
                batch = all_steps[i:i + batch_size]
                
                # Prepare batch tensors
                logprobs_batch = torch.stack([s.logprobs for s in batch])
                old_logprobs_batch = torch.stack([s.old_logprobs for s in batch])
                values_batch = torch.stack([s.values for s in batch])
                advantages_batch = torch.tensor([s.advantages for s in batch])
                returns_batch = torch.tensor([s.returns for s in batch])
                
                # Move to device
                device = self.model.device
                logprobs_batch = logprobs_batch.to(device)
                old_logprobs_batch = old_logprobs_batch.to(device)
                values_batch = values_batch.to(device)
                advantages_batch = advantages_batch.to(device)
                returns_batch = returns_batch.to(device)
                
                # Normalize advantages
                advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)
                
                # Compute losses
                policy_loss, policy_stats = compute_ppo_loss(
                    logprobs_batch,
                    old_logprobs_batch,
                    advantages_batch,
                    clip_epsilon=self.config.get('clip_epsilon', 0.2)
                )
                
                value_loss, value_stats = compute_value_loss(
                    values_batch,
                    returns_batch,
                    clip_value=self.config.get('clip_value', False)
                )
                
                # Total loss
                total_loss = (
                    policy_loss +
                    self.config.get('vf_coef', 0.5) * value_loss
                )
                
                # Backprop
                self.policy_optimizer.zero_grad()
                self.value_optimizer.zero_grad()
                total_loss.backward()
                
                # Gradient clipping
                if self.config.get('max_grad_norm', 0) > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.config['max_grad_norm']
                    )
                    torch.nn.utils.clip_grad_norm_(
                        self.value_head.parameters(),
                        self.config['max_grad_norm']
                    )
                
                self.policy_optimizer.step()
                self.value_optimizer.step()
                
                # Accumulate stats
                total_stats['policy_loss'] += policy_stats['policy_loss']
                total_stats['value_loss'] += value_stats['value_loss']
                total_stats['total_loss'] += total_loss.item()
                total_stats['approx_kl'] += policy_stats['approx_kl']
                total_stats['clipfrac'] += policy_stats['clipfrac']
                n_updates += 1
        
        # Average stats
        for key in total_stats:
            total_stats[key] /= max(n_updates, 1)
        
        return total_stats
    
    def train(
        self,
        initial_prompts_generator,
        reward_function,
        termination_function,
        prompt_generator_function,
        n_iterations: int = 10,
        n_rollouts_per_iter: int = 4,
        max_steps: int = 50
    ):
        \"\"\"
        Main training loop.
        
        Args:
            initial_prompts_generator: Function that returns list of initial prompts
            reward_function: Function to compute rewards
            termination_function: Function to check termination
            prompt_generator_function: Function to generate next prompt
            n_iterations: Number of training iterations
            n_rollouts_per_iter: Number of rollouts per iteration
            max_steps: Max steps per rollout
        \"\"\"
        for iteration in range(n_iterations):
            print(f\"\\n=== Iteration {iteration + 1}/{n_iterations} ===\")
            
            # 1. Generate rollouts
            print(\"Generating rollouts...\")
            generator = LLMGenerator(self.model)  # Wrap model
            manager = RolloutManager(generator, n_rollouts_per_iter, max_steps)
            
            initial_prompts = initial_prompts_generator(n_rollouts_per_iter)
            manager.initialize_rollouts(initial_prompts)
            manager.run_rollouts(
                prompt_generator_fn=prompt_generator_function,
                termination_fn=termination_function,
                max_new_tokens=self.config.get('max_new_tokens', 512)
            )
            
            # 2. Assign rewards
            print(\"Assigning rewards...\")
            manager.assign_rewards(reward_function)
            
            # 3. Compute advantages
            print(\"Computing advantages...\")
            compute_advantages_and_returns(
                manager.rollouts,
                gamma=self.config.get('gamma', 0.99),
                lam=self.config.get('lam', 0.95)
            )
            
            # 4. Train
            print(\"Training model...\")
            stats = self.train_step(
                manager.rollouts,
                n_epochs=self.config.get('ppo_epochs', 4),
                batch_size=self.config.get('batch_size', 32)
            )
            
            # 5. Log stats
            summary = manager.get_rollout_summary()
            print(f\"\\nIteration {iteration + 1} Results:\")
            print(f\"  Avg reward: {summary['avg_reward']:.2f}\")
            print(f\"  Avg steps: {summary['avg_steps']:.2f}\")
            print(f\"  Policy loss: {stats.get('policy_loss', 0):.4f}\")
            print(f\"  Value loss: {stats.get('value_loss', 0):.4f}\")
            print(f\"  Approx KL: {stats.get('approx_kl', 0):.4f}\")
            
            self.training_stats.append({
                'iteration': iteration + 1,
                'rollout_summary': summary,
                'training_stats': stats
            })
            
            # Save checkpoint
            if (iteration + 1) % self.config.get('save_interval', 5) == 0:
                self.save_checkpoint(f\"checkpoint_iter_{iteration + 1}.pt\")
        
        print(\"\\nTraining complete!\")
        return self.training_stats
    
    def save_checkpoint(self, filepath: str):
        \"\"\"Save model checkpoint\"\"\"
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'value_head_state_dict': self.value_head.state_dict(),
            'policy_optimizer_state_dict': self.policy_optimizer.state_dict(),
            'value_optimizer_state_dict': self.value_optimizer.state_dict(),
            'training_stats': self.training_stats,
            'config': self.config
        }, filepath)
        print(f\"Checkpoint saved to {filepath}\")

# ========================================
# EXAMPLE: Using RLTrainer
# ========================================

In [None]:
# Initialize model and value head
model_name = \"Qwen/Qwen3-0.6B\"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, dtype=\"auto\", device_map=\"auto\")

# Create value head (get hidden size from model config)
hidden_size = model.config.hidden_size
value_head = SimpleValueHead(hidden_size).to(model.device)

# Training configuration
config = {
    'policy_lr': 1e-5,
    'value_lr': 1e-4,
    'gamma': 0.99,
    'lam': 0.95,
    'clip_epsilon': 0.2,
    'vf_coef': 0.5,
    'max_grad_norm': 1.0,
    'ppo_epochs': 4,
    'batch_size': 32,
    'max_new_tokens': 512,
    'save_interval': 5,
    'clip_value': False
}

# Create trainer
trainer = RLTrainer(model, tokenizer, value_head, config)

In [None]:
# Define task-specific functions
def generate_initial_prompts(n):
    \"\"\"Generate n initial prompts for rollouts\"\"\"
    prompts = [
        \"Solve: What is 15 + 27?\",
        \"Write a short story about AI.\",
        \"Explain quantum computing.\",
        \"Debug: for i in range(10) print(i)\"
    ]
    return (prompts * ((n // len(prompts)) + 1))[:n]

# Use the dummy functions defined earlier:
# - dummy_reward_function
# - dummy_termination_check  
# - dummy_prompt_generator

# Run training
training_stats = trainer.train(
    initial_prompts_generator=generate_initial_prompts,
    reward_function=dummy_reward_function,
    termination_function=dummy_termination_check,
    prompt_generator_function=dummy_prompt_generator,
    n_iterations=10,
    n_rollouts_per_iter=4,
    max_steps=50
)

In [None]:
# Analyze training progress
import matplotlib.pyplot as plt

iterations = [s['iteration'] for s in training_stats]
avg_rewards = [s['rollout_summary']['avg_reward'] for s in training_stats]
policy_losses = [s['training_stats'].get('policy_loss', 0) for s in training_stats]
value_losses = [s['training_stats'].get('value_loss', 0) for s in training_stats]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(iterations, avg_rewards)
axes[0].set_title('Average Reward over Training')
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('Reward')

axes[1].plot(iterations, policy_losses)
axes[1].set_title('Policy Loss over Training')
axes[1].set_xlabel('Iteration')
axes[1].set_ylabel('Loss')

axes[2].plot(iterations, value_losses)
axes[2].set_title('Value Loss over Training')
axes[2].set_xlabel('Iteration')
axes[2].set_ylabel('Loss')

plt.tight_layout()
plt.show()

# Key Implementation Notes

## How it works:

1. **Trajectory Generation**: `RolloutManager` creates n parallel rollouts, executing steps until completion or early termination
2. **Reward Assignment**: Rewards given at the end of each rollout (outcome-based RL)
3. **Advantage Calculation**: Uses GAE (Generalized Advantage Estimation):
   - δ_t = r_t + γV(s_{t+1}) - V(s_t)
   - A_t = Σ_{l=0}^∞ (γλ)^l δ_{t+l}
4. **PPO Loss**: 
   - Policy: L = -E[min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A)]
   - Value: L = (V(s) - R)^2
   - KL: For monitoring policy drift
5. **Training**: Multiple epochs over collected data, with gradient clipping and advantage normalization

## Data Structure:
- Each `StepData` stores: prompt, response, thinking, logprobs, values, advantages, returns
- Each `RolloutData` tracks full trajectory with all steps and final reward
- All data saved for later analysis/debugging

## To Customize:
- `reward_function`: Define your task-specific reward
- `termination_function`: Define when rollouts should stop early
- `prompt_generator_function`: Define how to generate next prompts
- Config params: Learning rates, PPO hyperparameters, etc.