In [2]:
"""
Multi-model microbiome-disease extraction benchmark.
Direct API calls — no LAVA proxy.

Requires these env vars in .env:
    ANTHROPIC_API_KEY
    OPENAI_API_KEY
    GEMINI_API_KEY
    MOONSHOT_API_KEY
"""

import json
import os
import random
import re
import time

import requests
from dotenv import load_dotenv

load_dotenv()

# --- API Keys ---
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY")

# --- Config ---
EXCLUDE_INDICES = {554, 638, 1025, 1946, 508, 979, 1597, 1776}
SAMPLE_SIZE = 4
SEED = 42
MAX_CHARS = 48_000
TIMEOUT = 180
RETRY_MAX = 2
RETRY_BACKOFF = 10
CALL_DELAY = 3

INPUT_FILE = "MAIN_DATA.json"
OUTPUT_FILE = "extraction_results.json"

# Quick check
missing = []
if not ANTHROPIC_API_KEY: missing.append("ANTHROPIC_API_KEY")
if not OPENAI_API_KEY: missing.append("OPENAI_API_KEY")
if not GEMINI_API_KEY: missing.append("GEMINI_API_KEY")
if not MOONSHOT_API_KEY: missing.append("MOONSHOT_API_KEY")
if missing:
    print(f"WARNING: Missing API keys: {', '.join(missing)}")
    print("Models with missing keys will be skipped.")
else:
    print("All API keys loaded.")

All API keys loaded.


In [3]:
# ---------------------------------------------------------------------------
# API call functions — one per provider
# ---------------------------------------------------------------------------

def call_anthropic(system_prompt, user_prompt):
    """Claude Opus 4.6 via Anthropic Messages API."""
    resp = requests.post(
        "https://api.anthropic.com/v1/messages",
        headers={
            "x-api-key": ANTHROPIC_API_KEY,
            "anthropic-version": "2023-06-01",
            "content-type": "application/json",
        },
        json={
            "model": "claude-opus-4-6",
            "max_tokens": 8192,
            "temperature": 0.2,
            "system": system_prompt,
            "messages": [{"role": "user", "content": user_prompt}],
        },
        timeout=TIMEOUT,
    )
    resp.raise_for_status()
    return resp.json()["content"][0]["text"]


def call_openai(system_prompt, user_prompt):
    """GPT-5.2 via OpenAI Chat Completions API."""
    resp = requests.post(
        "https://api.openai.com/v1/chat/completions",
        headers={
            "Authorization": f"Bearer {OPENAI_API_KEY}",
            "Content-Type": "application/json",
        },
        json={
            "model": "gpt-5.2",
            "max_tokens": 8192,
            "temperature": 0.2,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
        },
        timeout=TIMEOUT,
    )
    resp.raise_for_status()
    return resp.json()["choices"][0]["message"]["content"]


def call_gemini(system_prompt, user_prompt):
    """Gemini 3 Pro via Google generateContent REST API."""
    resp = requests.post(
        "https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-preview:generateContent",
        headers={
            "x-goog-api-key": GEMINI_API_KEY,
            "Content-Type": "application/json",
        },
        json={
            "system_instruction": {"parts": [{"text": system_prompt}]},
            "contents": [{"role": "user", "parts": [{"text": user_prompt}]}],
            "generationConfig": {
                "maxOutputTokens": 8192,
                "temperature": 0.2,
            },
        },
        timeout=TIMEOUT,
    )
    resp.raise_for_status()
    return resp.json()["candidates"][0]["content"]["parts"][0]["text"]


def call_moonshot(system_prompt, user_prompt):
    """Kimi K2.5 via Moonshot OpenAI-compatible API."""
    resp = requests.post(
        "https://api.moonshot.ai/v1/chat/completions",
        headers={
            "Authorization": f"Bearer {MOONSHOT_API_KEY}",
            "Content-Type": "application/json",
        },
        json={
            "model": "kimi-k2.5",
            "max_tokens": 8192,
            "temperature": 0.6,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
        },
        timeout=TIMEOUT,
    )
    resp.raise_for_status()
    return resp.json()["choices"][0]["message"]["content"]


# Model registry — maps name -> (call_fn, api_key)
MODELS = {
    "Claude Opus 4.6": (call_anthropic, ANTHROPIC_API_KEY),
    "GPT-5.2":         (call_openai,    OPENAI_API_KEY),
    "Gemini 3 Pro":    (call_gemini,    GEMINI_API_KEY),
    "Kimi K2.5":       (call_moonshot,  MOONSHOT_API_KEY),
}

print(f"Models registered: {list(MODELS.keys())}")

Models registered: ['Claude Opus 4.6', 'GPT-5.2', 'Gemini 3 Pro', 'Kimi K2.5']


