# Dr. Zero Biomedical Training (Lightweight Version)

**Optimized for:** Single Google Drive (<15GB), A100 GPU (<20 hours)

## Overview

This notebook trains a lightweight version of Dr. Zero on PubMed biomedical literature:
- **Corpus**: 10,000 papers (~2 GB)
- **Training**: Single iteration only (~10-15 hours)
- **Storage**: <15 GB on single Google Drive
- **Checkpoints**: Temp storage + final checkpoint to Drive

## Requirements

- Google Colab Pro/Pro+ (A100 GPU)
- <15 GB Google Drive storage
- Weights & Biases account (optional)
- Your email for NCBI PubMed API

## Instructions

1. Run Cell 1 (Mount Drive)
2. Run Cell 2 (Install Dependencies) - **THEN RESTART RUNTIME**
3. Run Cell 1 again (re-mount Drive after restart)
4. Skip Cell 2, continue from Cell 3 onwards

In [None]:
# Cell 1: Mount Google Drive & Setup Directories

from google.colab import drive
from pathlib import Path
import os

# Mount single Google Drive
drive.mount('/content/drive')

# Create directory structure
# Google Drive: Only final models and corpus (~12 GB)
DRIVE_BASE = Path('/content/drive/MyDrive/drzero_lite')
CORPUS_DIR = DRIVE_BASE / 'corpus'
FINAL_CHECKPOINT_DIR = DRIVE_BASE / 'final_checkpoint'

# Local temp: Intermediate checkpoints, working files (auto-deleted on disconnect)
LOCAL_BASE = Path('/content/drzero_temp')
LOCAL_CHECKPOINT = LOCAL_BASE / 'checkpoints'
LOCAL_DATA = LOCAL_BASE / 'data'
LOCAL_LOGS = LOCAL_BASE / 'logs'

# Create all directories
for dir_path in [DRIVE_BASE, CORPUS_DIR, FINAL_CHECKPOINT_DIR,
                 LOCAL_BASE, LOCAL_CHECKPOINT, LOCAL_DATA, LOCAL_LOGS]:
    dir_path.mkdir(parents=True, exist_ok=True)

print("Storage Layout:")
print(f"  Google Drive (~12 GB): {DRIVE_BASE}")
print(f"    - Corpus: {CORPUS_DIR}")
print(f"    - Final model: {FINAL_CHECKPOINT_DIR}")
print(f"  Local temp (auto-deleted): {LOCAL_BASE}")
print(f"    - Working checkpoints: {LOCAL_CHECKPOINT}")
print(f"    - Data: {LOCAL_DATA}")
print("\n[OK] Directories created!")

In [None]:
# Cell 2: Install Dependencies
# IMPORTANT: After this cell completes, go to Runtime -> Restart session
# Then skip this cell and continue from Cell 3

import subprocess
import sys
import os

print("Installing dependencies...")
print("This may take 5-10 minutes.\n")

# Upgrade pip first
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"])

# Install numpy first to avoid binary incompatibility
print("1/6 Installing numpy...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numpy<2.0"])

# Core ML packages
print("2/6 Installing ML packages...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", 
    "torch", "transformers", "accelerate", "datasets", "sentence-transformers"])

# Utility packages
print("3/6 Installing utilities...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
    "biopython", "wandb", "tqdm", "psutil", "faiss-cpu", "uvicorn", "fastapi", "pydantic", "pandas", "pyarrow"])

# Install SGLang for model serving
print("4/6 Installing SGLang...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sglang[all]"])

