# Dr. Zero Full 3-Iteration Training (Google Colab)

**Complete Implementation** of the Dr. Zero paper (arXiv:2601.07055)

## Overview

This notebook implements the **full 3-iteration self-evolution pipeline**:
- **HRPO** (Hop-grouped Relative Policy Optimization) for proposer training
- **GRPO** for solver training
- **Difficulty-guided reward** (Paper Equation 4)
- **4:3:2:1 hop ratio** for question distribution

## Requirements

- Google Colab Pro+ (A100 80GB GPU)
- 2TB Google Drive storage
- ~28-42 hours total runtime (can be split across sessions)

## Paper Fidelity

| Parameter | Paper | This Implementation |
|-----------|-------|--------------------|
| Algorithm (Proposer) | HRPO | HRPO |
| Algorithm (Solver) | GRPO | GRPO |
| Steps/Iteration | 50 | 50 |
| Iterations | 3 | 3 |
| Hop Ratio | 4:3:2:1 | 4:3:2:1 |
| Reward Rollout N | 5 | 5 |
| Effective Batch | 256 | 256 (32×8) |

**Only hardware parameters adapted** (8 GPU → 1 GPU)

In [None]:
# Cell 1: Mount Google Drive & Setup Directories

from google.colab import drive
from pathlib import Path
import os

# Mount Google Drive (2TB)
drive.mount('/content/drive')

# Directory structure for full 3-iteration training
DRIVE_BASE = Path('/content/drive/MyDrive/drzero_full')

# Persistent storage (Google Drive)
CORPUS_DIR = DRIVE_BASE / 'corpus'
DATA_DIR = DRIVE_BASE / 'data'
CHECKPOINTS_DIR = DRIVE_BASE / 'checkpoints'
LOGS_DIR = DRIVE_BASE / 'logs'

# Create iteration-specific checkpoint directories
for i in range(1, 4):
    (CHECKPOINTS_DIR / f'iter{i}' / 'proposer').mkdir(parents=True, exist_ok=True)
    (CHECKPOINTS_DIR / f'iter{i}' / 'solver').mkdir(parents=True, exist_ok=True)

# Create all directories
for d in [CORPUS_DIR, DATA_DIR, LOGS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Local temp storage (faster I/O)
LOCAL_BASE = Path('/tmp/drzero')
LOCAL_BASE.mkdir(parents=True, exist_ok=True)

print("Storage Layout:")
print(f"  Google Drive: {DRIVE_BASE}")
print(f"    - Corpus: {CORPUS_DIR}")
print(f"    - Data: {DATA_DIR}")
print(f"    - Checkpoints: {CHECKPOINTS_DIR}")
print(f"  Local temp: {LOCAL_BASE}")
print("\n[OK] Directories created!")

In [None]:
# Cell 2: Install Dependencies
# IMPORTANT: After this cell, restart runtime and skip to Cell 3

import subprocess
import sys

print("Installing dependencies (10-15 minutes)...")

# Core packages
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numpy<2.0"])

print("1/5 Installing ML packages...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
    "torch", "transformers", "accelerate", "datasets", "sentence-transformers"])

print("2/5 Installing utilities...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
    "biopython", "wandb", "tqdm", "psutil", "faiss-gpu", "uvicorn", 
    "fastapi", "pydantic", "pandas", "pyarrow", "httpx", "openai"])

print("3/5 Installing SGLang...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sglang[all]"])