In [4]:
# ---------------------------------------------------------------------------
# Prompts
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = (
    "You are an expert medical data extraction specialist with deep knowledge of "
    "microbiome research methodology and statistics. You must output ONLY valid JSON "
    "— no markdown, no explanation, no preamble."
)

USER_PROMPT_TEMPLATE = """You are an expert medical data extraction specialist with deep knowledge of microbiome research methodology and statistics.

Your task: Extract microbiome-disease relationships from research papers with complete accuracy.

## STEP 1: IDENTIFY STUDY TYPE
First, determine what type of study this is:
- Disease characterization (comparing diseased vs healthy)
- Treatment/intervention study (testing drugs, supplements, transplants)
- Observational/longitudinal study
- Other

## STEP 2: EXTRACT DISEASE INFORMATION
- Primary disease or condition being studied
- Related conditions mentioned
- Control groups or comparison conditions

## STEP 3: EXTRACT ALL BACTERIAL CHANGES
For EACH bacterium mentioned with quantitative or qualitative changes:

Extract at ALL taxonomic levels present:
- Phylum level (if mentioned)
- Family level (if mentioned)
- Genus level (if mentioned)
- Species level (if mentioned)

For each bacterium, record:
- Name (exact as written)
- Taxonomic level (phylum/family/genus/species)
- Direction: "increased", "decreased", "unchanged", or "unclear"
- Quantitative data if available (percentages, fold-changes)
- Statistical significance (p-value, confidence level)
- Context (disease vs control, pre vs post treatment, etc.)

## STEP 4: DISTINGUISH CAUSALITY
CRITICAL: Determine if bacterial changes are:
- Associated with DISEASE state (disease vs healthy)
- Result of TREATMENT/INTERVENTION (pre vs post treatment)
- Correlational only
- Unknown/unclear

## STEP 5: VERIFY COMPLETENESS
- Did you extract EVERY bacterium mentioned with changes?
- Did you check all tables, figures, and text?
- Did you note if the paper says "X bacteria and Y others" (indicating incomplete listing)?
- Did you check for contradictions between sections?

## STEP 6: VALIDATE LOGIC
- Do the directions make biological sense?
- Are there any contradictory statements in the paper?
- Is the statistical significance adequate (adjust for multiple comparisons)?
- Were any bacteria mentioned in discussion but not measured in results?

## OUTPUT FORMAT:
Return a JSON object with this structure:

{{
  "study_type": "disease_characterization | treatment_intervention | observational | other",
  "study_design": "brief description",
  "primary_disease": "disease name or null",
  "related_conditions": ["condition1", "condition2"],
  "sample_size": "number or not specified",
  "statistical_methods": "brief description of analysis methods",

  "bacteria_relationships": [
    {{
      "taxon_name": "exact name from paper",
      "taxonomic_level": "phylum | family | genus | species",
      "direction": "increased | decreased | unchanged | unclear",
      "change_context": "disease_vs_control | treatment_effect | temporal | other",
      "quantitative_data": {{
        "disease_group": "percentage or value",
        "control_group": "percentage or value",
        "fold_change": "X-fold or null",
        "p_value": "value or null",
        "statistical_significance": "significant | not_significant | not_reported"
      }},
      "location_in_paper": "Table X | Figure Y | Results section | Discussion",
      "confidence": "high | medium | low",
      "notes": "any important context or caveats"
    }}
  ],

  "extraction_metadata": {{
    "total_bacteria_found": 0,
    "completeness_assessment": "complete | partial | unclear",
    "potential_missing_data": "description if incomplete",
    "contradictions_found": ["list any contradictions"],
    "limitations": ["any extraction limitations"]
  }}
}}

## IMPORTANT RULES:
1. Extract EVERY bacterium mentioned, even if changes are small or not significant
2. If a paper says "10 species changed" but only lists 7, note the missing 3
3. NEVER fabricate data - if direction is unclear, mark as "unclear"
4. Distinguish between disease effects and treatment effects
5. Note if findings didn't reach statistical significance after multiple comparison adjustment
6. Include bacteria mentioned in discussion even if not in main results (note as "discussion_only")
7. If genus-level and species-level data both exist, include both
8. Check for bacteria that DECREASED in one section but are described differently elsewhere

Now, extract from this paper:

Paper text: {text}

Output ONLY the JSON object, nothing else."""

In [5]:
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def parse_extraction(raw_text):
    """Parse extraction JSON from model response."""
    text = raw_text.strip()
    # Strip markdown fences
    match = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
    if match:
        text = match.group(1).strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    # Find first { ... } block by brace matching
    brace_start = text.find("{")
    if brace_start == -1:
        return None
    depth = 0
    for i in range(brace_start, len(text)):
        if text[i] == "{":
            depth += 1
        elif text[i] == "}":
            depth -= 1
            if depth == 0:
                try:
                    return json.loads(text[brace_start : i + 1])
                except json.JSONDecodeError:
                    return None
    return None