# Clone and install veRL
print("5/6 Installing veRL...")
if not os.path.exists('/content/verl'):
    subprocess.check_call(["git", "clone", "-q", "https://github.com/volcengine/verl.git", "/content/verl"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/verl"])

# Install repo as package for biomedical modules
print("6/6 Installing DrPubMedZero...")
if os.path.exists('/content/DrPubMedZero'):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/content/DrPubMedZero"])

print("\n" + "="*60)
print("IMPORTANT: RESTART RUNTIME NOW")
print("="*60)
print("\n1. Go to: Runtime -> Restart session")
print("2. After restart, run Cell 1 again (mount drive)")
print("3. Then SKIP this cell and continue from Cell 3")
print("\nThis restart is required for numpy to load correctly.")

In [None]:
# Cell 3: Clone Repository & Verify Setup

import subprocess
import sys
import os
from pathlib import Path

# Verify numpy loaded correctly
print("Verifying installations...")
import numpy as np
print(f"  [OK] numpy {np.__version__}")

import torch
print(f"  [OK] torch {torch.__version__}")
assert torch.cuda.is_available(), "ERROR: No GPU! Go to Runtime -> Change runtime type -> A100 GPU"
print(f"  [OK] GPU: {torch.cuda.get_device_name(0)}")

import transformers
print(f"  [OK] transformers {transformers.__version__}")

import faiss
print(f"  [OK] faiss")

# Clone repository
REPO_DIR = Path('/content/DrPubMedZero')

print("\nSetting up repository...")
if not REPO_DIR.exists():
    subprocess.check_call([
        "git", "clone",
        "https://github.com/ShivaAyyar/DrPubMedZero.git",
        str(REPO_DIR)
    ])
    print("  [OK] Repository cloned")
else:
    subprocess.check_call(["git", "-C", str(REPO_DIR), "pull"], 
                         stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    print("  [OK] Repository updated")

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))

# Verify biomedical module
from biomedical import PubMedCorpusManager
print("  [OK] Biomedical module loaded")

print(f"\n[OK] All checks passed! Repository: {REPO_DIR}")

In [None]:
# Cell 4: Configuration

import os
import getpass
from pathlib import Path

# Re-establish directory paths (in case of restart)
DRIVE_BASE = Path('/content/drive/MyDrive/drzero_lite')
CORPUS_DIR = DRIVE_BASE / 'corpus'
FINAL_CHECKPOINT_DIR = DRIVE_BASE / 'final_checkpoint'
LOCAL_BASE = Path('/content/drzero_temp')
LOCAL_CHECKPOINT = LOCAL_BASE / 'checkpoints'
LOCAL_DATA = LOCAL_BASE / 'data'
LOCAL_LOGS = LOCAL_BASE / 'logs'

# Ensure directories exist
for d in [CORPUS_DIR, FINAL_CHECKPOINT_DIR, LOCAL_CHECKPOINT, LOCAL_DATA, LOCAL_LOGS]:
    d.mkdir(parents=True, exist_ok=True)

CONFIG = {
    # Model
    'model_name': 'Qwen/Qwen2.5-3B-Instruct',
    
    # Data - Lightweight settings
    'corpus_size': 10000,
    'training_seeds': 500,
    'pubmed_query': '(breast cancer OR lung cancer) AND (gene OR protein)',
    'date_range': ('2022/01/01', '2024/12/31'),
    
    # Training - Lightweight settings
    'batch_size': 32,
    'gradient_accumulation': 4,
    'learning_rate': 1e-6,
    'max_steps': 100,
    'save_freq': 50,
    
    # Paths
    'corpus_dir': str(CORPUS_DIR),
    'corpus_file': str(CORPUS_DIR / 'pubmed-corpus.jsonl'),
    'index_file': str(CORPUS_DIR / 'pubmedbert_index.faiss'),
    'final_checkpoint': str(FINAL_CHECKPOINT_DIR),
    'temp_checkpoint': str(LOCAL_CHECKPOINT),
    'data_dir': str(LOCAL_DATA),
    'logs_dir': str(LOCAL_LOGS),
    
    # Server ports
    'retrieval_port': 8000,
    'solver_port': 8001,
}

# User email for NCBI
CONFIG['email'] = 'ssa163@case.edu'
print(f"NCBI Email: {CONFIG['email']}")

# Weights & Biases (optional)
wandb_key = getpass.getpass("W&B API key (press Enter to skip): ")
if wandb_key.strip():
    os.environ['WANDB_API_KEY'] = wandb_key
    import wandb
    wandb.login(key=wandb_key)
    print("  [OK] W&B configured")
else:
    os.environ['WANDB_MODE'] = 'disabled'
    print("  [OK] W&B disabled")

print(f"\nConfiguration:")
print(f"  Corpus: {CONFIG['corpus_size']} papers")
print(f"  Training: {CONFIG['max_steps']} steps")
print(f"  Model: {CONFIG['model_name']}")
print(f"\n[OK] Configuration complete!")

In [None]:
# Cell 5: Download PubMed Corpus

from biomedical import PubMedCorpusManager
from pathlib import Path
import json

corpus_file = Path(CONFIG['corpus_file'])

# Check if corpus already exists
if corpus_file.exists():
    with open(corpus_file) as f:
        n_papers = sum(1 for _ in f)
    print(f"Corpus already exists: {n_papers} papers")
    
    if n_papers >= CONFIG['corpus_size']:
        print("Skipping download.")
        download_needed = False
    else:
        print(f"Need {CONFIG['corpus_size'] - n_papers} more papers")
        download_needed = True
else:
    print(f"Downloading {CONFIG['corpus_size']} PubMed papers...")
    print("Estimated time: 15-30 minutes\n")
    download_needed = True

if download_needed:
    manager = PubMedCorpusManager(
        save_path=CONFIG['corpus_dir'],
        email=CONFIG['email']
    )
    
    articles = manager.download_pubmed_abstracts(
        query=CONFIG['pubmed_query'],
        max_results=CONFIG['corpus_size'],
        date_range=CONFIG['date_range']
    )
    
    if articles:
        manager.save_corpus(articles)
        print(f"\n[OK] Downloaded {len(articles)} papers!")
    else:
        raise RuntimeError("Download failed - check internet connection and email")
else:
    print("[OK] Using existing corpus")

In [None]:
# Cell 6: Build FAISS Index

from biomedical.biomedical_retriever import BiomedicalRetrieverServer
from pathlib import Path
import os

corpus_file = Path(CONFIG['corpus_file'])
index_file = Path(CONFIG['index_file'])

if index_file.exists():
    print(f"Index already exists: {index_file}")
    print("Skipping index build.")
else:
    print("Building PubMedBERT FAISS index...")
    print("Estimated time: 10-20 minutes\n")
    
    # The BiomedicalRetrieverServer will build and save the index
    retriever = BiomedicalRetrieverServer(
        corpus_path=str(corpus_file),
        index_path=str(index_file),
        model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
        device="cuda",
        topk=3
    )
    
    # Test search
    print("\nTesting search...")
    results = retriever.search("TP53 mutation breast cancer")
    print(f"  Found {len(results)} results")
    
    # Clean up to free GPU memory
    del retriever
    import torch
    torch.cuda.empty_cache()
    
    size_mb = os.path.getsize(index_file) / (1024**2)
    print(f"\n[OK] Index built: {size_mb:.0f} MB")

In [None]:
# Cell 7: Prepare Training Seeds

import random
import json
from pathlib import Path

corpus_file = Path(CONFIG['corpus_file'])
seeds_file = Path(CONFIG['data_dir']) / 'seeds.jsonl'

if seeds_file.exists():
    with open(seeds_file) as f:
        n_seeds = sum(1 for _ in f)
    print(f"Seeds already exist: {n_seeds}")
    if n_seeds >= CONFIG['training_seeds']:
        print("Skipping seed preparation.")
        prepare_needed = False
    else:
        prepare_needed = True
else:
    prepare_needed = True

if prepare_needed:
    print(f"Preparing {CONFIG['training_seeds']} training seeds...")
    
    # Load corpus
    corpus = []
    with open(corpus_file) as f:
        for line in f:
            corpus.append(json.loads(line))
    print(f"  Loaded {len(corpus)} papers from corpus")
    
    # Filter for papers with substantial abstracts
    substantial = [p for p in corpus if len(p.get('abstract', '').split()) > 100]
    print(f"  {len(substantial)} papers have >100 word abstracts")
    
    # Sample seeds
    source = substantial if len(substantial) >= CONFIG['training_seeds'] else corpus
    seeds = random.sample(source, min(CONFIG['training_seeds'], len(source)))
    
    # Save seeds
    with open(seeds_file, 'w') as f:
        for seed in seeds:
            f.write(json.dumps(seed) + '\n')
    
    print(f"\n[OK] Saved {len(seeds)} training seeds")
else:
    print("[OK] Using existing seeds")

CONFIG['seeds_file'] = str(seeds_file)

In [None]:
# Cell 7.5: Convert Seeds to Parquet Format for veRL

import json
import pandas as pd
from pathlib import Path

print("="*60)
print("CONVERTING SEEDS TO PARQUET FORMAT")
print("="*60)

seeds_file = Path(CONFIG['seeds_file'])
parquet_file = Path(CONFIG['data_dir']) / 'training_seeds.parquet'

if parquet_file.exists():
    print(f"\nParquet file already exists: {parquet_file}")
    df = pd.read_parquet(parquet_file)
    print(f"  Contains {len(df)} examples")
    CONFIG['train_parquet'] = str(parquet_file)
else:
    print("\nConverting JSONL seeds to parquet...")
    
    # Load seeds
    with open(seeds_file) as f:
        seeds = [json.loads(line) for line in f]
    print(f"  Loaded {len(seeds)} seeds")
    
    # Convert to veRL format
    # veRL expects: prompt (list of messages), data_source, extra_info
    data = []
    for idx, seed in enumerate(seeds):
        title = seed.get('title', '')
        abstract = seed.get('abstract', seed.get('text', ''))[:500]  # Truncate for memory
        pmid = seed.get('pmid', str(idx))
        
        # Create prompt for proposer to generate QA pair
        user_message = f"Generate a challenging biomedical question-answer pair from this paper:\n\nTitle: {title}\n\nAbstract: {abstract}"
        
        data.append({
            'prompt': [
                {"role": "user", "content": user_message}
            ],
            'data_source': 'search_biomedical_1',  # 1-hop reasoning for iteration 1
            'extra_info': {
                'index': idx,
                'pmid': pmid,
                'tools_kwargs': {},
                'interaction_kwargs': {}
            }
        })
    
    # Save as parquet
    df = pd.DataFrame(data)
    df.to_parquet(parquet_file)
    
    print(f"  Created {len(df)} training examples")
    print(f"  Saved to: {parquet_file}")
    
    CONFIG['train_parquet'] = str(parquet_file)

print(f"\n[OK] Training data ready: {CONFIG['train_parquet']}")
print("="*60)

In [None]:
# Cell 9: Train Proposer with veRL GRPO + Biomedical Reward Function

import subprocess
import sys
import os
from pathlib import Path
import pandas as pd

REPO_DIR = Path('/content/DrPubMedZero')
os.chdir(REPO_DIR)

# Ensure repo is in PYTHONPATH for reward function imports
sys.path.insert(0, str(REPO_DIR))

print("="*60)
print("DR. ZERO BIOMEDICAL PROPOSER TRAINING")
print("Using GRPO + compute_biomedical_challenger_score_batch")
print("="*60)

# Calculate dataset size for step estimation
df = pd.read_parquet(CONFIG['train_parquet'])
estimated_steps = len(df) // CONFIG['batch_size']

print(f"\nTraining Configuration:")
print(f"  Algorithm: GRPO (Group Relative Policy Optimization)")
print(f"  Reward: compute_biomedical_challenger_score_batch (paper's function)")
print(f"  Model: {CONFIG['model_name']}")
print(f"  Dataset: {len(df)} examples")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Estimated steps: {estimated_steps}")
print(f"  Solver rollout: 3 attempts per question")
print(f"  Checkpoints: Every 25 steps")
print(f"  Storage: Temp checkpoints → {CONFIG['temp_checkpoint']}")
print(f"           Final checkpoint → {CONFIG['final_checkpoint']}")
print()

# Training command (from iter1_challenger_biomed.sh, adapted for 1 GPU + 80GB RAM)
train_cmd = [
    sys.executable, "-m", "verl.trainer.main_ppo",
    "--config-path", str(REPO_DIR / "config"),
    "--config-name", "search_multiturn_grpo",

    # Data
    f"data.train_files={CONFIG['train_parquet']}",
    f"data.train_batch_size={CONFIG['batch_size']}",

    # Model
    f"actor_rollout_ref.model.path={CONFIG['model_name']}",
    f"actor_rollout_ref.actor.optim.lr={CONFIG['learning_rate']}",
    "actor_rollout_ref.actor.grad_clip=0.1",
    "actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.03",

    # GRPO algorithm (PAPER'S ALGORITHM!)
    "algorithm.use_kl_in_reward=False",
    "algorithm.adv_estimator=grpo_batch",
    "actor_rollout_ref.rollout.n=1",  # GRPO group size
    "actor_rollout_ref.actor.use_kl_loss=False",

    # Reward function (PAPER'S REWARD FUNCTION!)
    "reward_model.reward_manager=batch",
    "custom_reward_function.name=compute_biomedical_challenger_score_batch",
    "custom_reward_function.path=verl/custom_reward/reward_function.py",
    f"custom_reward_function.reward_kwargs.model_name={CONFIG['model_name']}",
    "custom_reward_function.reward_kwargs.base_url=http://127.0.0.1:8001",
    "custom_reward_function.reward_kwargs.reward_rollout_n=3",  # Solver attempts

    # Single GPU settings (HARDWARE ADAPTATION ONLY)
    "trainer.n_gpus_per_node=1",
    "trainer.nnodes=1",
    "actor_rollout_ref.rollout.tensor_model_parallel_size=1",
    "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2",
    "actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2",
    "actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2",

    # Memory optimization for 80GB A100
    "actor_rollout_ref.actor.fsdp_config.param_offload=True",
    "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
    "actor_rollout_ref.model.enable_gradient_checkpointing=True",
    "actor_rollout_ref.model.use_remove_padding=True",
    "actor_rollout_ref.rollout.gpu_memory_utilization=0.35",  # Matches solver allocation

    # Multi-turn tool use (PAPER'S APPROACH!)
    f"actor_rollout_ref.rollout.multi_turn.tool_config_path={str(REPO_DIR / 'config' / 'search_tool_config.yaml')}",

    # Training schedule
    "trainer.total_epochs=1",
    f"trainer.save_freq={CONFIG['save_freq']}",
    f"trainer.default_hdfs_dir={CONFIG['temp_checkpoint']}",
    "trainer.project_name=drzero-biomed-poc",
    "trainer.experiment_name=iter1_biomedical_colab",
    "trainer.logger=[\"console\"]",
    "trainer.val_before_train=False",
    "trainer.test_freq=-1",
]

print("="*60)
print("STARTING TRAINING")
print("="*60)
print("\nExpected output:")
print("  - Format scores (XML structure, tool calls)")
print("  - Solver responses (3 attempts per question)")
print("  - Difficulty scores (how much solvers vary)")
print("  - Final rewards (format + difficulty)")
print()
print("This will take 1-2 hours for ~15 steps on A100 80GB")
print("="*60 + "\n")

# Run training
try:
    process = subprocess.run(train_cmd, cwd=str(REPO_DIR))
    
    if process.returncode == 0:
        print("\n" + "="*60)
        print("✓ TRAINING COMPLETE")
        print("="*60)
        
        # Copy final checkpoint to Drive (only final, save space)
        import shutil
        checkpoints = sorted(Path(CONFIG['temp_checkpoint']).glob("global_step_*"))
        if checkpoints:
            final_ckpt = checkpoints[-1]
            dest = Path(CONFIG['final_checkpoint']) / final_ckpt.name
            print(f"\nCopying final checkpoint to Google Drive...")
            print(f"  From: {final_ckpt}")
            print(f"  To: {dest}")
            
            if dest.exists():
                shutil.rmtree(dest)
            shutil.copytree(final_ckpt, dest)
            
            # Get size
            size_gb = sum(f.stat().st_size for f in dest.rglob('*') if f.is_file()) / 1e9
            print(f"  Size: {size_gb:.2f} GB")
            print(f"\n[OK] Final checkpoint saved to Drive")
        else:
            print("\n[WARNING] No checkpoints found")
    else:
        print(f"\n" + "="*60)
        print(f"❌ TRAINING FAILED (exit code: {process.returncode})")
        print("="*60)
        print("\nCheck for errors above. Common issues:")
        print("  1. Servers not running (run Cell 8.5 first)")
        print("  2. Out of memory (reduce batch_size in Cell 4)")
        print("  3. Import errors (restart runtime after Cell 2)")
        
except KeyboardInterrupt:
    print("\n\nTraining interrupted by user.")
    print(f"Partial checkpoints saved to: {CONFIG['temp_checkpoint']}")
except Exception as e:
    print(f"\n\nTraining failed with exception: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Cell 9.5: Shutdown Servers and Cleanup

import subprocess
import os
import signal

print("="*60)
print("SHUTTING DOWN SERVERS")
print("="*60)

# Close log files first
if 'retrieval_log' in CONFIG and CONFIG['retrieval_log']:
    try:
        CONFIG['retrieval_log'].close()
    except:
        pass

if 'solver_log' in CONFIG and CONFIG['solver_log']:
    try:
        CONFIG['solver_log'].close()
    except:
        pass

# Kill processes
servers_stopped = []

if 'retrieval_pid' in CONFIG:
    try:
        os.kill(CONFIG['retrieval_pid'], signal.SIGTERM)
        servers_stopped.append(f"Retrieval server (PID: {CONFIG['retrieval_pid']})")
    except:
        pass

if 'solver_pid' in CONFIG:
    try:
        os.kill(CONFIG['solver_pid'], signal.SIGTERM)
        servers_stopped.append(f"Solver server (PID: {CONFIG['solver_pid']})")
    except:
        pass

# Force kill any remaining processes on these ports
subprocess.run("fuser -k 8000/tcp 2>/dev/null || lsof -ti:8000 | xargs kill -9 2>/dev/null || true", shell=True)
subprocess.run("fuser -k 8001/tcp 2>/dev/null || lsof -ti:8001 | xargs kill -9 2>/dev/null || true", shell=True)

if servers_stopped:
    for server in servers_stopped:
        print(f"  ✓ Stopped {server}")
else:
    print("  No servers were running")

# Show temp checkpoint usage
from pathlib import Path
temp_size = 0
temp_path = Path(CONFIG['temp_checkpoint'])
if temp_path.exists():
    temp_size = sum(f.stat().st_size for f in temp_path.rglob('*') if f.is_file()) / 1e9
    print(f"\nTemp checkpoint storage: {temp_size:.2f} GB")
    print(f"  Location: {temp_path}")
    print(f"  (Will be auto-deleted when Colab session ends)")

print("\n" + "="*60)
print("[OK] Cleanup complete")
print("="*60)

In [None]:
# Cell 8: Test Retrieval & Model Loading
# This cell tests that everything works before starting the long training

print("="*60)
print("PRE-TRAINING VERIFICATION")
print("="*60)

# Test 1: Load and test retriever
print("\n1. Testing retriever...")
from biomedical.biomedical_retriever import BiomedicalRetrieverServer
import torch

retriever = BiomedicalRetrieverServer(
    corpus_path=CONFIG['corpus_file'],
    index_path=CONFIG['index_file'],
    device="cuda"
)

results = retriever.search("breast cancer chemotherapy resistance")
print(f"   [OK] Retriever working - found {len(results)} results")
if results:
    print(f"   Sample: {results[0].get('title', 'N/A')[:60]}...")

# Clean up retriever
del retriever
torch.cuda.empty_cache()

# Test 2: Load model
print("\n2. Testing model loading...")
from transformers import AutoTokenizer, AutoModelForCausalLM

print(f"   Loading {CONFIG['model_name']}...")
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])
model = AutoModelForCausalLM.from_pretrained(
    CONFIG['model_name'],
    torch_dtype=torch.float16,
    device_map="auto"
)