print("4/5 Installing veRL...")
import os
if not os.path.exists('/content/verl'):
    subprocess.check_call(["git", "clone", "-q", 
        "https://github.com/volcengine/verl.git", "/content/verl"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/verl"])

print("5/5 Installing DrPubMedZero...")
if not os.path.exists('/content/DrPubMedZero'):
    subprocess.check_call(["git", "clone", "-q",
        "https://github.com/ShivaAyyar/DrPubMedZero.git", "/content/DrPubMedZero"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/DrPubMedZero"])

print("\n" + "="*60)
print("RESTART RUNTIME NOW: Runtime -> Restart session")
print("Then run Cell 1 again, skip Cell 2, continue from Cell 3")
print("="*60)

In [None]:
# Cell 3: Verify Setup & Clone Repository

import subprocess
import sys
import os
from pathlib import Path

print("Verifying installations...")

import numpy as np
print(f"  [OK] numpy {np.__version__}")

import torch
print(f"  [OK] torch {torch.__version__}")
assert torch.cuda.is_available(), "ERROR: No GPU!"
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"  [OK] GPU: {gpu_name} ({gpu_mem:.0f}GB)")
assert gpu_mem >= 70, "ERROR: Need A100 80GB, got smaller GPU"

import transformers
print(f"  [OK] transformers {transformers.__version__}")

import faiss
print(f"  [OK] faiss")

# Clone/update repository
REPO_DIR = Path('/content/DrPubMedZero')
if not REPO_DIR.exists():
    subprocess.check_call(["git", "clone",
        "https://github.com/ShivaAyyar/DrPubMedZero.git", str(REPO_DIR)])
else:
    subprocess.run(["git", "-C", str(REPO_DIR), "pull"], 
                   capture_output=True)

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))

# Verify custom modules
from biomedical import PubMedCorpusManager
print("  [OK] biomedical module")

from verl.custom_reward.dr_zero_reward import compute_difficulty_reward
print("  [OK] dr_zero_reward module")

from verl.custom_reward.hrpo_advantage import compute_hrpo_advantages
print("  [OK] hrpo_advantage module")

from verl.custom_reward.hop_counter import count_hops
print("  [OK] hop_counter module")

print(f"\n[OK] All checks passed!")

In [None]:
# Cell 4: Configuration for Full 3-Iteration Training

import os
import getpass
from pathlib import Path

# Re-establish paths
DRIVE_BASE = Path('/content/drive/MyDrive/drzero_full')
CORPUS_DIR = DRIVE_BASE / 'corpus'
DATA_DIR = DRIVE_BASE / 'data'
CHECKPOINTS_DIR = DRIVE_BASE / 'checkpoints'
REPO_DIR = Path('/content/DrPubMedZero')

CONFIG = {
    # Model
    'model_name': 'Qwen/Qwen2.5-3B-Instruct',
    
    # Corpus - larger for full training
    'corpus_size': 200000,  # 200K papers
    'training_seeds': 5000,
    'pubmed_query': '(cancer OR diabetes OR alzheimer OR cardiovascular) AND (gene OR protein OR mutation)',
    'date_range': ('2020/01/01', '2024/12/31'),
    
    # Training parameters (from Paper Tables 5-6)
    'micro_batch_size': 32,
    'gradient_accumulation': 8,  # 32 * 8 = 256 effective batch
    'proposer_lr': 1e-6,
    'solver_lr': 1e-6,
    'steps_per_iteration': 50,
    'num_iterations': 3,
    'save_freq': 25,
    
    # HRPO/GRPO parameters (from Paper)
    'proposer_group_size': 1,
    'solver_group_size': 5,
    'reward_rollout_n': 5,
    'proposer_kl_coef': 0.0,  # Paper: KL=0 for HRPO
    'solver_kl_coef': 0.001,
    'solver_clip_ratio': 0.2,
    'hop_ratio': {1: 4, 2: 3, 3: 2, 4: 1},  # Paper default
    
    # Sequence lengths (from Paper)
    'proposer_max_seq_len': 4096,
    'solver_max_seq_len': 3072,
    'max_turns': 5,
    
    # Paths
    'corpus_file': str(CORPUS_DIR / 'pubmed-corpus.jsonl'),
    'index_file': str(CORPUS_DIR / 'pubmedbert_index.faiss'),
    'train_parquet': str(DATA_DIR / 'training_seeds.parquet'),
    'repo_dir': str(REPO_DIR),
    'checkpoints_dir': str(CHECKPOINTS_DIR),
    
    # Server ports
    'retrieval_port': 8000,
    'solver_port': 8001,
}

# User email for NCBI
CONFIG['email'] = input("Enter email for NCBI API: ")

# W&B (optional)
wandb_key = getpass.getpass("W&B API key (Enter to skip): ")
if wandb_key.strip():
    os.environ['WANDB_API_KEY'] = wandb_key
    import wandb
    wandb.login(key=wandb_key)
    print("  [OK] W&B configured")
else:
    os.environ['WANDB_MODE'] = 'disabled'
    print("  [OK] W&B disabled")

print(f"\nConfiguration:")
print(f"  Corpus: {CONFIG['corpus_size']:,} papers")
print(f"  Seeds: {CONFIG['training_seeds']:,}")
print(f"  Iterations: {CONFIG['num_iterations']}")
print(f"  Steps/iteration: {CONFIG['steps_per_iteration']}")
print(f"  Effective batch: {CONFIG['micro_batch_size'] * CONFIG['gradient_accumulation']}")
print(f"  Hop ratio: {CONFIG['hop_ratio']}")
print(f"\n[OK] Configuration complete!")

In [None]:
# Cell 5: Download PubMed Corpus (200K papers)

from biomedical import PubMedCorpusManager
from pathlib import Path
import json

corpus_file = Path(CONFIG['corpus_file'])

if corpus_file.exists():
    with open(corpus_file) as f:
        n_papers = sum(1 for _ in f)
    print(f"Corpus exists: {n_papers:,} papers")
    if n_papers >= CONFIG['corpus_size'] * 0.9:  # 90% threshold
        print("Skipping download.")
        download_needed = False
    else:
        download_needed = True
else:
    download_needed = True

if download_needed:
    print(f"Downloading {CONFIG['corpus_size']:,} PubMed papers...")
    print("Estimated time: 2-4 hours\n")
    
    manager = PubMedCorpusManager(
        save_path=str(CORPUS_DIR),
        email=CONFIG['email']
    )
    
    articles = manager.download_pubmed_abstracts(
        query=CONFIG['pubmed_query'],
        max_results=CONFIG['corpus_size'],
        date_range=CONFIG['date_range']
    )
    
    if articles:
        manager.save_corpus(articles)
        print(f"\n[OK] Downloaded {len(articles):,} papers!")
    else:
        raise RuntimeError("Download failed")
else:
    print("[OK] Using existing corpus")

In [None]:
# Cell 6: Build PubMedBERT FAISS Index

from biomedical.biomedical_retriever import BiomedicalRetrieverServer
from pathlib import Path
import os
import torch

corpus_file = Path(CONFIG['corpus_file'])
index_file = Path(CONFIG['index_file'])

if index_file.exists():
    print(f"Index exists: {index_file}")
    print("Skipping build.")
else:
    print("Building PubMedBERT FAISS index...")
    print("Estimated time: 1-2 hours for 200K papers\n")
    
    retriever = BiomedicalRetrieverServer(
        corpus_path=str(corpus_file),
        index_path=str(index_file),
        model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
        device="cuda",
        topk=3
    )
    
    # Test
    results = retriever.search("BRCA1 mutation breast cancer")
    print(f"Test search: {len(results)} results")
    
    del retriever
    torch.cuda.empty_cache()
    
    size_gb = os.path.getsize(index_file) / 1e9
    print(f"\n[OK] Index built: {size_gb:.1f} GB")

In [None]:
# Cell 7: Prepare Training Seeds with Hop Distribution

import json
import pandas as pd
from pathlib import Path
from verl.custom_reward.hop_counter import generate_hop_distribution, assign_hop_to_seeds

corpus_file = Path(CONFIG['corpus_file'])
parquet_file = Path(CONFIG['train_parquet'])

if parquet_file.exists():
    df = pd.read_parquet(parquet_file)
    print(f"Training data exists: {len(df)} examples")
    
    # Check hop distribution
    if 'data_source' in df.columns:
        hops = df['data_source'].apply(lambda x: int(x.split('_')[-1]) if '_' in x else 1)
        print(f"Hop distribution: {hops.value_counts().sort_index().to_dict()}")
else:
    print(f"Preparing {CONFIG['training_seeds']} training seeds...")
    print(f"Target hop ratio: {CONFIG['hop_ratio']}")
    
    # Load corpus
    corpus = []
    with open(corpus_file) as f:
        for line in f:
            corpus.append(json.loads(line))
    print(f"  Loaded {len(corpus):,} papers")
    
    # Filter for substantial abstracts
    substantial = [p for p in corpus if len(p.get('abstract', '').split()) > 100]
    print(f"  {len(substantial):,} have >100 word abstracts")
    
    # Sample seeds
    import random
    source = substantial if len(substantial) >= CONFIG['training_seeds'] else corpus
    seeds = random.sample(source, min(CONFIG['training_seeds'], len(source)))
    
    # Assign hop counts following 4:3:2:1 ratio
    hop_counts = generate_hop_distribution(len(seeds), CONFIG['hop_ratio'])
    
    # Convert to veRL format
    data = []
    for idx, (seed, hop) in enumerate(zip(seeds, hop_counts)):
        title = seed.get('title', '')
        abstract = seed.get('abstract', seed.get('text', ''))[:800]
        pmid = seed.get('pmid', str(idx))
        
        user_msg = f"Generate a challenging {hop}-hop biomedical question from this paper:\n\nTitle: {title}\n\nAbstract: {abstract}"
        
        data.append({
            'prompt': [{'role': 'user', 'content': user_msg}],
            'data_source': f'search_biomedical_{hop}',
            'extra_info': {
                'index': idx,
                'pmid': pmid,
                'hop': hop,
                'tools_kwargs': {},
                'interaction_kwargs': {}
            }
        })
    
    df = pd.DataFrame(data)
    df.to_parquet(parquet_file)
    
    # Verify hop distribution
    from collections import Counter
    hop_dist = Counter(hop_counts)
    print(f"\nHop distribution:")
    for h in sorted(hop_dist.keys()):
        pct = hop_dist[h] / len(hop_counts) * 100
        print(f"  {h}-hop: {hop_dist[h]} ({pct:.1f}%)")
    
    print(f"\n[OK] Saved {len(df)} training seeds")

In [None]:
# Cell 8: Pre-Training Verification

import torch
from pathlib import Path
import pandas as pd

print("="*60)
print("PRE-TRAINING VERIFICATION")
print("="*60)

# 1. GPU
print("\n1. GPU Status:")
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"   {gpu_name}: {gpu_mem:.0f}GB")
assert gpu_mem >= 70, "Need A100 80GB"
print("   [OK]")

