# Dr. Zero Biomedical Training (Lightweight Version)

**Optimized for:** Single Google Drive (<15GB), A100 GPU (<20 hours)

## Overview

This notebook trains a lightweight version of Dr. Zero on PubMed biomedical literature:
- **Corpus**: 10,000 papers (~2 GB)
- **Training**: Single iteration only (~10-15 hours)
- **Storage**: <15 GB on single Google Drive
- **Checkpoints**: Temp storage + final checkpoint to Drive

## Requirements

- Google Colab Pro/Pro+ (A100 GPU)
- <15 GB Google Drive storage
- Weights & Biases account (optional)
- Your email for NCBI PubMed API

In [None]:
# Cell 1: Mount Google Drive & Setup Directories

from google.colab import drive
from pathlib import Path
import os

# Mount single Google Drive
drive.mount('/content/drive')

# Create directory structure
# Google Drive: Only final models and corpus (~12 GB)
DRIVE_BASE = Path('/content/drive/MyDrive/drzero_lite')
CORPUS_DIR = DRIVE_BASE / 'corpus'
FINAL_CHECKPOINT_DIR = DRIVE_BASE / 'final_checkpoint'

# Local temp: Intermediate checkpoints, working files (auto-deleted on disconnect)
LOCAL_BASE = Path('/content/drzero_temp')
LOCAL_CHECKPOINT = LOCAL_BASE / 'checkpoints'
LOCAL_DATA = LOCAL_BASE / 'data'
LOCAL_LOGS = LOCAL_BASE / 'logs'

# Create all directories
for dir_path in [DRIVE_BASE, CORPUS_DIR, FINAL_CHECKPOINT_DIR,
                 LOCAL_BASE, LOCAL_CHECKPOINT, LOCAL_DATA, LOCAL_LOGS]:
    dir_path.mkdir(parents=True, exist_ok=True)

print("Storage Layout:")
print(f"  Google Drive (~12 GB): {DRIVE_BASE}")
print(f"    - Corpus: {CORPUS_DIR}")
print(f"    - Final model: {FINAL_CHECKPOINT_DIR}")
print(f"  Local temp (auto-deleted): {LOCAL_BASE}")
print(f"    - Working checkpoints: {LOCAL_CHECKPOINT}")
print(f"    - Data: {LOCAL_DATA}")
print("\n‚úÖ Directories created!")

In [None]:
# Cell 2: Install Dependencies

import subprocess
import sys
import os

# Standard packages
packages = [
    "torch", "transformers", "accelerate", "datasets",
    "sentence-transformers", "biopython",
    "wandb", "tqdm", "psutil"
]

print("Installing dependencies...")
for pkg in packages:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])
    print(f"  ‚úì {pkg}")

# Install FAISS - use faiss-cpu which is more reliable
# (GPU acceleration still works via PyTorch for embeddings)
print("Installing FAISS...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "faiss-cpu"])
print(f"  ‚úì faiss-cpu")

# Install SGLang (may take a moment)
print("Installing SGLang...")
try:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sglang[all]"],
                         stderr=subprocess.DEVNULL)
    print(f"  ‚úì sglang")
except:
    # Fallback: install without extras
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sglang"])
    print(f"  ‚úì sglang (minimal)")

# Install veRL
if not os.path.exists('/content/verl'):
    print("Installing veRL framework...")
    subprocess.check_call(["git", "clone", "-q", "https://github.com/volcengine/verl.git", "/content/verl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/verl"])
    print(f"  ‚úì verl")
else:
    print(f"  ‚úì verl (already installed)")

# Verify GPU
import torch
assert torch.cuda.is_available(), "No GPU detected! Set Runtime -> A100 GPU"
print(f"\n‚úÖ GPU: {torch.cuda.get_device_name(0)}")
print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 3: Clone Repository

REPO_DIR = Path('/content/DrPubMedZero')