# Quick generation test
inputs = tokenizer("What is breast cancer?", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=20)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"   [OK] Model generating: {response[:50]}...")

# Clean up
del model, tokenizer
torch.cuda.empty_cache()

# Test 3: Check GPU memory
print("\n3. GPU Memory Status:")
total = torch.cuda.get_device_properties(0).total_memory / 1e9
used = torch.cuda.memory_allocated() / 1e9
free = total - used
print(f"   Total: {total:.1f} GB")
print(f"   Used: {used:.1f} GB")
print(f"   Free: {free:.1f} GB")

# Test 4: Check files
print("\n4. Required Files:")
from pathlib import Path
files_ok = True
for name, path in [('Corpus', CONFIG['corpus_file']), 
                   ('Index', CONFIG['index_file']),
                   ('Seeds', CONFIG['seeds_file'])]:
    exists = Path(path).exists()
    status = "[OK]" if exists else "[MISSING]"
    print(f"   {status} {name}: {path}")
    if not exists:
        files_ok = False

print("\n" + "="*60)
if files_ok:
    print("ALL CHECKS PASSED - Ready for training!")
    print("\nNext: Run Cell 9 to start training.")
    print("Training will take approximately 10-15 hours on A100.")
else:
    print("SOME CHECKS FAILED - Please fix issues above before training.")
