In [6]:
"""
azure_pubmed_qa_builder_100.py
------------------------------
Generate 100 de-duplicated PubMed-QA samples (English) with Azure OpenAI + PubMed.
"""

from __future__ import annotations
import os, json, time, random, hashlib, requests, itertools
from pathlib import Path
from typing import List, Dict, Any, Optional, Set

from openai import AzureOpenAI
from tqdm.auto import tqdm

# ========== 0. Azure OpenAI ==========
AZURE_DEPLOYMENT_NAME = "gpt-4.1-noah"          # ← your deployment name
client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_version="2024-12-01-preview",
)

# ========== 1. PubMed helpers ==========
PUBMED_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"

def pubmed_search(term: str, retmax: int, api_key: Optional[str]) -> List[str]:
    params = {"db": "pubmed", "term": term,
              "retmax": retmax, "retmode": "json", "sort": "pub+date"}
    if api_key:
        params["api_key"] = api_key
    r = requests.get(f"{PUBMED_BASE}/esearch.fcgi", params=params, timeout=15)
    r.raise_for_status()
    return r.json()["esearchresult"]["idlist"]

def fetch_metadata(pmids: List[str], api_key: Optional[str]) -> Dict[str, Dict[str, Any]]:
    if not pmids:
        return {}
    params = {"db": "pubmed", "id": ",".join(pmids), "retmode": "json"}
    if api_key:
        params["api_key"] = api_key
    r = requests.get(f"{PUBMED_BASE}/esummary.fcgi", params=params, timeout=30)
    r.raise_for_status()
    result = r.json()["result"]
    return {pid: result[pid] for pid in pmids if pid in result}

# ========== 2. GPT paraphrase ==========
SYSTEM_MSG = ("You are a helpful assistant that creates natural English questions "
              "for a biomedical literature QA dataset.")

def paraphrase_query(meta: Dict[str, Any]) -> str:
    title = meta.get("title", "")
    journal = meta.get("fulljournalname", "")
    year = meta.get("pubdate", "").split(" ")[0]
    author = meta.get("sortfirstauthor", "")
    prompt = (
        "Rewrite the following search intent into one concise, natural English "
        "question (do not mention PubMed, E-utilities, or search syntax):\n\n"
        f"Intent: Find the unique PMID of the article whose title contains "
        f"“{title}”, first author {author}, published in {journal} in {year}."
    )
    resp = client.chat.completions.create(
        model=AZURE_DEPLOYMENT_NAME,
        messages=[{"role": "system", "content": SYSTEM_MSG},
                  {"role": "user", "content": prompt}],
        temperature=0.6,
        max_tokens=128,
    )
    return resp.choices[0].message.content.strip()

# ========== 3. Query builder ==========
def build_unique_query(meta: Dict[str, Any]) -> str:
    title_phrase = meta.get("title", "").split(":")[0]
    author = meta.get("sortfirstauthor", "")
    year = meta.get("pubdate", "").split(" ")[0]
    return f"\"{title_phrase}\"[ti] AND {author}[au] AND {year}[dp]"

def make_tool_call(term: str) -> Dict[str, Any]:
    return {"tool": "pubmed.search", "params": {"term": term, "retmax": 1}}

# ========== 4. Seed sampling & dataset ==========
MESH_TOPICS = [
    "oncology", "neurology", "cardiology", "immunology", "gastroenterology",
    "endocrinology", "pulmonology", "dermatology", "psychiatry", "genetics"
]

def sample_seed_pmids(topic: str, n: int, api_key: Optional[str]) -> List[str]:
    term = f"{topic}[MeSH Major Topic] AND 2023:2025[pdat]"
    return pubmed_search(term, retmax=n, api_key=api_key)

def build_dataset(target_n: int = 100,
                  ncbi_key: Optional[str] = None) -> List[Dict[str, Any]]:
    dataset: List[Dict[str, Any]] = []
    used_pmids: Set[str] = set()

    # round-robin over MESH_TOPICS to increase diversity
    topic_cycle = itertools.cycle(MESH_TOPICS)

    with tqdm(total=target_n, desc="Generating QA samples") as pbar:
        while len(dataset) < target_n:
            topic = next(topic_cycle)
            seed_pmids = sample_seed_pmids(topic, n=20, api_key=ncbi_key)

            # Remove already used PMIDs
            seed_pmids = [pid for pid in seed_pmids if pid not in used_pmids]
            if not seed_pmids:
                continue

            metas = fetch_metadata(seed_pmids, api_key=ncbi_key)
            for pid in seed_pmids:
                if pid in used_pmids or pid not in metas:
                    continue

                meta = metas[pid]
                term = build_unique_query(meta)
                question = paraphrase_query(meta)

                dataset.append({
                    "id": hashlib.md5(term.encode()).hexdigest()[:16],
                    "question": question,
                    "tool_calls": [make_tool_call(term)],
                    "answer": [pid]
                })
                used_pmids.add(pid)
                pbar.update(1)

                # NCBI 速率限制：≤3 次 / 秒
                time.sleep(0.34)

                if len(dataset) >= target_n:
                    break
    return dataset

# ========== 5. Save ==========
def save_json(data: List[Dict[str, Any]], path: str) -> None:
    Path(path).write_text(json.dumps({"dataset": data},
                                     ensure_ascii=False, indent=2))
    print(f"\n✅ Saved {len(data)} unique samples → {path}")

# ========== 6. Run ==========
if __name__ == "__main__":
    ncbi_key = os.getenv("NCBI_API_KEY")
    qa_data = build_dataset(target_n=100, ncbi_key=ncbi_key)
    save_json(qa_data, "pubmed_qa_dataset_100.json")


Generating QA samples:   0%|          | 0/100 [00:00<?, ?it/s]


✅ Saved 100 unique samples → pubmed_qa_dataset_100.json
