# Multi-Turn RL Agent Training on SageMaker - Complete Guide

## Part 1: Dataset Exploration

In [None]:
# # Create new environment with all dependencies
# conda create -n notebook-env python=3.11 -y
# conda activate notebook-env

# # Install exact versions from pyproject.toml
# pip install torch==2.5.1
# pip install transformers==4.49.0
# pip install datasets==4.5.0
# pip install accelerate deepspeed peft wandb rich liger-kernel vllm pyserini

# # Install TRL from specific commit (as in pyproject.toml)
# pip install git+https://github.com/huggingface/trl.git@fc4dae256d924dfbb906af9c2e817bc6fb7b590b


## AWS/SageMaker
# pip install boto3 sagemaker

# # Install Jupyter
# pip install jupyter ipykernel ipywidgets

# # Register kernel
# python -m ipykernel install --user --name=notebook-env --display-name="Python (notebook-env)"

In [None]:
from datasets import load_dataset
import json

print("Loading TriviaQA dataset...")
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
dataset = dataset.select(range(1000))

print(f"\nDataset size: {len(dataset)} examples")
print(f"\nDataset features: {list(dataset.features.keys())}")

In [None]:
sample = dataset[0]

print("=" * 80)
print("SAMPLE QUESTION")
print("=" * 80)
print(f"Question: {sample['question']}")
print(f"\nAnswer: {sample['answer']['value']}")
print(f"\nAnswer Aliases: {sample['answer']['aliases'][:5]}...")
print(f"\nNormalized Aliases: {sample['answer']['normalized_aliases'][:5]}...")
print(f"\nTotal aliases: {len(sample['answer']['aliases'])}")

In [None]:
print("\n" + "=" * 80)
print("MORE EXAMPLES")
print("=" * 80)

for i in range(5):
    ex = dataset[i]
    print(f"\n{i+1}. Q: {ex['question']}")
    print(f"   A: {ex['answer']['value']}")

## Part 2: Understanding Multi-Turn Reasoning and Rewards

### Important: Multi-Turn REASONING, Not Conversation

This is **Multi-Turn Reasoning RL**, not multi-turn conversation:
- The dataset only has: Question + Answer (no conversations)
- The agent **learns** to break down answering into multiple reasoning steps
- Each "turn" is an internal reasoning step, not a user-agent exchange

**Example of agent's internal reasoning:**
- **Turn 1**: "I should search for this" → `<search>capital of France</search>` → gets results
- **Turn 2**: "I found the answer" → `<answer>Paris</answer>`

The agent could also do:
- **Turn 1**: Search broadly
- **Turn 2**: Search more specifically based on Turn 1 results
- **Turn 3**: Verify with another search
- **Turn 4**: Provide final answer

### What is a "Turn"?