print("="*60)

In [None]:
# Cell 9: Train Biomedical Proposer (Simplified for Colab)
# Uses HuggingFace TRL instead of veRL for single-GPU compatibility

import os
import sys
import json
import torch
from pathlib import Path
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import Dataset

print("="*60)
print("BIOMEDICAL PROPOSER TRAINING")
print("="*60)

# Verify GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# === STEP 1: Create Training Data ===
print("\n[1/4] Preparing training data...")

# Load seeds and create biomedical QA training prompts
seeds_file = Path(CONFIG['seeds_file'])
corpus_file = Path(CONFIG['corpus_file'])

training_examples = []
with open(seeds_file) as f:
    seeds = [json.loads(line) for line in f]

print(f"  Loaded {len(seeds)} seed papers")

# Create multi-hop biomedical question prompts
# The proposer learns to generate research questions from paper contexts
system_prompt = """You are a biomedical research assistant. Given a scientific paper abstract, generate a challenging multi-hop research question that would require searching PubMed to answer. The question should connect multiple biomedical concepts and require reasoning across papers.

Format your response as a single research question."""

for i, seed in enumerate(seeds):
    abstract = seed.get('abstract', seed.get('text', ''))[:1000]  # Truncate long abstracts
    title = seed.get('title', '')
    
    # Create training example: paper context -> research question
    prompt = f"Based on this paper about '{title}':\n\n{abstract}\n\nGenerate a research question:"
    
    # For initial training, we use simple self-supervision
    # The model learns the format of good biomedical questions
    training_examples.append({
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ],
        "pmid": seed.get('pmid', str(i))
    })

