# Multi-Turn RL Evaluation

Framework for evaluating multi-turn RL rollouts where:
- Tasks run for ~50 steps (some terminate early)
- Every step is recorded for each rollout
- Rewards assigned at the end
- Data collected for backprop

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict, Union, Optional, Any, Tuple
from copy import deepcopy
import json
import pickle
from dataclasses import dataclass, field, asdict
from datetime import datetime

In [None]:
class LLMGenerator:
    def __init__(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto", device_map="auto")
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"  # decoder-only: left padding
        self.think_start_id = 151667  # <think>
        self.think_end_id = 151668    # </think>

    def generate(self,
                 prompts: Union[str, List[str]],
                 max_new_tokens: int = 512,
                 temperature: float = 0.6,
                 num_return_sequences: int = 1,
                 enable_thinking: bool = True,
                 return_thinking: bool = True,
                 **kwargs
                 ) -> Union[str, List[str]]:

        single_prompt = isinstance(prompts[0], dict)

        print("single prompt", single_prompt)

        if single_prompt:
            prompts = [prompts]

        # Apply chat template with thinking enabled/disabled
        texts = [
            self.tokenizer.apply_chat_template(
                prompt,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=enable_thinking
            )
            for prompt in prompts
        ]

        model_inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=temperature > 0,
                num_return_sequences=num_return_sequences,
                **kwargs
            )

        results = self._decode_batch(
            outputs,
            model_inputs.input_ids,
            num_return_sequences=num_return_sequences,
            return_thinking=return_thinking
        )

        return results[0] if single_prompt else results

    def _decode_batch(self, outputs, input_ids, num_return_sequences, return_thinking):
        batch_size = input_ids.shape[0]
        prompt_lens = [input_ids[i].shape[0] for i in range(batch_size)]

        # Reshape to [batch_size, num_return_sequences, seq_len]
        outputs = outputs.view(batch_size, num_return_sequences, -1)

        results = []
        for batch_idx in range(batch_size):
            prompt_len = prompt_lens[batch_idx]
            batch_results = []

            for seq_idx in range(num_return_sequences):
                # Get full generated sequence for this sample
                full_seq = outputs[batch_idx, seq_idx]

                # Slice off prompt tokens to get model's output only
                output_ids = full_seq[prompt_len:].tolist()

                if return_thinking:
                    thinking, content = self._parse_thinking(output_ids)
                    # Return tuple of (thinking, content)
                    batch_results.append([thinking, content])
                else:
                    # Standard: decode everything after prompt
                    content = self.tokenizer.decode(
                        full_seq[prompt_len:],
                        skip_special_tokens=True
                    ).strip()
                    batch_results.append(content)

            # If num_return_sequences=1, unwrap the list
            results.append(batch_results[0] if num_return_sequences == 1 else batch_results)

        return results

    def _parse_thinking(self, output_ids: List[int]) -> Tuple[str, str]:
        """
        Split thinking and content at the </think> token.
        Uses reverse index trick to find last occurrence.
        """
        try:
            # Find last </think> token (handles nested <think> tags)
            # output_ids[::-1] creates reversed list
            # .index() finds first occurrence in reversed = last in original
            think_end_idx = len(output_ids) - output_ids[::-1].index(self.think_end_id)

            # Include </think> in thinking part
            thinking_ids = output_ids[:think_end_idx]
            content_ids = output_ids[think_end_idx:]
        except ValueError:
            # No </think> token found - model skipped thinking
            thinking_ids = []
            content_ids = output_ids

        thinking = self.tokenizer.decode(thinking_ids, skip_special_tokens=True)
        content = self.tokenizer.decode(content_ids, skip_special_tokens=True)

        return thinking.strip(), content.strip()

In [None]:
def generate_templates(prompts):
    """Returns an array of prompts. Array of 1 prompt if only one."""
    gen_p = [{"role": "user", "content": p} for p in prompts]
    return gen_p


def batch_history(history, prompts):
    """returns a batch of history + prompts based on the number of prompts"""
    print(len(prompts))
    batch = []
    templated_prompts = generate_templates(prompts)
    for i in range(len(templated_prompts)):
        k = deepcopy(history)
        k.append(templated_prompts[i])
        batch.append(k)
    return batch