# 2. Files
print("\n2. Required Files:")
files = [
    ('Corpus', CONFIG['corpus_file']),
    ('Index', CONFIG['index_file']),
    ('Training data', CONFIG['train_parquet']),
]
all_ok = True
for name, path in files:
    exists = Path(path).exists()
    status = "[OK]" if exists else "[MISSING]"
    print(f"   {status} {name}")
    all_ok = all_ok and exists

# 3. Training data
print("\n3. Training Data:")
df = pd.read_parquet(CONFIG['train_parquet'])
print(f"   Examples: {len(df)}")
hops = df['data_source'].apply(lambda x: int(x.split('_')[-1]))
print(f"   Hop distribution: {hops.value_counts().sort_index().to_dict()}")

# 4. Custom modules
print("\n4. Custom Modules:")
try:
    from verl.custom_reward.dr_zero_reward import compute_difficulty_reward, compute_format_reward
    print("   [OK] dr_zero_reward")
except ImportError as e:
    print(f"   [FAIL] dr_zero_reward: {e}")
    all_ok = False

try:
    from verl.custom_reward.hrpo_advantage import compute_hrpo_advantages, HRPOAdvantageEstimator
    print("   [OK] hrpo_advantage")
except ImportError as e:
    print(f"   [FAIL] hrpo_advantage: {e}")
    all_ok = False