def is_junk_chunk(chunk):
    """Return True if a chunk should be skipped."""
    text = chunk.strip()
    if len(text) < 30:
        return True
    if len(re.findall(r"\[DOI\]|\[PubMed\]|\[PMC", text)) >= 3:
        return True
    if re.match(r"10\.\d{4,}/", text):
        return True
    if len(re.findall(r"[\w.+-]+@[\w.-]+\.\w+", text)) >= 3:
        return True
    return False


def merge_chunks(chunks):
    """Filter junk, merge with smart joining, truncate."""
    clean = [c for c in chunks if not is_junk_chunk(c)]
    if not clean:
        return ""
    parts = [clean[0]]
    for prev, cur in zip(clean, clean[1:]):
        prev_stripped = prev.rstrip()
        if prev_stripped and prev_stripped[-1] in ".!?":
            parts.append(" " + cur)
        else:
            parts.append(". " + cur)
    merged = "".join(parts)
    return merged[:MAX_CHARS] if len(merged) > MAX_CHARS else merged


def call_with_retry(call_fn, system_prompt, user_prompt, model_name):
    """Wrap a call function with retry logic."""
    last_err = None
    for attempt in range(1, RETRY_MAX + 2):
        try:
            return call_fn(system_prompt, user_prompt)
        except Exception as e:
            last_err = str(e)
            if attempt <= RETRY_MAX:
                print(f"    Attempt {attempt} failed: {last_err}. Retrying in {RETRY_BACKOFF}s...")
                time.sleep(RETRY_BACKOFF)
            else:
                raise RuntimeError(f"[{model_name}] All {RETRY_MAX + 1} attempts failed. Last: {last_err}")

print("Helpers loaded.")

Helpers loaded.


In [6]:
# ---------------------------------------------------------------------------
# Load data & sample papers
# ---------------------------------------------------------------------------

print("Loading data...")
with open(INPUT_FILE) as f:
    data = json.load(f)

keys = list(data.keys())
valid_keys = [k for i, k in enumerate(keys) if i not in EXCLUDE_INDICES]
print(f"Total papers: {len(keys)}, valid: {len(valid_keys)}, excluded: {len(keys) - len(valid_keys)}")

random.seed(SEED)
sampled_keys = random.sample(valid_keys, SAMPLE_SIZE)

print(f"\nSampled papers:")
for k in sampled_keys:
    print(f"  [{k}] {data[k]['name'][:80]}")

# Prepare merged text
prepared = {}
for k in sampled_keys:
    paper = data[k]
    merged = merge_chunks(paper["chunks"])
    prepared[k] = {"title": paper["name"], "text": merged, "orig_chunks": len(paper["chunks"])}
    print(f"  Paper {k}: {len(paper['chunks'])} chunks -> {len(merged)} chars")

print(f"\n{len(prepared)} papers ready for extraction.")

Loading data...
Total papers: 2026, valid: 2018, excluded: 8

Sampled papers:
  [1320] Oral Microbiota Linking Associations of Dietary Factors with Recurrent Oral Ulce
  [230] Particle Size, Mass Concentration, and Microbiota in Dental Aerosols
  [52] Evaluation of co-circulating pathogens and microbiome from COVID-19 infections
  [1529] Rifaximin ameliorates influenza A virus infection-induced lung barrier damage by
  Paper 1320: 122 chunks -> 38701 chars
  Paper 230: 93 chunks -> 32628 chars
  Paper 52: 133 chunks -> 48000 chars
  Paper 1529: 139 chunks -> 48000 chars

4 papers ready for extraction.


In [7]:
# ---------------------------------------------------------------------------
# Run extraction across all models
# ---------------------------------------------------------------------------

results = {}
timings = {}

for model_name, (call_fn, api_key) in MODELS.items():
    if not api_key:
        print(f"\n{'='*60}")
        print(f"SKIPPING {model_name} — no API key")
        print(f"{'='*60}")
        continue

    results[model_name] = {}
    timings[model_name] = {}
    print(f"\n{'='*60}")
    print(f"Model: {model_name}")
    print(f"{'='*60}")

    for k in sampled_keys:
        info = prepared[k]
        print(f"  [{model_name}] Paper {k}: {info['title'][:50]}...")

        user_prompt = USER_PROMPT_TEMPLATE.format(text=info["text"])

        try:
            t0 = time.time()
            raw = call_with_retry(call_fn, SYSTEM_PROMPT, user_prompt, model_name)
            elapsed = time.time() - t0
            timings[model_name][k] = round(elapsed, 1)

            parsed = parse_extraction(raw)
            if parsed is None:
                print(f"    WARNING: JSON parse failed. Storing raw snippet.")
                results[model_name][k] = {"error": "JSON parse failed", "raw": raw[:500]}
            else:
                n_rels = len(parsed.get("bacteria_relationships", []))
                print(f"    OK: {n_rels} relationships extracted ({elapsed:.1f}s)")
                results[model_name][k] = parsed
        except Exception as e:
            print(f"    FAILED: {e}")
            results[model_name][k] = {"error": str(e)}
            timings[model_name][k] = None

        time.sleep(CALL_DELAY)