if not REPO_DIR.exists():
    subprocess.check_call([
        "git", "clone",
        "https://github.com/ShivaAyyar/DrPubMedZero.git",
        str(REPO_DIR)
    ])
else:
    subprocess.check_call(["git", "-C", str(REPO_DIR), "pull"])

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))

# Verify biomedical module
from biomedical import PubMedCorpusManager, setup_for_colab
setup_for_colab()

print(f"‚úÖ Repository ready: {REPO_DIR}")

In [None]:
# Cell 4: Configuration (Lightweight Version)

import getpass

CONFIG = {
    # Reduced scale to fit in 15GB and <20 GPU hours
    'model_name': 'Qwen/Qwen2.5-3B-Instruct',
    
    # Data - REDUCED
    'corpus_size': 10000,  # 10K papers instead of 50K (~2 GB)
    'training_seeds': 500,  # 500 seeds instead of 2000
    'pubmed_query': '(breast cancer OR lung cancer) AND (gene OR protein)',
    'date_range': ('2022/01/01', '2024/12/31'),  # Recent papers only
    
    # Training - REDUCED for <20 GPU hours
    'batch_size': 32,  # Smaller batch
    'gradient_accumulation': 4,
    'learning_rate': 1e-6,
    'max_steps': 100,  # ~10-15 hours on A100
    'save_freq': 50,  # Save every 50 steps (only 2 checkpoints)
    
    # Paths
    'corpus_path': str(CORPUS_DIR),
    'final_checkpoint': str(FINAL_CHECKPOINT_DIR),
    'temp_checkpoint': str(LOCAL_CHECKPOINT),
    'data_dir': str(LOCAL_DATA),
    'logs_dir': str(LOCAL_LOGS),
    
    # Servers
    'retrieval_port': 8000,
    'solver_port': 8001,
}

# User inputs
CONFIG['email'] = 'ssa163@case.edu'
print(f"Email: {CONFIG['email']}")

wandb_key = getpass.getpass("W&B API key (optional, press Enter to skip): ")
if wandb_key:
    os.environ['WANDB_API_KEY'] = wandb_key
    import wandb
    wandb.login(key=wandb_key)
else:
    os.environ['WANDB_MODE'] = 'disabled'

print("\nüìã Configuration:")
print(f"  Corpus: {CONFIG['corpus_size']} papers (~2 GB)")
print(f"  Training: {CONFIG['max_steps']} steps (~10-15 hours)")
print(f"  Storage: <15 GB total")
print(f"\n  Drive: {DRIVE_BASE}")
print(f"  Temp: {LOCAL_BASE}")

In [None]:
# Cell 5: Download PubMed Corpus (Lightweight)

from biomedical import PubMedCorpusManager
import json

corpus_file = Path(CONFIG['corpus_path']) / 'pubmed-corpus-lite.jsonl'

if corpus_file.exists():
    with open(corpus_file) as f:
        n_papers = sum(1 for _ in f)
    print(f"‚úì Corpus exists: {n_papers} papers")
    if n_papers >= CONFIG['corpus_size']:
        print("  Skipping download")
    else:
        download = True
else:
    print(f"Downloading {CONFIG['corpus_size']} PubMed papers...")
    print("‚è±Ô∏è Estimated time: 15-20 minutes")
    download = True

if download:
    manager = PubMedCorpusManager(
        save_path=CONFIG['corpus_path'],
        email=CONFIG['email']
    )
    
    articles = manager.download_pubmed_abstracts(
        query=CONFIG['pubmed_query'],
        max_results=CONFIG['corpus_size'],
        date_range=CONFIG['date_range']
    )
    
    manager.save_corpus(articles)
    print(f"\n‚úÖ Downloaded {len(articles)} papers!")

CONFIG['corpus_file'] = str(corpus_file)

In [None]:
# Cell 6: Build FAISS Index

from biomedical import build_biomedical_index

index_path = Path(CONFIG['corpus_path']) / 'index-lite.faiss'

if index_path.exists():
    print("‚úì Index exists, skipping build")