try:
    from verl.custom_reward.hop_counter import count_hops, HopCounter
    print("   [OK] hop_counter")
except ImportError as e:
    print(f"   [FAIL] hop_counter: {e}")
    all_ok = False

print("\n" + "="*60)
if all_ok:
    print("ALL CHECKS PASSED - Ready for training!")
else:
    print("SOME CHECKS FAILED - Fix issues above")
print("="*60)

In [None]:
# Cell 9: Launch Retrieval Server

import subprocess
import time
import requests
from pathlib import Path

print("="*60)
print("LAUNCHING RETRIEVAL SERVER (Port 8000)")
print("="*60)

REPO_DIR = Path('/content/DrPubMedZero')

retrieval_cmd = [
    "python", str(REPO_DIR / "search" / "retrieval_server.py"),
    "--mode=biomedical",
    f"--index_path={CONFIG['index_file']}",
    f"--corpus_path={CONFIG['corpus_file']}",
    "--retriever_model=microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    "--faiss_gpu",
    "--topk=3",
    f"--port={CONFIG['retrieval_port']}"
]

# Start server
log_file = open('/tmp/retrieval_server.log', 'w')
retrieval_proc = subprocess.Popen(
    retrieval_cmd,
    stdout=log_file,
    stderr=subprocess.STDOUT,
    cwd=str(REPO_DIR)
)
CONFIG['retrieval_pid'] = retrieval_proc.pid
CONFIG['retrieval_log'] = log_file