print("\nExtraction complete.")


Model: Claude Opus 4.6
  [Claude Opus 4.6] Paper 1320: Oral Microbiota Linking Associations of Dietary Fa...
    OK: 19 relationships extracted (52.5s)
  [Claude Opus 4.6] Paper 230: Particle Size, Mass Concentration, and Microbiota ...
    OK: 0 relationships extracted (16.3s)
  [Claude Opus 4.6] Paper 52: Evaluation of co-circulating pathogens and microbi...
  [Claude Opus 4.6] Paper 1529: Rifaximin ameliorates influenza A virus infection-...
    OK: 18 relationships extracted (53.2s)

Model: GPT-5.2
  [GPT-5.2] Paper 1320: Oral Microbiota Linking Associations of Dietary Fa...
    Attempt 1 failed: Invalid header value b'Bearer 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\n'. Retrying in 10s...
    Attempt 2 failed: Invalid header value b'Bearer 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\n'. Retrying in 10s.

In [8]:
# ---------------------------------------------------------------------------
# Save results
# ---------------------------------------------------------------------------

output = {
    "papers_sampled": sampled_keys,
    "timings": timings,
    "results": results,
}
with open(OUTPUT_FILE, "w") as f:
    json.dump(output, f, indent=2)
print(f"Saved to {OUTPUT_FILE}")

Saved to extraction_results.json


In [9]:
# ---------------------------------------------------------------------------
# Benchmark summary
# ---------------------------------------------------------------------------

print(f"\n{'Paper':>8} | {'Model':<20} | {'Rels':>5} | {'Time':>7} | Status")
print(f"{'-'*8}-+-{'-'*20}-+-{'-'*5}-+-{'-'*7}-+-{'-'*10}")

for k in sampled_keys:
    for model_name in MODELS:
        if model_name not in results:
            continue
        entry = results[model_name].get(k, {})
        t = timings.get(model_name, {}).get(k)
        t_str = f"{t}s" if t else "—"

        if "error" in entry:
            print(f"{k:>8} | {model_name:<20} | {'ERR':>5} | {t_str:>7} | {entry['error'][:40]}")
        else:
            n = len(entry.get("bacteria_relationships", []))
            print(f"{k:>8} | {model_name:<20} | {n:>5} | {t_str:>7} | OK")

# Per-model aggregates
print(f"\n{'Model':<20} | {'Avg Rels':>8} | {'Avg Time':>8} | {'Errors':>6}")
print(f"{'-'*20}-+-{'-'*8}-+-{'-'*8}-+-{'-'*6}")

for model_name in MODELS:
    if model_name not in results:
        continue
    rels = []
    times = []
    errs = 0
    for k in sampled_keys:
        entry = results[model_name].get(k, {})
        if "error" in entry:
            errs += 1
        else:
            rels.append(len(entry.get("bacteria_relationships", [])))
        t = timings.get(model_name, {}).get(k)
        if t:
            times.append(t)

    avg_r = f"{sum(rels)/len(rels):.1f}" if rels else "—"
    avg_t = f"{sum(times)/len(times):.1f}s" if times else "—"
    print(f"{model_name:<20} | {avg_r:>8} | {avg_t:>8} | {errs:>6}")


   Paper | Model                |  Rels |    Time | Status
---------+----------------------+-------+---------+-----------
    1320 | Claude Opus 4.6      |    19 |   52.5s | OK
    1320 | GPT-5.2              |   ERR |       — | [GPT-5.2] All 3 attempts failed. Last: I
    1320 | Gemini 3 Pro         |    17 |   60.0s | OK
    1320 | Kimi K2.5            |   ERR |       — | [Kimi K2.5] All 3 attempts failed. Last:
     230 | Claude Opus 4.6      |     0 |   16.3s | OK
     230 | GPT-5.2              |   ERR |       — | [GPT-5.2] All 3 attempts failed. Last: I
     230 | Gemini 3 Pro         |     0 |   28.1s | OK
     230 | Kimi K2.5            |   ERR |       — | [Kimi K2.5] All 3 attempts failed. Last:
      52 | Claude Opus 4.6      |   ERR |   77.4s | JSON parse failed
      52 | GPT-5.2              |   ERR |       — | [GPT-5.2] All 3 attempts failed. Last: I
      52 | Gemini 3 Pro         |   ERR |   66.1s | JSON parse failed
      52 | Kimi K2.5            |   ERR |       — | 