# ðŸš€ Local P5 Multi-Turn GRPO Training

This notebook runs Multi-Turn GRPO training locally on your P5 EC2 instance.

## Architecture

- **GPU 7**: vLLM server for fast inference during rollouts
- **GPUs 0-6**: Distributed training with DeepSpeed ZeRO-3

## Prerequisites

1. Java 21 installed (for Pyserini)
2. 8 GPUs available
3. Docker container with PyTorch 2.8.0

## Step 1: Install Dependencies

In [None]:
# Install required packages
!pip install -q -r requirements_local.txt

## Step 2: Verify Java Installation

In [None]:
import os

# Set Java environment
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-21-openjdk-amd64'
os.environ['PATH'] = f"{os.environ['JAVA_HOME']}/bin:{os.environ['PATH']}"

# Verify Java
!java -version

## Step 3: Check GPU Availability

In [None]:
!nvidia-smi --list-gpus

## Step 4: Pre-download Pyserini Index (Optional but Recommended)

This downloads the 10GB Wikipedia index used for tool calling. It's better to do this once upfront.

In [None]:
from pyserini.search.lucene import LuceneSearcher

print("Downloading Pyserini Wikipedia index (10GB)...")
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
print("âœ“ Pyserini index ready")

## Step 5: Configure Training

Choose your model and training configuration:

In [None]:
# Available configs
!ls -1 hf_recipes/Qwen/

In [None]:
# Set your configuration
CONFIG_FILE = "hf_recipes/Qwen/Qwen3-0.6B--mt-grpo.yaml"
NUM_GPUS = 8  # Total GPUs (7 for training, 1 for vLLM)
VLLM_PORT = 8000

print(f"Configuration: {CONFIG_FILE}")
print(f"Total GPUs: {NUM_GPUS}")
print(f"Training GPUs: 0-{NUM_GPUS-2}")
print(f"vLLM GPU: {NUM_GPUS-1}")

## Step 6: View Training Configuration

In [None]:
import yaml

with open(CONFIG_FILE, 'r') as f:
    config = yaml.safe_load(f)
    
print("Training Configuration:")
print(f"  Model: {config.get('model_name_or_path', 'N/A')}")
print(f"  Max steps: {config.get('max_steps', 'N/A')}")
print(f"  Learning rate: {config.get('learning_rate', 'N/A')}")
print(f"  Batch size: {config.get('per_device_train_batch_size', 'N/A')}")
print(f"  Num generations: {config.get('num_generations', 'N/A')}")
print(f"  Max env steps: {config.get('max_env_steps', 'N/A')}")
print(f"  Turn advantage coef: {config.get('turn_advantage_coef', 'N/A')}")

## Step 7: Launch Training

### Option A: Run in Notebook (Blocking)

In [None]:
# This will block until training completes
!bash local_mt_grpo_train.sh --config {CONFIG_FILE} --num_process {NUM_GPUS} --vllm_port {VLLM_PORT}

### Option B: Run in Background (Non-blocking)

Run this in a terminal instead:

```bash
cd /app/mt-grpo/local_training
nohup bash local_mt_grpo_train.sh \
  --config hf_recipes/Qwen/Qwen3-0.6B--mt-grpo.yaml \
  --num_process 8 \
  --vllm_port 8000 \
  > training.log 2>&1 &

# Monitor progress
tail -f training.log
```

## Step 8: Monitor Training

### Check vLLM Server Status

In [None]:
import requests

try:
    response = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=5)
    print(f"âœ“ vLLM server is running (status: {response.status_code})")
except Exception as e:
    print(f"âœ— vLLM server not reachable: {e}")

### View vLLM Server Logs

In [None]:
!tail -50 vllm_server.log

### Monitor GPU Usage

In [None]:
!nvidia-smi

### View Training Logs (if running in background)

In [None]:
# If you ran training in background
!tail -100 training.log

## Step 9: Check Training Output

Training checkpoints and logs will be saved according to your config file's `output_dir` setting.

In [None]:
# List output directories
import glob

output_dirs = glob.glob("outputs/*")
if output_dirs:
    print("Training outputs:")
    for d in sorted(output_dirs):
        print(f"  {d}")
else:
    print("No output directories found yet")

## Troubleshooting

### Common Issues

1. **Java not found**: Install OpenJDK 21
   ```bash
   apt-get update && apt-get install -y openjdk-21-jdk
   ```

2. **vLLM server fails to start**: Check GPU availability and port conflicts
   ```bash
   lsof -i :8000  # Check if port is in use
   ```

3. **Out of memory**: Reduce batch size or num_generations in config

4. **Import errors**: Ensure all dependencies are installed with correct versions
   ```bash
   pip list | grep -E "vllm|trl|transformers"
   ```

## Cleanup

Stop vLLM server if it's still running:

In [None]:
# Find and kill vLLM process
!pkill -f "vllm.entrypoints.openai.api_server"