print(f"Started retrieval server (PID: {retrieval_proc.pid})")
print("Waiting for server to be ready...")

# Wait for server
for i in range(120):  # 2 minutes timeout
    try:
        resp = requests.get(f"http://127.0.0.1:{CONFIG['retrieval_port']}/health", timeout=5)
        if resp.status_code == 200:
            print(f"\n[OK] Retrieval server ready! ({i+1}s)")
            break
    except:
        pass
    time.sleep(1)
    if i % 10 == 9:
        print(f"  Still waiting... ({i+1}s)")
else:
    print("\n[WARNING] Server may not be ready. Check logs:")
    print("  !tail -50 /tmp/retrieval_server.log")

In [None]:
# Cell 10: Launch Solver Server (Base Model for Iteration 1)

import subprocess
import time
import requests

print("="*60)
print("LAUNCHING SOLVER SERVER (Port 8001)")
print("Using base model for Iteration 1")
print("="*60)

# For iteration 1, use base model
solver_model = CONFIG['model_name']

solver_cmd = [
    "python", "-m", "sglang.launch_server",
    f"--model-path={solver_model}",
    f"--port={CONFIG['solver_port']}",
    "--tool-call-parser=qwen25",
    "--mem-fraction-static=0.35",
    "--tp-size=1",
    "--dp-size=1",
]

log_file = open('/tmp/solver_server.log', 'w')
solver_proc = subprocess.Popen(
    solver_cmd,
    stdout=log_file,
    stderr=subprocess.STDOUT
)
CONFIG['solver_pid'] = solver_proc.pid
CONFIG['solver_log'] = log_file
CONFIG['current_solver_model'] = solver_model

print(f"Started solver server (PID: {solver_proc.pid})")
print(f"Model: {solver_model}")
print("Waiting for server to be ready (may take 3-5 minutes)...")

for i in range(300):  # 5 minutes timeout
    try:
        resp = requests.get(f"http://127.0.0.1:{CONFIG['solver_port']}/health", timeout=5)
        if resp.status_code == 200:
            print(f"\n[OK] Solver server ready! ({i+1}s)")
            break
    except:
        pass
    time.sleep(1)
    if i % 30 == 29:
        print(f"  Still waiting... ({i+1}s)")
else:
    print("\n[WARNING] Server may not be ready. Check logs:")
    print("  !tail -50 /tmp/solver_server.log")

---
# ITERATION 1
---

In [None]:
# Cell 11: Train Proposer Iteration 1 (HRPO)

import subprocess
import sys
import os
from pathlib import Path

REPO_DIR = Path('/content/DrPubMedZero')
os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))

ITERATION = 1
checkpoint_dir = Path(CONFIG['checkpoints_dir']) / f'iter{ITERATION}' / 'proposer'

print("="*60)
print(f"ITERATION {ITERATION}: PROPOSER TRAINING (HRPO)")
print("="*60)
print(f"\nAlgorithm: HRPO (Hop-grouped RPO)")
print(f"  - NO ratio clipping (strictly on-policy)")
print(f"  - KL coefficient: {CONFIG['proposer_kl_coef']}")
print(f"  - Group size: {CONFIG['proposer_group_size']}")
print(f"  - Reward rollouts: {CONFIG['reward_rollout_n']}")
print(f"\nReward: Difficulty-guided (Paper Eq. 4)")
print(f"  r = I(0<k<n) × (n-k)/(n-1) + r_format")
print(f"\nTraining: {CONFIG['steps_per_iteration']} steps")
print(f"Checkpoint: {checkpoint_dir}")
print()

