# Dr. Zero Full 3-Iteration Training (Google Colab)

**Complete Implementation** of the Dr. Zero paper (arXiv:2601.07055)

## Overview

This notebook implements the **full 3-iteration self-evolution pipeline**:
- **HRPO** (Hop-grouped Relative Policy Optimization) for proposer training
- **GRPO** for solver training
- **Difficulty-guided reward** (Paper Equation 4)
- **4:3:2:1 hop ratio** for question distribution

## Structure

- **Cell 1**: Install all dependencies (run once, then restart runtime)
- **Cell 2**: Full training pipeline (runs continuously until completion)
- **Cell 3**: Evaluation and results

## Requirements

- Google Colab Pro+ (A100 80GB GPU)
- 2TB Google Drive storage
- ~28-42 hours total runtime

## Paper Fidelity

| Parameter | Paper | This Implementation |
|-----------|-------|--------------------|
| Algorithm (Proposer) | HRPO | HRPO |
| Algorithm (Solver) | GRPO | GRPO |
| Steps/Iteration | 50 | 50 |
| Iterations | 3 | 3 |
| Hop Ratio | 4:3:2:1 | 4:3:2:1 |
| Reward Rollout N | 5 | 5 |
| Effective Batch | 256 | 256 (32×8) |

**Only hardware parameters adapted** (8 GPU → 1 GPU)

In [None]:
# =============================================================================
# CELL 1: INSTALL DEPENDENCIES
# =============================================================================
# Run this cell ONCE, then restart runtime (Runtime -> Restart session)
# After restart, skip this cell and run Cell 2
# =============================================================================

import subprocess
import sys
import os

print("="*70)
print(" INSTALLING DEPENDENCIES")
print(" After completion, restart runtime and run Cell 2")
print("="*70)

# Mount Google Drive first
from google.colab import drive
drive.mount('/content/drive')

# Create directory structure
from pathlib import Path
DRIVE_BASE = Path('/content/drive/MyDrive/drzero_full')
for subdir in ['corpus', 'data', 'checkpoints', 'logs']:
    (DRIVE_BASE / subdir).mkdir(parents=True, exist_ok=True)
for i in range(1, 4):
    (DRIVE_BASE / 'checkpoints' / f'iter{i}' / 'proposer').mkdir(parents=True, exist_ok=True)
    (DRIVE_BASE / 'checkpoints' / f'iter{i}' / 'solver').mkdir(parents=True, exist_ok=True)
print(f"\n[OK] Directory structure created at {DRIVE_BASE}")

# Upgrade pip
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"])

# Install numpy (version constraint for compatibility)
print("\n[1/7] Installing numpy...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numpy<2.0"])

# Install ML packages
print("[2/7] Installing ML packages (torch, transformers, etc.)...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
    "torch", "transformers", "accelerate", "datasets", "sentence-transformers"])

# Install utilities
print("[3/7] Installing utilities...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
    "biopython", "wandb", "tqdm", "psutil", "uvicorn",
    "fastapi", "pydantic", "pandas", "pyarrow", "httpx", "openai", "requests"])

# Install FAISS (try GPU first, fall back to CPU)
print("[4/7] Installing FAISS...")
try:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "faiss-gpu"],
                         stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
    print("       Installed faiss-gpu")
except:
    try:
        import faiss
        print("       faiss already available")
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "faiss-cpu"])
        print("       Installed faiss-cpu")

# Install SGLang
print("[5/7] Installing SGLang...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sglang[all]"])