else:
    print("Building PubMedBERT index...")
    print("‚è±Ô∏è Estimated time: 10-15 minutes")
    
    build_biomedical_index(
        corpus_file=str(corpus_file),
        index_save_path=str(index_path),
        use_gpu=False,  # Using faiss-cpu
        batch_size=64
    )
    
    size_mb = os.path.getsize(index_path) / (1024**2)
    print(f"\n‚úÖ Index built: {size_mb:.0f} MB")

CONFIG['index_path'] = str(index_path)

In [None]:
# Cell 7: Prepare Training Seeds

import random
import json

seeds_file = Path(CONFIG['data_dir']) / 'seeds.jsonl'

if seeds_file.exists():
    print("‚úì Seeds exist, skipping")
else:
    # Load corpus
    corpus = []
    with open(corpus_file) as f:
        for line in f:
            corpus.append(json.loads(line))
    
    # Sample seeds (prefer longer abstracts)
    substantial = [p for p in corpus if len(p.get('abstract', '').split()) > 100]
    seeds = random.sample(
        substantial if len(substantial) >= CONFIG['training_seeds'] else corpus,
        CONFIG['training_seeds']
    )
    
    # Save
    with open(seeds_file, 'w') as f:
        for seed in seeds:
            f.write(json.dumps(seed) + '\n')
    
    print(f"‚úÖ Saved {len(seeds)} training seeds")

CONFIG['seeds_file'] = str(seeds_file)

In [None]:
# Cell 8: Launch Servers

from colab_helpers import BackgroundServer, kill_port
import time

# Kill existing processes
kill_port(CONFIG['retrieval_port'])
kill_port(CONFIG['solver_port'])
time.sleep(2)

# Launch retrieval server
print("Starting retrieval server...")
retrieval_server = BackgroundServer(
    name="retrieval",
    log_file=str(Path(CONFIG['logs_dir']) / 'retrieval.log')
)

retrieval_server.start([
    "python", "search/retrieval_server.py",
    "--mode", "biomedical",
    "--corpus_path", CONFIG['corpus_file'],
    "--index_path", CONFIG['index_path'],
    "--port", str(CONFIG['retrieval_port'])
], cwd=str(REPO_DIR))

assert retrieval_server.wait_until_ready(
    f"http://localhost:{CONFIG['retrieval_port']}/health",
    timeout=180
), "Retrieval server failed to start"

print("‚úì Retrieval server ready")

# Launch solver server
print("\nStarting solver server...")
solver_server = BackgroundServer(
    name="solver",
    log_file=str(Path(CONFIG['logs_dir']) / 'solver.log')
)

solver_server.start([
    "python", "-m", "sglang.launch_server",
    "--model-path", CONFIG['model_name'],
    "--port", str(CONFIG['solver_port']),
    "--tp", "1",
    "--mem-fraction-static", "0.5"
])

assert solver_server.wait_until_ready(
    f"http://localhost:{CONFIG['solver_port']}/health",
    timeout=300
), "Solver server failed to start"

print("‚úì Solver server ready")
print("\n‚úÖ Both servers operational!")

In [None]:
# Cell 9: Train Proposer (Single Iteration)

os.environ['RETRIEVAL_SERVER_URL'] = f"http://localhost:{CONFIG['retrieval_port']}"
os.environ['SOLVER_SERVER_URL'] = f"http://localhost:{CONFIG['solver_port']}"
os.environ['NCBI_EMAIL'] = CONFIG['email']

print("="*80)
print("TRAINING PROPOSER (Lightweight - Single Iteration)")
print("="*80)
print(f"\nSteps: {CONFIG['max_steps']}")
print(f"Estimated time: 10-15 hours on A100")
print(f"\nCheckpoints saved to temp: {CONFIG['temp_checkpoint']}")
print(f"Final checkpoint copied to Drive: {CONFIG['final_checkpoint']}\n")