proposer_cmd = [
    sys.executable, "-m", "verl.trainer.main_ppo",
    "--config-path", str(REPO_DIR / "config"),
    "--config-name", "search_multiturn_grpo",
    
    # Data
    f"data.train_files={CONFIG['train_parquet']}",
    f"data.train_batch_size={CONFIG['micro_batch_size']}",
    f"trainer.gradient_accumulation_steps={CONFIG['gradient_accumulation']}",
    "data.max_prompt_length=1536",
    "data.max_response_length=2560",
    f"data.max_seq_length={CONFIG['proposer_max_seq_len']}",
    
    # Model
    f"actor_rollout_ref.model.path={CONFIG['model_name']}",
    f"actor_rollout_ref.actor.optim.lr={CONFIG['proposer_lr']}",
    "actor_rollout_ref.actor.grad_clip=0.1",
    "actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.03",
    "actor_rollout_ref.actor.optim.weight_decay=0.01",
    
    # HRPO: No ratio clipping, KL=0
    "algorithm.use_kl_in_reward=False",
    "algorithm.adv_estimator=grpo_batch",
    f"actor_rollout_ref.rollout.n={CONFIG['proposer_group_size']}",
    "actor_rollout_ref.actor.use_kl_loss=False",
    # Note: No clip_ratio for HRPO (paper: "omit ratio clipping")
    
    # Single GPU
    "trainer.n_gpus_per_node=1",
    "trainer.nnodes=1",
    "actor_rollout_ref.rollout.tensor_model_parallel_size=1",
    "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2",
    "actor_rollout_ref.rollout.gpu_memory_utilization=0.35",
    
    # Memory optimization
    "actor_rollout_ref.actor.fsdp_config.param_offload=True",
    "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
    "actor_rollout_ref.model.enable_gradient_checkpointing=True",
    
    # Multi-turn tool use
    f"actor_rollout_ref.rollout.multi_turn.tool_config_path={REPO_DIR}/config/search_tool_config.yaml",
    f"actor_rollout_ref.rollout.max_turns={CONFIG['max_turns']}",
    
    # Difficulty reward function
    "reward_model.reward_manager=batch",
    "custom_reward_function.name=compute_biomedical_challenger_score_batch",
    "custom_reward_function.path=verl/custom_reward/reward_function.py",
    f"custom_reward_function.reward_kwargs.model_name={CONFIG['model_name']}",
    f"custom_reward_function.reward_kwargs.base_url=http://127.0.0.1:{CONFIG['solver_port']}",
    f"custom_reward_function.reward_kwargs.reward_rollout_n={CONFIG['reward_rollout_n']}",
    
    # Training schedule
    "trainer.total_epochs=1",
    f"trainer.total_training_steps={CONFIG['steps_per_iteration']}",
    f"trainer.save_freq={CONFIG['save_freq']}",
    f"trainer.default_hdfs_dir={checkpoint_dir}",
    f"trainer.project_name=drzero-iter{ITERATION}",
    f"trainer.experiment_name=proposer_iter{ITERATION}",
]

print("Starting training...")
print("Expected time: 3-5 hours\n")

try:
    process = subprocess.run(proposer_cmd, cwd=str(REPO_DIR))
    if process.returncode == 0:
        print(f"\n[OK] Proposer Iteration {ITERATION} complete!")
        CONFIG[f'proposer_iter{ITERATION}_ckpt'] = str(checkpoint_dir / f'global_step_{CONFIG["steps_per_iteration"]}')
    else:
        print(f"\n[FAIL] Training failed (exit code: {process.returncode})")
except KeyboardInterrupt:
    print("\nTraining interrupted.")

In [None]:
# Cell 12: Generate Data with Iteration 1 Proposer

import subprocess
import sys
from pathlib import Path