# Install veRL
print("[6/7] Installing veRL...")
if not os.path.exists('/content/verl'):
    subprocess.check_call(["git", "clone", "-q",
        "https://github.com/volcengine/verl.git", "/content/verl"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/verl"])

# Install DrPubMedZero
print("[7/7] Installing DrPubMedZero...")
if not os.path.exists('/content/DrPubMedZero'):
    subprocess.check_call(["git", "clone", "-q",
        "https://github.com/ShivaAyyar/DrPubMedZero.git", "/content/DrPubMedZero"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/DrPubMedZero"])

print("\n" + "="*70)
print(" INSTALLATION COMPLETE!")
print("="*70)
print("\nNEXT STEPS:")
print("  1. Runtime -> Restart session")
print("  2. Skip this cell")
print("  3. Run Cell 2 (Training)")
print("="*70)

In [None]:
# =============================================================================
# CELL 2: FULL TRAINING PIPELINE
# =============================================================================
# This cell runs the COMPLETE training pipeline:
#   1. Mount Drive & verify GPU
#   2. Download PubMed corpus (if needed)
#   3. Build FAISS index (if needed)
#   4. Prepare training seeds
#   5. Run 3 iterations of Proposer-Solver training
#   6. Save all checkpoints to Google Drive
#
# Total time: ~28-42 hours (runs continuously)
# =============================================================================

import subprocess
import sys
import os
from pathlib import Path
import getpass

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Verify GPU
import torch
assert torch.cuda.is_available(), "ERROR: No GPU available!"
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name} ({gpu_mem:.0f}GB)")
assert gpu_mem >= 70, f"ERROR: Need A100 80GB, got {gpu_mem:.0f}GB"
print("[OK] GPU verified\n")

# Get user inputs
print("="*70)
print(" CONFIGURATION")
print("="*70)

EMAIL = input("Enter email for NCBI API: ")
assert '@' in EMAIL, "Please enter a valid email"

# W&B setup (optional)
wandb_key = getpass.getpass("W&B API key (press Enter to skip): ")
if wandb_key.strip():
    os.environ['WANDB_API_KEY'] = wandb_key
    print("[OK] W&B configured")
else:
    os.environ['WANDB_MODE'] = 'disabled'
    print("[OK] W&B disabled")

# Configuration options
print("\nTraining Configuration (press Enter for defaults):")

corpus_input = input("Corpus size [200000]: ").strip()
CORPUS_SIZE = int(corpus_input) if corpus_input else 200000

seeds_input = input("Training seeds [5000]: ").strip()
TRAINING_SEEDS = int(seeds_input) if seeds_input else 5000

iter_input = input("Number of iterations [3]: ").strip()
NUM_ITERATIONS = int(iter_input) if iter_input else 3

steps_input = input("Steps per iteration [50]: ").strip()
STEPS_PER_ITER = int(steps_input) if steps_input else 50

print(f"\nConfiguration:")
print(f"  Corpus size: {CORPUS_SIZE:,}")
print(f"  Training seeds: {TRAINING_SEEDS:,}")
print(f"  Iterations: {NUM_ITERATIONS}")
print(f"  Steps/iteration: {STEPS_PER_ITER}")
print(f"  Effective batch: 256 (32 x 8)")

# Update repo
print("\nUpdating repository...")
REPO_DIR = Path('/content/DrPubMedZero')
if REPO_DIR.exists():
    subprocess.run(["git", "-C", str(REPO_DIR), "pull"], capture_output=True)
else:
    subprocess.check_call(["git", "clone",
        "https://github.com/ShivaAyyar/DrPubMedZero.git", str(REPO_DIR)])

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))
print(f"[OK] Repository ready at {REPO_DIR}")

# Verify custom modules
print("\nVerifying custom modules...")
try:
    from verl.custom_reward.dr_zero_reward import compute_difficulty_reward
    from verl.custom_reward.hrpo_advantage import compute_hrpo_advantages
    from verl.custom_reward.hop_counter import count_hops
    print("[OK] All custom modules loaded")
except ImportError as e:
    print(f"WARNING: {e}")

# Run training script
print("\n" + "="*70)
print(" STARTING FULL TRAINING PIPELINE")
print(" This will run continuously for ~28-42 hours")
print("="*70 + "\n")

# Build command
train_cmd = [
    sys.executable, str(REPO_DIR / "train_drzero_full.py"),
    f"--email={EMAIL}",
    f"--corpus_size={CORPUS_SIZE}",
    f"--training_seeds={TRAINING_SEEDS}",
    f"--iterations={NUM_ITERATIONS}",
    f"--steps={STEPS_PER_ITER}",
    "--drive_path=/content/drive/MyDrive/drzero_full",
    f"--repo_dir={REPO_DIR}"
]

print(f"Command: {' '.join(train_cmd)}\n")

# Run training
try:
    process = subprocess.run(train_cmd, cwd=str(REPO_DIR))
    
    if process.returncode == 0:
        print("\n" + "="*70)
        print(" TRAINING COMPLETED SUCCESSFULLY!")
        print("="*70)
        print("\nCheckpoints saved to:")
        for i in range(1, NUM_ITERATIONS + 1):
            print(f"  /content/drive/MyDrive/drzero_full/checkpoints/iter{i}/solver/solver_iter{i}_hf")
    else:
        print(f"\nTraining exited with code: {process.returncode}")
        print("Check logs for errors.")
        
except KeyboardInterrupt:
    print("\n" + "="*70)
    print(" TRAINING INTERRUPTED")
    print("="*70)
    print("\nCheckpoints have been saved to Google Drive.")
    print("To resume, re-run this cell.")
    
except Exception as e:
    print(f"\nERROR: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# =============================================================================
# CELL 3: EVALUATION AND RESULTS
# =============================================================================
# Run this cell after training completes to:
#   1. Verify all checkpoints
#   2. Load and test the final model
#   3. Generate sample outputs
#   4. Display storage summary
# =============================================================================

import os
import sys
from pathlib import Path
import torch

# Mount Drive if needed
try:
    from google.colab import drive
    drive.mount('/content/drive')
except:
    pass

DRIVE_BASE = Path('/content/drive/MyDrive/drzero_full')
CHECKPOINTS = DRIVE_BASE / 'checkpoints'

print("="*70)
print(" TRAINING RESULTS")
print("="*70)

# =============================================================================
# 1. CHECKPOINT VERIFICATION
# =============================================================================

print("\n[1/4] Checkpoint Verification")
print("-" * 50)

found_checkpoints = []
for i in range(1, 4):
    hf_path = CHECKPOINTS / f'iter{i}' / 'solver' / f'solver_iter{i}_hf'
    config_file = hf_path / 'config.json'
    
    if config_file.exists():
        found_checkpoints.append((i, str(hf_path)))
        # Get model size
        total_size = sum(f.stat().st_size for f in hf_path.rglob('*') if f.is_file())
        print(f"  [OK] Iteration {i}: {hf_path.name} ({total_size/1e9:.1f} GB)")
    else:
        print(f"  [--] Iteration {i}: Not found")

if not found_checkpoints:
    print("\n  ERROR: No checkpoints found!")
    print("  Training may not have completed.")
else:
    print(f"\n  Found {len(found_checkpoints)} checkpoint(s)")

# =============================================================================
# 2. LOAD FINAL MODEL
# =============================================================================

print("\n[2/4] Model Loading")
print("-" * 50)

if found_checkpoints:
    final_iter, final_path = found_checkpoints[-1]
    print(f"  Loading best model (Iteration {final_iter})...")
    
    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
        model = AutoModelForCausalLM.from_pretrained(
            final_path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        
        print(f"  [OK] Model loaded: {model.config.model_type}")
        print(f"  Parameters: {model.num_parameters()/1e9:.2f}B")
        
    except Exception as e:
        print(f"  [ERROR] Failed to load model: {e}")
        model = None
else:
    model = None
    print("  Skipping (no checkpoints)")

# =============================================================================
# 3. GENERATE SAMPLE OUTPUT
# =============================================================================

print("\n[3/4] Sample Generation")
print("-" * 50)

if model is not None:
    test_prompts = [
        "Generate a challenging 2-hop biomedical question about BRCA1 mutations and breast cancer treatment.",
        "What is the relationship between p53 tumor suppressor gene and cancer development?",
        "Explain how insulin resistance leads to Type 2 diabetes."
    ]
    
    for i, prompt in enumerate(test_prompts, 1):
        print(f"\n  Test {i}:")
        print(f"  Prompt: {prompt[:60]}...")
        
        try:
            messages = [{"role": "user", "content": prompt}]
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(text, return_tensors="pt").to(model.device)
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=200,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id
                )
            
            response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            print(f"  Response: {response[:150]}...")
            
        except Exception as e:
            print(f"  [ERROR] Generation failed: {e}")
else:
    print("  Skipping (no model loaded)")

# =============================================================================
# 4. STORAGE SUMMARY
# =============================================================================

print("\n[4/4] Storage Summary")
print("-" * 50)

import shutil

# Calculate storage for each component
components = [
    ("Corpus", DRIVE_BASE / 'corpus'),
    ("Training Data", DRIVE_BASE / 'data'),
    ("Checkpoints", DRIVE_BASE / 'checkpoints'),
    ("Logs", DRIVE_BASE / 'logs'),
]

total_used = 0
for name, path in components:
    if path.exists():
        size = sum(f.stat().st_size for f in path.rglob('*') if f.is_file())
        total_used += size
        print(f"  {name}: {size/1e9:.2f} GB")
    else:
        print(f"  {name}: --")

print(f"  " + "-"*30)
print(f"  Total: {total_used/1e9:.2f} GB")

# Drive space
try:
    total, used, free = shutil.disk_usage('/content/drive')
    print(f"\n  Drive Space:")
    print(f"    Total: {total/1e12:.1f} TB")
    print(f"    Used: {used/1e12:.2f} TB")
    print(f"    Free: {free/1e12:.2f} TB")
except:
    pass

# =============================================================================
# FINAL SUMMARY
# =============================================================================

print("\n" + "="*70)
print(" SUMMARY")
print("="*70)

if found_checkpoints:
    print(f"\nTraining Status: COMPLETE")
    print(f"Best Model: Iteration {found_checkpoints[-1][0]}")
    print(f"Model Path: {found_checkpoints[-1][1]}")
    print(f"\nTo use the model:")
    print(f"  from transformers import AutoModelForCausalLM, AutoTokenizer")
    print(f"  model = AutoModelForCausalLM.from_pretrained('{found_checkpoints[-1][1]}')")
    print(f"  tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct')")
else:
    print(f"\nTraining Status: INCOMPLETE")
    print("No completed checkpoints found.")
    print("Re-run Cell 2 to continue training.")

print("\n" + "="*70)