print(f"  Created {len(training_examples)} training examples")

# === STEP 2: Load Model and Tokenizer ===
print("\n[2/4] Loading model...")

model_name = CONFIG['model_name']
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    use_cache=False  # Required for gradient checkpointing
)
model.gradient_checkpointing_enable()

print(f"  Model loaded: {model_name}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

# === STEP 3: Prepare Dataset ===
print("\n[3/4] Tokenizing dataset...")

def tokenize_function(examples):
    # Apply chat template
    texts = []
    for msgs in examples["messages"]:
        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        texts.append(text)
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=1024,
        padding="max_length",
        return_tensors="pt"
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

dataset = Dataset.from_list(training_examples)
dataset = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,
    remove_columns=["messages", "pmid"],
    desc="Tokenizing"
)
dataset.set_format("torch")

print(f"  Dataset size: {len(dataset)}")

# === STEP 4: Training ===
print("\n[4/4] Starting training...")
print(f"  Steps: {CONFIG['max_steps']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")

# Simple training loop for Colab (avoids TRL/veRL complexity)
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

# Dataloader
train_dataloader = DataLoader(
    dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    drop_last=True
)

# Optimizer
optimizer = AdamW(model.parameters(), lr=CONFIG['learning_rate'])

# Scheduler
num_training_steps = min(CONFIG['max_steps'], len(train_dataloader) * CONFIG['gradient_accumulation'])
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.03 * num_training_steps),
    num_training_steps=num_training_steps
)