ITERATION = 1
REPO_DIR = Path('/content/DrPubMedZero')

proposer_ckpt = CONFIG.get(f'proposer_iter{ITERATION}_ckpt')
if not proposer_ckpt:
    # Try to find checkpoint
    ckpt_dir = Path(CONFIG['checkpoints_dir']) / f'iter{ITERATION}' / 'proposer'
    ckpts = sorted(ckpt_dir.glob('global_step_*'))
    if ckpts:
        proposer_ckpt = str(ckpts[-1])
        CONFIG[f'proposer_iter{ITERATION}_ckpt'] = proposer_ckpt
    else:
        raise FileNotFoundError(f"No proposer checkpoint found for iteration {ITERATION}")

output_parquet = Path(CONFIG['checkpoints_dir']).parent / 'data' / f'iter{ITERATION}_generated.parquet'
output_parquet.parent.mkdir(parents=True, exist_ok=True)

print("="*60)
print(f"ITERATION {ITERATION}: GENERATE DATA")
print("="*60)
print(f"\nProposer: {proposer_ckpt}")
print(f"Output: {output_parquet}")
print(f"Samples per seed: {CONFIG['solver_group_size']}")
print()

gen_cmd = [
    sys.executable, "-m", "verl.trainer.main_generation",
    "--config-path", str(REPO_DIR / "config"),
    "--config-name", "search_multiturn_grpo",
    
    f"+ckpt_path={proposer_ckpt}",
    f"+data.path={CONFIG['train_parquet']}",
    f"+data.output_path={output_parquet}",
    "+data.batch_size=64",
    
    f"actor_rollout_ref.rollout.n={CONFIG['solver_group_size']}",
    "actor_rollout_ref.rollout.temperature=1.0",
    "actor_rollout_ref.rollout.top_p=1.0",
    "actor_rollout_ref.rollout.gpu_memory_utilization=0.8",
]

print("Generating QA pairs...")
print("Expected time: 1-2 hours\n")

try:
    process = subprocess.run(gen_cmd, cwd=str(REPO_DIR))
    if process.returncode == 0:
        print(f"\n[OK] Data generation complete!")
        CONFIG[f'iter{ITERATION}_data'] = str(output_parquet)
    else:
        print(f"\n[FAIL] Generation failed")
except KeyboardInterrupt:
    print("\nGeneration interrupted.")

In [None]:
# Cell 13: Train Solver Iteration 1 (GRPO)

import subprocess
import sys
from pathlib import Path

ITERATION = 1
REPO_DIR = Path('/content/DrPubMedZero')

generated_data = CONFIG.get(f'iter{ITERATION}_data')
if not generated_data or not Path(generated_data).exists():
    generated_data = str(Path(CONFIG['checkpoints_dir']).parent / 'data' / f'iter{ITERATION}_generated.parquet')

checkpoint_dir = Path(CONFIG['checkpoints_dir']) / f'iter{ITERATION}' / 'solver'

print("="*60)
print(f"ITERATION {ITERATION}: SOLVER TRAINING (GRPO)")
print("="*60)
print(f"\nAlgorithm: GRPO")
print(f"  - WITH ratio clipping (epsilon={CONFIG['solver_clip_ratio']})")
print(f"  - KL coefficient: {CONFIG['solver_kl_coef']}")
print(f"  - Group size: {CONFIG['solver_group_size']}")
print(f"\nReward: Outcome-based (correctness)")
print(f"\nTraining data: {generated_data}")
print(f"Checkpoint: {checkpoint_dir}")
print()

