# Parallelism Configuration Test

Test different parallelism configurations (TP, CP, DP) with math_rl training and compare metrics.

## Contents
1. [Configuration](#configuration)
2. [Helper Functions](#helper-functions)
3. [Config 1: DP=4, TP=1, CP=1 (Baseline)](#config-1-dp4-tp1-cp1-baseline)
4. [Config 2: DP=2, TP=2, CP=1](#config-2-dp2-tp2-cp1)
5. [Config 3: DP=1, TP=2, CP=2](#config-3-dp1-tp2-cp2)
6. [Results Comparison](#results-comparison)

## Configuration

In [None]:
import os
import subprocess
import json
import time
import requests
from pathlib import Path

# Training Configuration
MODEL_NAME = "meta-llama/Llama-3.2-1B"
GROUP_SIZE = 4
GROUPS_PER_BATCH = 100
LEARNING_RATE = 1e-4
N_BATCHES = 10

# Wandb Configuration
WANDB_PROJECT = "parallelism-test"
WANDB_API_KEY = "0ed1fa8a77196635759510132f81ea55ced801bd"

# API Configuration
API_BASE_URL = "http://localhost:8000"
API_KEY = "slime-dev-key"

# Paths
OUTPUT_DIR = Path("/tmp/parallelism_test")
LOG_FILE = Path("/data/logs/tinkercloud.log")
OPENTINKER_DIR = Path("/root/gavin/tinkercloud")

# Environment
os.environ["TINKER_API_KEY"] = API_KEY
os.environ["TINKER_BASE_URL"] = API_BASE_URL
os.environ["WANDB_API_KEY"] = WANDB_API_KEY
os.environ["PYTHONPATH"] = f"/root/gavin/tinkercloud:/root/Megatron-LM:/root/gavin/miles:{os.environ.get('PYTHONPATH', '')}"

# Create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Group Size: {GROUP_SIZE}")
print(f"  Groups per Batch: {GROUPS_PER_BATCH}")
print(f"  N Batches: {N_BATCHES}")
print(f"  Wandb Project: {WANDB_PROJECT}")
print(f"  Output Dir: {OUTPUT_DIR}")

## Helper Functions

In [None]:
def run_cmd(cmd, check=True, capture=True, cwd=None):
    """Run a shell command and return output"""
    result = subprocess.run(cmd, shell=True, capture_output=capture, text=True, cwd=cwd)
    if capture:
        if result.returncode != 0 and check:
            print(f"Error: {result.stderr}")
        return result.stdout.strip(), result.stderr.strip(), result.returncode
    return None, None, result.returncode


def wait_for_api(timeout=60):
    """Wait for API to be ready"""
    start = time.time()
    while time.time() - start < timeout:
        try:
            resp = requests.get(
                f"{API_BASE_URL}/health",
                headers={"X-API-Key": API_KEY},
                timeout=5
            )
            if resp.status_code == 200 and resp.json().get("status") == "healthy":
                return True
        except:
            pass
        time.sleep(2)
    return False


def restart_server(tp: int, cp: int):
    """Restart tinkercloud server with new parallelism config"""
    print(f"\n{'='*60}")
    print(f"Restarting server with TP={tp}, CP={cp}")
    print(f"{'='*60}")
    
    # Kill existing server
    run_cmd("pkill -9 -f 'uvicorn.*training' 2>/dev/null || true", check=False)
    time.sleep(3)
    
    # Clear log
    LOG_FILE.write_text("")
    
    # Start server with new parallelism
    cmd = f"""
    cd {OPENTINKER_DIR} && \
    SLIME_DEFAULT_TP={tp} SLIME_DEFAULT_CP={cp} ALLOW_PARTIAL_BATCHES=true \
    nohup python3 -m uvicorn training.api:app --host 0.0.0.0 --port 8000 \
    >> {LOG_FILE} 2>&1 &
    """
    run_cmd(cmd, check=False)
    
    print("Waiting for server to start...")
    if wait_for_api(timeout=30):
        print("Server is healthy")
    else:
        print("ERROR: Server failed to start")
        print(LOG_FILE.read_text()[-2000:])
        raise RuntimeError("Server failed to start")
    
    # Cleanup existing sessions
    print("Cleaning up existing sessions...")
    run_cmd("python /root/gavin/tinker_gmi/tests_integration/cleanup_test_env.py 2>/dev/null || true", check=False)


def run_training(config_name: str, tp: int, cp: int, dp: int):
    """Run training with specified parallelism config"""
    from datetime import datetime
    
    log_path = OUTPUT_DIR / config_name
    wandb_name = f"{config_name}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    
    print(f"\n{'='*60}")
    print(f"Running training: {config_name} ({N_BATCHES} batches)")
    print(f"  TP={tp}, CP={cp}, DP={dp}")
    print(f"  Log path: {log_path}")
    print(f"  Wandb name: {wandb_name}")
    print(f"{'='*60}\n")
    
    # Remove old log dir
    if log_path.exists():
        import shutil
        shutil.rmtree(log_path)
    
    # Run training
    cmd = f"""
    python -m tinker_cookbook.recipes.math_rl.train \
        model_name="{MODEL_NAME}" \
        group_size={GROUP_SIZE} \
        groups_per_batch={GROUPS_PER_BATCH} \
        learning_rate={LEARNING_RATE} \
        n_batches={N_BATCHES} \
        log_path="{log_path}" \
        wandb_project="{WANDB_PROJECT}" \
        wandb_name="{wandb_name}" \
        behavior_if_log_dir_exists=delete
    """
    
    stdout, stderr, rc = run_cmd(cmd, check=False)
    if stdout:
        print(stdout[-3000:])  # Last 3000 chars
    if rc != 0:
        print(f"\nTraining failed with exit code {rc}")
        if stderr:
            print(stderr[-1000:])
    
    return log_path


def get_final_metrics(log_path: Path) -> dict:
    """Extract final metrics from training log"""
    metrics_file = log_path / "metrics.jsonl"
    if not metrics_file.exists():
        return {}
    
    # Read last line
    lines = metrics_file.read_text().strip().split("\n")
    if not lines:
        return {}
    
    return json.loads(lines[-1])


def print_metrics(config_name: str, metrics: dict):
    """Print formatted metrics"""
    step = metrics.get('progress/batch', '-')
    correct = metrics.get('env/all/correct', 0)
    reward = metrics.get('env/all/reward/total', 0)
    train_time = metrics.get('time/train', 0)
    
    print(f"\nFinal metrics for {config_name}:")
    print(f"  Step: {step}")
    print(f"  Correct: {correct:.3f}")
    print(f"  Reward: {reward:.3f}")
    print(f"  Train Time: {train_time:.1f}s")


print("Helper functions defined")

---
## Config 1: DP=4, TP=1, CP=1 (Baseline)

Pure data parallelism - each GPU processes different samples independently.

In [None]:
# Config 1: DP=4, TP=1, CP=1 (baseline)
CONFIG_1 = "dp4-tp1-cp1"
TP_1, CP_1, DP_1 = 1, 1, 4

restart_server(tp=TP_1, cp=CP_1)

In [None]:
log_path_1 = run_training(CONFIG_1, TP_1, CP_1, DP_1)
metrics_1 = get_final_metrics(log_path_1)
print_metrics(CONFIG_1, metrics_1)

---
## Config 2: DP=2, TP=2, CP=1

Tensor parallelism splits model weights across 2 GPUs, with 2-way data parallelism.

In [None]:
# Config 2: DP=2, TP=2, CP=1
CONFIG_2 = "dp2-tp2-cp1"
TP_2, CP_2, DP_2 = 2, 1, 2

restart_server(tp=TP_2, cp=CP_2)

In [None]:
log_path_2 = run_training(CONFIG_2, TP_2, CP_2, DP_2)
metrics_2 = get_final_metrics(log_path_2)
print_metrics(CONFIG_2, metrics_2)

---
## Config 3: DP=1, TP=2, CP=2

Tensor + Context parallelism. TP=2 splits weights, CP=2 splits sequence length.

In [None]:
# Config 3: DP=1, TP=2, CP=2
CONFIG_3 = "dp1-tp2-cp2"
TP_3, CP_3, DP_3 = 2, 2, 1

restart_server(tp=TP_3, cp=CP_3)

In [None]:
log_path_3 = run_training(CONFIG_3, TP_3, CP_3, DP_3)
metrics_3 = get_final_metrics(log_path_3)
print_metrics(CONFIG_3, metrics_3)

---
## Results Comparison

Compare all three configurations side by side.

In [None]:
# Collect all results
configs = [
    ("dp4-tp1-cp1", 1, 1, 4),
    ("dp2-tp2-cp1", 2, 1, 2),
    ("dp1-tp2-cp2", 2, 2, 1),
]

results = []
for config_name, tp, cp, dp in configs:
    metrics = get_final_metrics(OUTPUT_DIR / config_name)
    results.append({
        "config": config_name,
        "tp": tp,
        "cp": cp,
        "dp": dp,
        "correct": metrics.get("env/all/correct", "N/A"),
        "reward": metrics.get("env/all/reward/total", "N/A"),
        "train_time": metrics.get("time/train", "N/A"),
    })

# Print comparison table
print("=" * 70)
print("PARALLELISM CONFIGURATION COMPARISON")
print("=" * 70)
print(f"{'Config':<15} {'TP':>4} {'CP':>4} {'DP':>4} {'Correct':>10} {'Reward':>10} {'Time':>10}")
print("-" * 70)

for r in results:
    correct = f"{r['correct']:.3f}" if isinstance(r['correct'], float) else r['correct']
    reward = f"{r['reward']:.3f}" if isinstance(r['reward'], float) else r['reward']
    train_time = f"{r['train_time']:.1f}s" if isinstance(r['train_time'], float) else r['train_time']
    
    print(f"{r['config']:<15} {r['tp']:>4} {r['cp']:>4} {r['dp']:>4} {correct:>10} {reward:>10} {train_time:>10}")

print("=" * 70)

In [None]:
# Save results to JSON
results_file = OUTPUT_DIR / "comparison_results.json"
with open(results_file, "w") as f:
    json.dump(results, f, indent=2, default=str)

print(f"Results saved to: {results_file}")
print(f"\nWandb project: {WANDB_PROJECT}")
print(f"Output directory: {OUTPUT_DIR}")

---
## Quick Reference

### Parallelism Types

| Type | Abbrev | What it splits | Communication |
|------|--------|----------------|---------------|
| Data Parallel | DP | Samples across GPUs | Gradient all-reduce |
| Tensor Parallel | TP | Model weights (columns/rows) | Activation all-reduce |
| Context Parallel | CP | Sequence length | Attention KV exchange |

### GPU Allocation

With 4 GPUs: `DP × TP × CP = 4`

| Config | DP | TP | CP | Use Case |
|--------|----|----|----|---------|
| dp4-tp1-cp1 | 4 | 1 | 1 | Small models, max throughput |
| dp2-tp2-cp1 | 2 | 2 | 1 | Medium models |
| dp1-tp2-cp2 | 1 | 2 | 2 | Large models, long sequences |

### Commands

```bash
# Set parallelism via environment
SLIME_DEFAULT_TP=2 SLIME_DEFAULT_CP=2 python -m uvicorn training.api:app

# Check server logs
tail -f /data/logs/tinkercloud.log

# View wandb results
# https://wandb.ai/<entity>/parallelism-test
```