# Training loop
model.train()
global_step = 0
total_loss = 0
checkpoint_dir = Path(CONFIG['temp_checkpoint'])

progress_bar = tqdm(total=CONFIG['max_steps'], desc="Training")

for epoch in range(10):  # Max epochs (will break early based on steps)
    for batch_idx, batch in enumerate(train_dataloader):
        # Move to GPU
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss / CONFIG['gradient_accumulation']
        
        # Backward pass
        loss.backward()
        total_loss += loss.item()
        
        # Update weights every gradient_accumulation steps
        if (batch_idx + 1) % CONFIG['gradient_accumulation'] == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            global_step += 1
            avg_loss = total_loss * CONFIG['gradient_accumulation']
            total_loss = 0
            
            progress_bar.update(1)
            progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
            
            # Save checkpoint
            if global_step % CONFIG['save_freq'] == 0:
                ckpt_path = checkpoint_dir / f"step_{global_step}"
                ckpt_path.mkdir(parents=True, exist_ok=True)
                model.save_pretrained(ckpt_path)
                tokenizer.save_pretrained(ckpt_path)
                print(f"\n  Saved checkpoint: {ckpt_path}")
            
            if global_step >= CONFIG['max_steps']:
                break
    
    if global_step >= CONFIG['max_steps']:
        break