solver_cmd = [
    sys.executable, "-m", "verl.trainer.main_ppo",
    "--config-path", str(REPO_DIR / "config"),
    "--config-name", "search_multiturn_grpo",
    
    # Data
    f"data.train_files={generated_data}",
    f"data.train_batch_size={CONFIG['micro_batch_size']}",
    f"trainer.gradient_accumulation_steps={CONFIG['gradient_accumulation']}",
    "data.max_prompt_length=512",
    f"data.max_seq_length={CONFIG['solver_max_seq_len']}",
    
    # Model
    f"actor_rollout_ref.model.path={CONFIG['model_name']}",
    f"actor_rollout_ref.actor.optim.lr={CONFIG['solver_lr']}",
    "actor_rollout_ref.actor.grad_clip=0.1",
    "actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.03",
    "actor_rollout_ref.actor.optim.weight_decay=0.01",
    
    # GRPO: With ratio clipping
    "algorithm.adv_estimator=grpo",
    f"actor_rollout_ref.rollout.n={CONFIG['solver_group_size']}",
    "actor_rollout_ref.actor.use_kl_loss=True",
    f"actor_rollout_ref.actor.kl_loss_coef={CONFIG['solver_kl_coef']}",
    f"actor_rollout_ref.actor.clip_ratio={CONFIG['solver_clip_ratio']}",
    
    # Single GPU
    "trainer.n_gpus_per_node=1",
    "trainer.nnodes=1",
    "actor_rollout_ref.rollout.tensor_model_parallel_size=1",
    "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2",
    "actor_rollout_ref.rollout.gpu_memory_utilization=0.35",
    
    # Memory optimization
    "actor_rollout_ref.actor.fsdp_config.param_offload=True",
    "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
    "actor_rollout_ref.model.enable_gradient_checkpointing=True",
    
    # Multi-turn
    f"actor_rollout_ref.rollout.multi_turn.tool_config_path={REPO_DIR}/config/search_tool_config.yaml",
    f"actor_rollout_ref.rollout.max_turns={CONFIG['max_turns']}",
    
    # Training schedule
    "trainer.total_epochs=1",
    f"trainer.total_training_steps={CONFIG['steps_per_iteration']}",
    f"trainer.save_freq={CONFIG['save_freq']}",
    f"trainer.default_hdfs_dir={checkpoint_dir}",
    f"trainer.project_name=drzero-iter{ITERATION}",
    f"trainer.experiment_name=solver_iter{ITERATION}",
]

print("Starting training...")
print("Expected time: 3-4 hours\n")

try:
    process = subprocess.run(solver_cmd, cwd=str(REPO_DIR))
    if process.returncode == 0:
        print(f"\n[OK] Solver Iteration {ITERATION} complete!")
        CONFIG[f'solver_iter{ITERATION}_ckpt'] = str(checkpoint_dir / f'global_step_{CONFIG["steps_per_iteration"]}')
    else:
        print(f"\n[FAIL] Training failed")
except KeyboardInterrupt:
    print("\nTraining interrupted.")

In [None]:
# Cell 14: Convert Solver Iteration 1 to HuggingFace Format

import subprocess
import sys
from pathlib import Path

ITERATION = 1

solver_ckpt = CONFIG.get(f'solver_iter{ITERATION}_ckpt')
if not solver_ckpt:
    ckpt_dir = Path(CONFIG['checkpoints_dir']) / f'iter{ITERATION}' / 'solver'
    ckpts = sorted(ckpt_dir.glob('global_step_*'))
    if ckpts:
        solver_ckpt = str(ckpts[-1])

output_hf_dir = Path(CONFIG['checkpoints_dir']) / f'iter{ITERATION}' / 'solver' / f'solver_iter{ITERATION}_hf'

print("="*60)
print(f"CONVERT SOLVER ITERATION {ITERATION} TO HF FORMAT")
print("="*60)
print(f"\nInput: {solver_ckpt}")
print(f"Output: {output_hf_dir}")
print()

convert_cmd = [
    sys.executable, "-m", "verl.model_merger", "merge",
    "--backend", "fsdp",
    "--local_dir", f"{solver_ckpt}/actor",
    "--target_dir", str(output_hf_dir),
]

try:
    process = subprocess.run(convert_cmd)
    if process.returncode == 0:
        print(f"\n[OK] Conversion complete!")
        CONFIG[f'solver_iter{ITERATION}_hf'] = str(output_hf_dir)
        
        # Verify
        from transformers import AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(output_hf_dir)
        print(f"  Model loaded successfully: {model.config.model_type}")
        del model
    else:
        print(f"\n[FAIL] Conversion failed")
except Exception as e:
    print(f"\n[ERROR] {e}")