def create_batch(bh, num_gen, outputs, include_thinking=False):
    """flattens a prompts x generations x content array to (prompts * generations) x content array"""
    batch_size = len(bh)
    bh_history = []
    bh_wthinking = []
    for i in range(batch_size):
        _wo_thinking = deepcopy(bh[i])
        _thinking = deepcopy(bh[i])
        if num_gen > 1:  # 3d array
            for j in range(num_gen):
                if include_thinking:
                    _thinking.append({"role": "assistant", "content": outputs[i][j]})

                _wo_thinking.append({"role": "assistant", "content": outputs[i][j][1]})
                bh_history.append(_wo_thinking)
                bh_wthinking.append(_thinking)
        else:  # 2d array
            a = deepcopy(bh[i])
            if include_thinking:
                _thinking.append({"role": "assistant", "content": outputs[i]})

            _wo_thinking.append({"role": "assistant", "content": outputs[i][1]})
            bh_history.append(_wo_thinking)
            bh_wthinking.append(_thinking)
    return bh_history, bh_wthinking


def res(generator, batch, max_new_tokens=1000, num_return_sequences=1):
    outputs = generator.generate(batch, max_new_tokens=max_new_tokens, num_return_sequences=num_return_sequences)
    gens_with_history, gens_w_thinking = create_batch(batch, num_return_sequences, outputs, include_thinking=True)
    return outputs, gens_with_history, gens_w_thinking

In [None]:
@dataclass
class StepData:
    """Data for a single step in a rollout"""
    step_num: int
    prompt: str
    thinking: str
    content: str
    history_wo_thinking: List[Dict[str, str]]
    history_w_thinking: List[Dict[str, str]]
    tokens: List[int] = field(default_factory=list)  # Token IDs for backprop


@dataclass
class RolloutData:
    """Complete data for a single rollout"""
    rollout_id: int
    initial_prompt: str
    steps: List[StepData] = field(default_factory=list)
    is_complete: bool = False
    terminated_early: bool = False
    termination_step: int = -1
    reward: float = 0.0
    metadata: Dict[str, Any] = field(default_factory=dict)

    def add_step(self, step_data: StepData):
        self.steps.append(step_data)

    def get_total_steps(self) -> int:
        return len(self.steps)

    def mark_complete(self, early: bool = False):
        self.is_complete = True
        self.terminated_early = early
        self.termination_step = len(self.steps)

    def set_reward(self, reward: float):
        self.reward = reward

    def to_dict(self):
        return asdict(self)