progress_bar.close()

# === STEP 5: Save Final Model ===
print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)

# Save to Google Drive
final_path = Path(CONFIG['final_checkpoint']) / "proposer_final"
final_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(final_path)
tokenizer.save_pretrained(final_path)

print(f"\nFinal model saved to: {final_path}")
print(f"Total steps: {global_step}")

# Free memory
del model, optimizer
torch.cuda.empty_cache()
print("\n[OK] Training complete! Model saved to Google Drive.")

In [None]:
# Cell 10: Check Storage & Results

from pathlib import Path

def get_dir_size(path):
    total = 0
    path = Path(path)
    if path.exists():
        for f in path.rglob('*'):
            if f.is_file():
                total += f.stat().st_size
    return total / (1024**3)  # GB

print("="*60)
print("STORAGE SUMMARY")
print("="*60)

drive_base = Path('/content/drive/MyDrive/drzero_lite')
local_base = Path('/content/drzero_temp')

print("\nGoogle Drive (persistent):")
corpus_size = get_dir_size(drive_base / 'corpus')
ckpt_size = get_dir_size(drive_base / 'final_checkpoint')
print(f"  Corpus: {corpus_size:.2f} GB")
print(f"  Final checkpoint: {ckpt_size:.2f} GB")
print(f"  TOTAL: {corpus_size + ckpt_size:.2f} GB / 15 GB")