# Training command
train_cmd = [
    "python", "-m", "verl.trainer.main_ppo",
    "--config-path", "./config",
    "--config-name", "search_multiturn_grpo",
    f"actor_rollout_ref.model.path={CONFIG['model_name']}",
    f"actor_rollout_ref.actor.optim.lr={CONFIG['learning_rate']}",
    f"data.train_files=[{CONFIG['seeds_file']}]",
    f"data.train_batch_size={CONFIG['batch_size']}",
    "trainer.n_gpus_per_node=1",
    f"trainer.gradient_accumulation_steps={CONFIG['gradient_accumulation']}",
    f"trainer.save_freq={CONFIG['save_freq']}",
    f"trainer.total_training_steps={CONFIG['max_steps']}",
    f"trainer.default_hdfs_dir={CONFIG['temp_checkpoint']}",
    "trainer.project_name=drzero-lite",
    "actor_rollout_ref.rollout.tensor_model_parallel_size=1",
]

try:
    subprocess.run(train_cmd, cwd=str(REPO_DIR), check=True)
    
    print("\n" + "="*80)
    print("‚úÖ TRAINING COMPLETE!")
    print("="*80)
    
    # Find final checkpoint
    checkpoints = sorted(Path(CONFIG['temp_checkpoint']).glob("step_*"))
    if checkpoints:
        final_ckpt = checkpoints[-1]
        
        # Copy to Drive for persistence
        import shutil
        print(f"\nCopying final checkpoint to Drive...")
        shutil.copytree(final_ckpt, CONFIG['final_checkpoint'] / final_ckpt.name, dirs_exist_ok=True)
        
        print(f"‚úÖ Final model saved to: {CONFIG['final_checkpoint'] / final_ckpt.name}")
        print(f"\nüí° Intermediate checkpoints in temp storage will be deleted on disconnect")
        print(f"   Only final checkpoint is saved to Google Drive")
    
except Exception as e:
    print(f"\n‚ùå Training failed: {e}")
    print("\nCheck logs:")
    print(f"  - W&B dashboard")
    print(f"  - {CONFIG['logs_dir']}")
    raise

In [None]:
# Cell 10: Check Storage Usage

import shutil

def get_size(path):
    total = 0
    for entry in Path(path).rglob('*'):
        if entry.is_file():
            total += entry.stat().st_size
    return total / (1024**3)  # Convert to GB

print("Storage Usage:")
print("="*50)

drive_size = get_size(DRIVE_BASE)
print(f"Google Drive: {drive_size:.2f} GB")
print(f"  - Corpus: ~2 GB")
print(f"  - Index: ~0.5 GB")
print(f"  - Final checkpoint: ~{drive_size - 2.5:.1f} GB")

temp_size = get_size(LOCAL_BASE)
print(f"\nTemp storage: {temp_size:.2f} GB (will be deleted)")

print(f"\nTotal Drive usage: {drive_size:.1f} GB / 15 GB")
if drive_size < 15:
    print("‚úÖ Within 15 GB limit!")
else:
    print("‚ö†Ô∏è Exceeds 15 GB - consider deleting old files")

---
# Training Complete!

## What You Have

- ‚úÖ PubMed corpus (10K papers, ~2 GB)
- ‚úÖ PubMedBERT search index (~0.5 GB)
- ‚úÖ Trained proposer model (~8-10 GB)
- ‚úÖ **Total**: <15 GB on Google Drive

## GPU Time Used

- Corpus download: ~20 min
- Index building: ~15 min
- Training: ~10-15 hours
- **Total**: <20 hours on A100

## Next Steps

1. **Download model**: `!zip -r model.zip {CONFIG['final_checkpoint']}`
2. **Test generation**: Use proposer to generate biomedical questions
3. **Extend training**: Add more steps or iterations if needed

## Storage Cleanup

To free up space:
- Temp files auto-delete on disconnect
- Keep only final checkpoint on Drive
- Archive corpus if needed: `!tar -czf corpus.tar.gz {CORPUS_DIR}`