In [None]:
class RolloutManager:
    """Manages multiple parallel rollouts for multi-turn RL evaluation"""

    def __init__(self, generator: LLMGenerator, n_rollouts: int, max_steps: int = 50):
        self.generator = generator
        self.n_rollouts = n_rollouts
        self.max_steps = max_steps
        self.rollouts: List[RolloutData] = []
        self.active_rollout_indices: List[int] = []

    def initialize_rollouts(self, initial_prompts: List[str]):
        """Initialize n rollouts with initial prompts"""
        assert len(initial_prompts) == self.n_rollouts, "Number of prompts must match n_rollouts"

        self.rollouts = [
            RolloutData(
                rollout_id=i,
                initial_prompt=initial_prompts[i],
                metadata={"created_at": datetime.now().isoformat()}
            )
            for i in range(self.n_rollouts)
        ]
        self.active_rollout_indices = list(range(self.n_rollouts))
        print(f"Initialized {self.n_rollouts} rollouts")

    def execute_step(self, step_num: int, prompts: List[str], max_new_tokens: int = 1000):
        """Execute a single step for all active rollouts"""
        if not self.active_rollout_indices:
            print("No active rollouts")
            return

        # Get active rollouts
        active_rollouts = [self.rollouts[i] for i in self.active_rollout_indices]

        # Build histories for each active rollout
        histories = []
        for rollout in active_rollouts:
            if len(rollout.steps) == 0:
                # First step: empty history
                histories.append([])
            else:
                # Use the last step's history without thinking
                histories.append(rollout.steps[-1].history_wo_thinking)

        # Create batch with new prompts
        batch = []
        for i, prompt in enumerate(prompts):
            hist = deepcopy(histories[i])
            hist.append({"role": "user", "content": prompt})
            batch.append(hist)

        # Generate outputs (num_return_sequences=1 as per requirement)
        outputs, gens_with_history, gens_w_thinking = res(
            self.generator,
            batch,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1
        )

        # Store step data for each active rollout
        for i, rollout_idx in enumerate(self.active_rollout_indices):
            thinking, content = outputs[i]  # outputs[i] is [thinking, content]

            step_data = StepData(
                step_num=step_num,
                prompt=prompts[i],
                thinking=thinking,
                content=content,
                history_wo_thinking=gens_with_history[i],
                history_w_thinking=gens_w_thinking[i],
                tokens=[]  # Will be populated if needed for backprop
            )

            self.rollouts[rollout_idx].add_step(step_data)

        print(f"Step {step_num} completed for {len(self.active_rollout_indices)} rollouts")

    def check_termination_conditions(self, termination_fn=None) -> List[int]:
        """Check which rollouts should terminate. Returns indices of rollouts to terminate."""
        to_terminate = []

        for rollout_idx in self.active_rollout_indices:
            rollout = self.rollouts[rollout_idx]

            # Check max steps
            if rollout.get_total_steps() >= self.max_steps:
                to_terminate.append(rollout_idx)
                rollout.mark_complete(early=False)
                continue

            # Custom termination function
            if termination_fn and rollout.steps:
                last_step = rollout.steps[-1]
                if termination_fn(rollout, last_step):
                    to_terminate.append(rollout_idx)
                    rollout.mark_complete(early=True)

        # Remove terminated rollouts from active list
        for idx in to_terminate:
            self.active_rollout_indices.remove(idx)

        if to_terminate:
            print(f"Terminated {len(to_terminate)} rollouts. {len(self.active_rollout_indices)} still active")

        return to_terminate

    def run_rollouts(self, prompt_generator_fn, termination_fn=None, max_new_tokens: int = 1000):
        """
        Run all rollouts until completion.

        Args:
            prompt_generator_fn: Function that takes (rollout_data, step_num) and returns next prompt
            termination_fn: Optional function that takes (rollout_data, step_data) and returns bool
            max_new_tokens: Max tokens to generate per step
        """
        step_num = 0

        while self.active_rollout_indices and step_num < self.max_steps:
            # Generate prompts for active rollouts
            prompts = []
            for rollout_idx in self.active_rollout_indices:
                rollout = self.rollouts[rollout_idx]
                prompt = prompt_generator_fn(rollout, step_num)
                prompts.append(prompt)

            # Execute step
            self.execute_step(step_num, prompts, max_new_tokens)

            # Check termination
            self.check_termination_conditions(termination_fn)

            step_num += 1

        # Mark any remaining active rollouts as complete
        for rollout_idx in self.active_rollout_indices:
            self.rollouts[rollout_idx].mark_complete(early=False)

        print(f"\nAll rollouts complete. Total steps: {step_num}")

    def assign_rewards(self, reward_fn):
        """Assign rewards to all rollouts using provided reward function"""
        for rollout in self.rollouts:
            reward = reward_fn(rollout)
            rollout.set_reward(reward)
        print(f"Rewards assigned to {len(self.rollouts)} rollouts")

    def get_rollout_summary(self) -> Dict[str, Any]:
        """Get summary statistics of all rollouts"""
        total_steps = [r.get_total_steps() for r in self.rollouts]
        early_terminations = sum(1 for r in self.rollouts if r.terminated_early)
        rewards = [r.reward for r in self.rollouts]

        return {
            "total_rollouts": len(self.rollouts),
            "avg_steps": sum(total_steps) / len(total_steps) if total_steps else 0,
            "min_steps": min(total_steps) if total_steps else 0,
            "max_steps": max(total_steps) if total_steps else 0,
            "early_terminations": early_terminations,
            "avg_reward": sum(rewards) / len(rewards) if rewards else 0,
            "step_distribution": total_steps,
            "rewards": rewards
        }

    def save_rollouts(self, filepath: str, format: str = "pickle"):
        """Save all rollout data to file"""
        if format == "pickle":
            with open(filepath, "wb") as f:
                pickle.dump(self.rollouts, f)
        elif format == "json":
            with open(filepath, "w") as f:
                data = [r.to_dict() for r in self.rollouts]
                json.dump(data, f, indent=2)
        else:
            raise ValueError(f"Unknown format: {format}")

        print(f"Saved {len(self.rollouts)} rollouts to {filepath}")

    @staticmethod
    def load_rollouts(filepath: str, format: str = "pickle") -> List[RolloutData]:
        """Load rollout data from file"""
        if format == "pickle":
            with open(filepath, "rb") as f:
                return pickle.load(f)
        elif format == "json":
            with open(filepath, "r") as f:
                data = json.load(f)
                # Would need to reconstruct RolloutData objects from dicts
                return data
        else:
            raise ValueError(f"Unknown format: {format}")