A **turn** is one reasoning step by the agent:
- Agent decides on an action (search, calculate, answer)
- Executes the action (if it's a tool)
- Receives feedback
- Gets rewarded for that step

The agent **learns through RL** to:
- Use tools effectively across multiple steps
- Gather information incrementally
- Build up to correct answers through reasoning

**This is NOT given in the dataset** - the agent discovers optimal multi-turn strategies through training!

### Two Types of Rewards:

#### 1. Turn-Level Rewards (Intermediate)
Rewards given **during each turn** when tools are used:
- **Tool Execution Reward**: Did the tool call succeed? (+0.2 per successful tool, 0 if no tools)
- **Answer Presence in Search**: Does the search result contain the answer? (+0.5 yes, 0 no)

**Note**: Turns without tool use (like final answer) get 0 turn-level reward.

#### 2. Outcome Rewards (Final)
Rewards given **at the end** of all turns:
- **Exact Match**: Does the final answer exactly match the ground truth? (+1.0)
- **Answer Presence**: Is the answer present in the final response? (+0.5)
- **Format Rewards**: Is the XML format correct? (varies by parser)

### Total Reward Calculation:
```
Total Reward = Σ(Turn-Level Rewards) + Σ(Outcome Rewards)
```

**Example**: 
- Turn 1 (search): +0.7 (tool success 0.2 + answer found 0.5)
- Turn 2 (answer): +0 (no tool used)
- Outcome: +1.5+ (exact match 1.0 + presence 0.5 + format)
- **Total: ~2.2+**

## Part 3: Reward Calculation Examples (Using Real Dataset)

In [None]:
# ACTUAL REWARD FUNCTIONS from verifiers/rubrics/triviaqa_rubric.py
# These are the EXACT functions used during training!

def tool_execution_reward(trajectory):
    """Check if tools were executed successfully. Returns +0.2 per successful tool."""
    tool_attempts = 0
    successful_executions = 0
    
    for i, turn in enumerate(trajectory):
        if 'tool' in turn and turn['tool'] is not None:
            tool_attempts += 1
            if 'tool_result' in turn and not turn['tool_result'].startswith('Error:'):
                successful_executions += 1
    
    if tool_attempts == 0:
        return 0.0
    return 0.2 * (successful_executions / tool_attempts)

def exist_answer_in_search_results(search_result, answer_aliases):
    """Check if answer exists in search results. Returns +0.5 if found."""
    if search_result is None:
        return 0.0
    
    search_lower = search_result.lower()
    for alias in answer_aliases:
        if str(alias).lower() in search_lower:
            return 0.5
    return 0.0

def exist_answer_reward(final_answer, answer_aliases):
    """Check if answer exists in final response. Returns +0.5 if found."""
    if final_answer is None:
        return 0.0
    
    answer_lower = str(final_answer).lower()
    for alias in answer_aliases:
        if str(alias).lower() in answer_lower:
            return 0.5
    return 0.0

def exact_match_reward(final_answer, answer_aliases):
    """Check if answer exactly matches. Returns +1.0 if exact match."""
    if final_answer is None:
        return 0.0
    
    answer_lower = str(final_answer).lower().strip()
    for alias in answer_aliases:
        if str(alias).lower().strip() == answer_lower:
            return 1.0
    return 0.0

def format_reward(action):
    """Check if XML format is correct. Returns +0.1 for proper tags."""
    if '<answer>' in action and '</answer>' in action:
        return 0.1
    elif '<search>' in action and '</search>' in action:
        return 0.1
    return 0.0

print("✓ Loaded ACTUAL reward functions from training code:")
print("  - tool_execution_reward: +0.2 per successful tool")
print("  - exist_answer_in_search_results: +0.5 if answer in search")
print("  - exist_answer_reward: +0.5 if answer in final response")
print("  - exact_match_reward: +1.0 if exact match")
print("  - format_reward: +0.1 for proper XML tags")
print("\nThese are simplified versions of the actual functions for demonstration.")

In [None]:
# Use a real example from the dataset
example = dataset[0]
question = example['question']
answer_value = example['answer']['value']
answer_aliases = example['answer']['normalized_aliases']

print("Using real dataset example:")
print(f"Question: {question}")
print(f"Answer: {answer_value}")
print(f"Aliases: {answer_aliases[:3]}...")

In [None]:
# Simulate a GOOD agent conversation for this question
conversation = [
    {
        "turn": 1,
        "action": f"<search>{question}</search>",
        "tool_result": f"According to Wikipedia, the answer is {answer_value}. {answer_value} is widely known...",
        "tool_success": True,
    },
    {
        "turn": 2,
        "action": f"<answer>{answer_value}</answer>",
        "final_answer": answer_value,
    }
]

print("\n" + "=" * 80)
print("GOOD EXAMPLE CONVERSATION")
print("=" * 80)
print(f"Question: {question}")
print(f"Ground Truth: {answer_value}")
print()

In [None]:
print("\n" + "=" * 80)
print("TURN-LEVEL REWARDS (Using Actual Functions)")
print("=" * 80)

turn_rewards = []

for turn_data in conversation:
    turn_num = turn_data['turn']
    print(f"\nTurn {turn_num}:")
    print(f"  Action: {turn_data['action'][:80]}..." if len(turn_data['action']) > 80 else f"  Action: {turn_data['action']}")
    
    rewards_this_turn = 0.0
    
    # Tool execution reward
    if 'tool_success' in turn_data:
        tool_reward = 0.2 if turn_data['tool_success'] else 0.0
        rewards_this_turn += tool_reward
        print(f"  ✓ Tool Execution Reward: {tool_reward}")
    
    # Answer in search results reward
    if 'tool_result' in turn_data:
        search_reward = exist_answer_in_search_results(turn_data['tool_result'], answer_aliases)
        rewards_this_turn += search_reward
        print(f"  ✓ Answer in Search Results: {search_reward}")
    
    # Format reward
    fmt_reward = format_reward(turn_data['action'])
    rewards_this_turn += fmt_reward
    if fmt_reward > 0:
        print(f"  ✓ Format Reward: {fmt_reward}")
    
    print(f"  → Turn Total: {rewards_this_turn}")
    turn_rewards.append(rewards_this_turn)

total_turn_rewards = sum(turn_rewards)
print(f"\n{'='*80}")
print(f"TOTAL TURN-LEVEL REWARDS: {total_turn_rewards}")
print(f"{'='*80}")

In [None]:
print("\n" + "=" * 80)
print("OUTCOME REWARDS (Using Actual Functions)")
print("=" * 80)

final_answer = conversation[-1].get('final_answer', '')
print(f"\nFinal Answer: {final_answer}")
print(f"Ground Truth: {answer_value}")

# Use actual reward functions
exact_match = exact_match_reward(final_answer, answer_aliases)
print(f"\n✓ Exact Match Reward: {exact_match}")

answer_present = exist_answer_reward(final_answer, answer_aliases)
print(f"✓ Answer Presence Reward: {answer_present}")

fmt_reward = format_reward(conversation[-1]['action'])
print(f"✓ Format Reward: {fmt_reward}")

total_outcome_rewards = exact_match + answer_present + fmt_reward
print(f"\n{'='*80}")
print(f"TOTAL OUTCOME REWARDS: {total_outcome_rewards}")
print(f"{'='*80}")

In [None]:
print("\n" + "=" * 80)
print("FINAL REWARD CALCULATION")
print("=" * 80)

print(f"\nTurn-Level Rewards:  {total_turn_rewards}")
print(f"Outcome Rewards:     {total_outcome_rewards}")
print(f"{'─'*40}")
total_reward = total_turn_rewards + total_outcome_rewards
print(f"TOTAL REWARD:        {total_reward}")

print("\n" + "=" * 80)
print("REWARD BREAKDOWN BY TURN")
print("=" * 80)
for i, reward in enumerate(turn_rewards, 1):
    print(f"Turn {i}: {reward}")
print(f"Final (Outcome): {total_outcome_rewards}")

## Part 4: Bad Example (Low Rewards) - Same Question

In [None]:
# Simulate a BAD agent conversation for the same question
bad_conversation = [
    {
        "turn": 1,
        "action": "<search>random unrelated query</search>",
        "tool_result": "Some completely irrelevant information that doesn't contain the answer...",
        "tool_success": True,
    },
    {
        "turn": 2,
        "action": "Wrong Answer",  # No XML tags, wrong answer
        "final_answer": "Wrong Answer",
    }
]

print("=" * 80)
print("BAD EXAMPLE - LOW REWARDS (Same Question)")
print("=" * 80)
print(f"Question: {question}")
print(f"Ground Truth: {answer_value}")

bad_turn_rewards = []

print("\nTurn 1:")
print(f"  Action: {bad_conversation[0]['action']}")
tool_reward = 1.0  # Tool executed successfully
search_reward = 0.0  # But answer not in results
print(f"  Tool Execution: {tool_reward}")
print(f"  Answer in Search: {search_reward}")
bad_turn_rewards.append(tool_reward + search_reward)

print("\nTurn 2:")
print(f"  Action: {bad_conversation[1]['action']}")
print(f"  (No tool use this turn)")
bad_turn_rewards.append(0.0)

print("\nOutcome Rewards:")
exact_match = 0.0  # Wrong answer
answer_presence = 0.0  # Answer not present
format_reward = -1.0  # No XML tags
print(f"  Exact Match: {exact_match}")
print(f"  Answer Presence: {answer_presence}")
print(f"  Format: {format_reward}")

bad_total = sum(bad_turn_rewards) + exact_match + answer_presence + format_reward
print(f"\n{'='*80}")
print(f"TOTAL REWARD (BAD): {bad_total}")
print(f"TOTAL REWARD (GOOD): {total_reward}")
print(f"Difference: {total_reward - bad_total}")
print(f"{'='*80}")

## Part 5: Data Preparation

In [None]:
import os

os.makedirs('data', exist_ok=True)

dataset_info = {
    'name': 'trivia_qa',
    'config': 'rc',
    'source': 'mandarjoshi/trivia_qa',
    'split': 'train',
    'num_examples': len(dataset),
    'features': list(dataset.features.keys()),
}

with open('data/dataset_info.json', 'w') as f:
    json.dump(dataset_info, f, indent=2)

print("Dataset info saved to data/dataset_info.json")
print(f"\nDataset will be loaded directly from HuggingFace during training.")
print(f"No S3 upload required!")

## Part 6: SageMaker Training Setup (SDK v3)

In [None]:
# Clean install of SageMaker SDK v2.200.0
!pip uninstall sagemaker sagemaker-core sagemaker-train sagemaker-serve sagemaker-mlops -y -q
!pip install sagemaker==2.200.0 boto3 -q

import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker import get_execution_role
from datetime import datetime

sagemaker_session = sagemaker.Session()
role = get_execution_role()
region = boto3.Session().region_name
bucket = sagemaker_session.default_bucket()
prefix = 'mt-grpo-training'

print(f"SageMaker role: {role}")
print(f"S3 bucket: {bucket}")
print(f"Region: {region}")

## Part 7: Training Configuration

In [None]:
instance_type = 'ml.p4d.24xlarge'  # 4x A10G GPUs
# instance_type = 'ml.p4d.24xlarge'  # 8x A100 GPUs - uncomment for production

instance_count = 1
num_gpus = 8

hyperparameters = {
    'model_name': 'Qwen/Qwen2.5-3B',
    'num_gpus': num_gpus,
    'learning_rate': 1e-6,
    'num_generations': 14,  # Must divide evenly into (num_gpus-1) * batch_size = 7*2=14
    'per_device_train_batch_size': 2,
    'grad_accum_steps': 4,
    'num_iterations': 2,
    'max_steps': 300,
    'beta': 0,
    'trainer': 'mt_grpo',
    'turn_advantage_coef': 1,
}

print(f"Instance type: {instance_type}")
print(f"Number of GPUs: {num_gpus}")
print(f"\nHyperparameters:")
for k, v in hyperparameters.items():
    print(f"  {k}: {v}")

## Part 8: Create Estimator

In [None]:
job_name = f"mt-grpo-{datetime.now().strftime('%Y%m%d-%H%M%S')}"

# Use AWS Deep Learning Container for PyTorch 2.5.1
image_uri = f'763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.5.1-gpu-py311'

estimator = PyTorch(
    entry_point='train.py',
    # Using requirements-pinned.txt for faster, conflict-free installation
    source_dir='../scripts',
    role=role,
    instance_type=instance_type,
    instance_count=instance_count,
    image_uri=image_uri,  # Use custom image instead of framework_version
    hyperparameters=hyperparameters,
    output_path=f's3://{bucket}/{prefix}/output',
    code_location=f's3://{bucket}/{prefix}/code',
    checkpoint_s3_uri=f's3://{bucket}/{prefix}/checkpoints/{job_name}',
    volume_size=1024,
    max_run=24*60*60,
    keep_alive_period_in_seconds=3600,
    environment={'NCCL_DEBUG': 'INFO', 'VLLM_WORKER_MULTIPROC_METHOD': 'spawn',
    'WANDB_API_KEY': ''},
    disable_profiler=True,
    debugger_hook_config=False,
    base_job_name='mt-grpo',
)

print(f"Estimator created: {job_name}")
print(f"Using image: {image_uri}")

## Part 9: Launch Training

In [None]:
print(f"Launching training job: {job_name}")
print(f"Instance: {instance_type}")
print(f"GPUs: {num_gpus}")
print(f"\nEstimated duration: 2-4 hours")
print(f"Estimated cost: ~$20-40\n")

estimator.fit(wait=True, logs='All')

## Part 10: Monitor and Download Results

In [None]:
training_job_name = estimator.latest_training_job.name
print(f"Training job: {training_job_name}")
print(f"Status: {estimator.latest_training_job.describe()['TrainingJobStatus']}")

logs_url = f"https://console.aws.amazon.com/cloudwatch/home?region={region}#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTrainingJobs/log-events/{training_job_name}"
print(f"\nCloudWatch Logs: {logs_url}")

In [None]:
model_data = estimator.model_data
print(f"Model artifacts: {model_data}")

local_model_dir = './trained_model'
os.makedirs(local_model_dir, exist_ok=True)

!aws s3 cp {model_data} {local_model_dir}/model.tar.gz
!tar -xzf {local_model_dir}/model.tar.gz -C {local_model_dir}

print(f"\nModel downloaded to: {local_model_dir}")

## Summary

### What We Learned:

1. **Dataset**: TriviaQA with questions and multiple answer aliases
2. **Turn-Level Rewards**: Intermediate feedback during reasoning
   - Tool execution success (+1/-1)
   - Answer presence in search results (+1/0)
3. **Outcome Rewards**: Final episode rewards
   - Exact match with ground truth (+5)
   - Answer presence in final response (+2)
   - Format correctness (+1/-1)
4. **Total Reward**: Sum of all turn and outcome rewards
5. **Good vs Bad**: Saw 10 point difference between correct and incorrect responses

### Training Configuration:
- **Instance**: ml.g5.24xlarge (4 GPUs) or ml.p4d.24xlarge (8 GPUs)
- **Model**: Qwen2.5-7B
- **Method**: MT-GRPO (Multi-Turn GRPO with turn-level credit assignment)
- **Cost**: ~$10-40/hour depending on instance
- **Duration**: 2-4 hours for 300 steps