print("\nLocal temp (will be deleted):")
temp_size = get_dir_size(local_base)
print(f"  Temp files: {temp_size:.2f} GB")

print("\n" + "="*60)
total_drive = corpus_size + ckpt_size
if total_drive < 15:
    print(f"[OK] Within 15 GB Drive limit ({total_drive:.1f} GB used)")
else:
    print(f"[WARNING] Exceeds 15 GB limit ({total_drive:.1f} GB used)")
print("="*60)

# List checkpoints
print("\nSaved checkpoints:")
ckpt_dir = drive_base / 'final_checkpoint'
if ckpt_dir.exists():
    for item in sorted(ckpt_dir.iterdir()):
        print(f"  {item.name}")
else:
    print("  None yet")

---
# Training Complete!

## What You Have

- PubMed corpus (10K papers)
- PubMedBERT search index
- Trained proposer model
- All saved to Google Drive (<15 GB)

## Download Your Model

```python
# Zip and download
!zip -r model.zip /content/drive/MyDrive/drzero_lite/final_checkpoint
from google.colab import files
files.download('model.zip')
```

## Test Your Model

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load your trained model
model_path = "/content/drive/MyDrive/drzero_lite/final_checkpoint/step_100"
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")

# Generate
prompt = "Generate a research question about breast cancer and TP53:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=200)
print(tokenizer.decode(outputs[0]))
```

## Next Steps

1. Evaluate on PubMedQA benchmark
2. Generate biomedical QA pairs
3. Fine-tune solver (optional - adds ~5 hours)
4. Iterate for better performance