In [None]:
def dummy_reward_function(rollout: RolloutData) -> float:
    """
    Dummy reward function - replace with actual task-specific logic.
    
    Examples:
    - Check if final output matches expected answer
    - Count number of correct steps
    - Measure similarity to reference solution
    """
    # Example: reward based on number of steps completed
    num_steps = rollout.get_total_steps()
    
    # Higher reward for completing more steps (up to max)
    if rollout.terminated_early:
        # Penalty for early termination
        return num_steps * 0.5
    else:
        # Bonus for completing full rollout
        return num_steps * 1.0 + 10.0


def dummy_termination_check(rollout: RolloutData, last_step: StepData) -> bool:
    """
    Dummy termination check - replace with actual task-specific logic.
    
    Examples:
    - Check if output contains stop phrase ("DONE", "FINAL ANSWER:", etc.)
    - Check if task objective is met
    - Check for errors or invalid states
    """
    # Example: terminate if content contains "DONE"
    if "DONE" in last_step.content.upper():
        return True
    
    # Example: random early termination (10% chance) to simulate giving up
    import random
    if random.random() < 0.1:
        return True
    
    return False


def dummy_prompt_generator(rollout: RolloutData, step_num: int) -> str:
    """
    Dummy prompt generator - replace with actual task-specific logic.
    
    Examples:
    - For math: generate next problem in sequence
    - For coding: provide next test case or requirement
    - For reasoning: ask follow-up questions
    """
    if step_num == 0:
        # First step uses initial prompt
        return rollout.initial_prompt
    else:
        # Subsequent steps: simple continuation
        return f"Continue with step {step_num + 1}"

In [None]:
# Example usage

# Initialize generator
generator = LLMGenerator("Qwen/Qwen3-0.6B")

# Setup rollouts
n_rollouts = 4
max_steps = 50

# Create initial prompts for each rollout
initial_prompts = [
    "Solve this problem step by step: What is 15 + 27?",
    "Write a story about a robot learning to cook.",
    "Explain how photosynthesis works.",
    "Debug this code: for i in range(10) print(i)"
]

# Create manager
manager = RolloutManager(generator, n_rollouts=n_rollouts, max_steps=max_steps)

In [None]:
# Initialize rollouts
manager.initialize_rollouts(initial_prompts)

In [None]:
# Run all rollouts to completion
manager.run_rollouts(
    prompt_generator_fn=dummy_prompt_generator,
    termination_fn=dummy_termination_check,
    max_new_tokens=512
)

In [None]:
# Assign rewards
manager.assign_rewards(dummy_reward_function)

In [None]:
# Get summary
summary = manager.get_rollout_summary()
print("\n=== Rollout Summary ===")
for key, value in summary.items():
    print(f"{key}: {value}")

In [None]:
# Save rollouts for later analysis/backprop
manager.save_rollouts("rollout_data.pkl", format="pickle")

In [None]:
# Inspect individual rollout
rollout = manager.rollouts[0]
print(f"\n=== Rollout {rollout.rollout_id} ===")
print(f"Initial prompt: {rollout.initial_prompt}")
print(f"Total steps: {rollout.get_total_steps()}")
print(f"Terminated early: {rollout.terminated_early}")
print(f"Reward: {rollout.reward}")
print(f"\nSteps:")
for step in rollout.steps:
    print(f"  Step {step.step_num}: {step.prompt[:50]}... -> {step.content[:50]}...")

In [None]:
# Access full step data for backprop
# Each step contains:
# - step.prompt: the input prompt
# - step.thinking: the thinking process
# - step.content: the actual response
# - step.history_wo_thinking: conversation history without thinking
# - step.history_w_thinking: conversation history with thinking
# - step.tokens: token IDs (can be populated during generation if needed)

# Example: get all responses for backprop
for rollout in manager.rollouts:
    print(f"\nRollout {rollout.rollout_id} - Reward: {rollout.reward}")
    for step in rollout.steps:
        # Here you would:
        # 1. Re-tokenize or use stored tokens
        # 2. Compute loss with reward signal
        # 3. Backprop gradients
        print(f"  Step {step.step_num}: content length = {len(step.content)} chars")