#### Data Pre-processing

Load raw trial data

In [1]:
import pandas as pd 

data = pd.read_csv("data/raw_trials.csv")

In [2]:
print(data.columns)

Index(['title', 'objective', 'outcome_details', 'phase',
       'primary_completion_date', 'primary_endpoints_reported_date',
       'prior_concurrent_therapy', 'start_date', 'study_design',
       'treatment_plan', 'record_type', 'patients_per_site_per_month',
       'primary_endpoint_json', 'other_endpoint_json', 'associated_cro_json',
       'notes_json', 'outcomes_json', 'patient_dispositions_json',
       'results_json', 'study_keywords_json', 'tags_json',
       'primary_drugs_tested_json', 'other_drugs_tested_json',
       'therapeutic_areas_json', 'bmt_other_drugs_tested_json',
       'bmt_primary_drugs_tested_json', 'ct_gov_listed_locations_json',
       'ct_gov_mesh_terms_json'],
      dtype='object')


In [3]:
print(data.isna().sum().to_markdown())
print("Shape:", data.shape)

|                                 |   0 |
|:--------------------------------|----:|
| title                           |   0 |
| objective                       |   3 |
| outcome_details                 | 146 |
| phase                           |   0 |
| primary_completion_date         |  61 |
| primary_endpoints_reported_date | 161 |
| prior_concurrent_therapy        | 184 |
| start_date                      |  45 |
| study_design                    |  16 |
| treatment_plan                  |   1 |
| record_type                     |   0 |
| patients_per_site_per_month     | 119 |
| primary_endpoint_json           |   0 |
| other_endpoint_json             |   0 |
| associated_cro_json             |   0 |
| notes_json                      |   0 |
| outcomes_json                   |   0 |
| patient_dispositions_json       |   0 |
| results_json                    |   0 |
| study_keywords_json             |   0 |
| tags_json                       |   0 |
| primary_drugs_tested_json       

Generate unique hash per trial since trial id is missing
- i.e. "tid_0e8fa21079f928135dfc6164a15285f8"

In [4]:
import hashlib
import json
from pathlib import Path

OUTPUT_PATH = Path("cache/raw_trials_with_hash.csv")
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)

# ---------------------------------------------------------
# If file already exists → skip generation
# ---------------------------------------------------------
if OUTPUT_PATH.exists():
    print(f"⚠️ {OUTPUT_PATH} already exists — skipping hash generation.")
else:
    print("Generating raw_trials_with_hash.csv ...")

    def make_trial_hash(row):
        """
        Deterministic hash for a trial based on stable fields.
        You can add/remove fields if needed.
        """
        payload = {
            "title": row.get("title", ""),
            "start_date": row.get("start_date", ""),
            "phase": row.get("phase", ""),
        }
        raw = json.dumps(payload, sort_keys=True, ensure_ascii=False)
        return "tid_" + hashlib.md5(raw.encode("utf-8")).hexdigest()

    # Create trial_hash column
    data["trial_hash"] = data.apply(make_trial_hash, axis=1)

    # Move trial_hash to first column
    cols = ["trial_hash"] + [c for c in data.columns if c != "trial_hash"]
    data = data[cols]

    print(data.columns)
    print(data.shape)

    # Export
    data.to_csv(OUTPUT_PATH, index=False)
    print(f"✅ Saved to {OUTPUT_PATH}")

Generating raw_trials_with_hash.csv ...
Index(['trial_hash', 'title', 'objective', 'outcome_details', 'phase',
       'primary_completion_date', 'primary_endpoints_reported_date',
       'prior_concurrent_therapy', 'start_date', 'study_design',
       'treatment_plan', 'record_type', 'patients_per_site_per_month',
       'primary_endpoint_json', 'other_endpoint_json', 'associated_cro_json',
       'notes_json', 'outcomes_json', 'patient_dispositions_json',
       'results_json', 'study_keywords_json', 'tags_json',
       'primary_drugs_tested_json', 'other_drugs_tested_json',
       'therapeutic_areas_json', 'bmt_other_drugs_tested_json',
       'bmt_primary_drugs_tested_json', 'ct_gov_listed_locations_json',
       'ct_gov_mesh_terms_json'],
      dtype='object')
(184, 29)
✅ Saved to cache/raw_trials_with_hash.csv


#### Task 1

Using a chatbot, identify all interventions from each trial. For each intervention...
- label as the investigational product, active comparator, or placebo
- list all of the alternative names
- identify the molecular target 
- identify the mechanism of action
- for investigational products
    - trial trove drug id
    - biomedtracker drug id

In [5]:
import re
import json
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
from services.openai_wrapper import OpenAIWrapper

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
BASE_DIR = Path("cache")

TRIALS_WITH_HASH_CSV = Path("cache/raw_trials_with_hash.csv")

DRUG_ROLE_DIR = BASE_DIR / "trial_drug_roles"
DRUG_ROLE_DIR.mkdir(parents=True, exist_ok=True)

DRUG_ROLE_LOG_DIR = BASE_DIR / "trial_drug_roles_log"
DRUG_ROLE_LOG_DIR.mkdir(parents=True, exist_ok=True)

MASTER_ROLES_PATH = BASE_DIR / "trial_drug_roles_master.json"

MODEL = "gpt-5"
client = OpenAIWrapper()

MAX_WORKERS = 8

# Columns to feed into the chatbot
RELEVANT_COLS = [
    "title",
    "objective",
    "outcome_details",
    "treatment_plan",
    "notes_json",
    "results_json",
    "primary_drugs_tested_json",
    "other_drugs_tested_json",
    "therapeutic_areas_json",
    "bmt_other_drugs_tested_json",
    "bmt_primary_drugs_tested_json",
    "ct_gov_mesh_terms_json",
]

# -------------------------------------------------
# Helpers
# -------------------------------------------------
def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}

    # Direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    # Fallback: first {...} region
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}

    return {}


def build_prompt(trial_payload: dict) -> str:
    """
    Build prompt asking the LLM to:
    - Extract drug names
    - Canonicalize names by removing company/manufacturer/location qualifiers
    - Deduplicate synonymous names
    - For each canonical drug, return a dict with:
        * role (Investigational Product / Placebo / Active Comparator / Standard of Care)
        * alternative_names (list)
        * molecular_target
        * mechanism
        * tt_drug_id (TrialTrove/PharmaProjects drugId as string)
        * bmt_drug_id (BioMedTracker bmtDrugId as string)
    """
    payload_json = json.dumps(trial_payload, ensure_ascii=False, indent=2)

    return f"""
You are a clinical trial design and interpretation expert.

You are given structured information about a clinical trial, including:
- Title and objective
- Study design and treatment plan
- JSON fields listing drugs tested in the study:
  - primary_drugs_tested_json
  - other_drugs_tested_json
  - bmt_other_drugs_tested_json
  - bmt_primary_drugs_tested_json
- These JSON fields may also contain metadata such as
  drugApprovalStatus (Approved / Unapproved), mechanisms, etc.
- In the TrialTrove/PharmaProjects JSON blocks, the unique drug identifier
  is usually under a key like "drugId".
- In the BioMedTracker JSON blocks, the unique drug identifier
  is usually under a key like "bmtDrugId".

Your tasks:

1. Identify all DISTINCT physical drug entities explicitly used in the study.
   - Strings in the *_drugs_tested_json fields are drug-name candidates.
   - If these fields contain structured JSON, infer names from keys such as
     "name", "drug_name", "drugName", "drugPrimaryName", "preferred_name", "label", etc.

2. Canonicalize each drug name:
   Remove company names, manufacturer qualifiers, geographic qualifiers,
   dosage-form qualifiers, or parenthetical descriptors that do NOT change
   the name of the underlying drug.
   Examples of correct canonicalization:
   - "AlphaBlocker (CompanyX)" → "AlphaBlocker"
   - "Recombinant Growth Factor (rgf)" → "Recombinant Growth Factor"
   - "DrugX citrate (RegionY)" → "DrugX citrate"
   - "BrandName (compound-42, MakerCorp)" → "BrandName"

   Keep only the essential drug or brand name as the canonical key.

3. Deduplicate synonymous names referring to the SAME drug.
   - If multiple variations refer to one physical drug, keep ONE canonical key.
   - Prefer the simplest, clean name.
   - Collect all other variations in alternative_names.

4. For EACH distinct drug, build an object with SIX fields:

   - "role": one of:
       * "Investigational Product"
       * "Placebo"
       * "Active Comparator"
       * "Standard of Care"

   ROLE ASSIGNMENT GUIDANCE:

   A. "Investigational Product"
      - Use ONLY for the sponsor's proprietary or novel product.
      - Clues: unapproved, new mechanism, highlighted in title/objective.
      - Do NOT label common chemotherapy or widely used drugs this way.

   B. "Standard of Care"
      - Use for established backbone therapies, such as common chemotherapies
        or widely used drugs in the disease area.
      - Examples (fictional): DrugX, Chemo-A, Cytotoxin-7, etc.

   C. "Active Comparator"
      - Use when a non-placebo drug is explicitly the control arm.
      - Clues: terms like "versus", "comparator", "control regimen".

   D. "Placebo"
      - Use for inert or sham treatments.

   SUMMARY:
   - Proprietary or novel study drug → "Investigational Product".
   - Classical or widely used therapy → "Standard of Care".
   - Control regimen (non-placebo) → "Active Comparator".
   - Inert control → "Placebo".

   - "alternative_names": list of synonymous or variant names.
     Examples:
     * ABC-123 → ["Compound-ABC", "ABC123"]
     * BrandX → ["generic compound name"]

   - "molecular_target": e.g., "CD20", "Kinase-A", "Receptor-Z".
     If unknown, use "".

   - "mechanism": e.g., "monoclonal antibody", "kinase inhibitor",
     "fusion protein", "PD-1/LAG-3 bispecific antibody", etc.
     If not inferable, use "".

   - "tt_drug_id": the TrialTrove/PharmaProjects drugId (from fields like
     primary_drugs_tested_json / other_drugs_tested_json) for this drug, as a STRING.
     If no matching ID can be determined, set this to "".

   - "bmt_drug_id": the BioMedTracker bmtDrugId (from fields like
     bmt_primary_drugs_tested_json / bmt_other_drugs_tested_json) for this drug,
     as a STRING. If no matching ID can be determined, set this to "".

   ID ASSIGNMENT GUIDANCE:
   - Only assign non-empty "tt_drug_id" and "bmt_drug_id" when you can confidently
     match the canonical drug name to the corresponding JSON object.
   - It is especially important to assign these IDs for drugs whose "role"
     is "Investigational Product".
   - When converting numeric IDs to strings, do NOT pad or modify them:
     e.g., drugId 170544 → "170544", bmtDrugId 42756 → "42756".

Important rules:
- "role" MUST use only the allowed strings.
- No invented drugs.
- Combination therapies: classify EACH component using the rules above.
- Every drug object MUST contain ALL of the following keys:
  "role", "alternative_names", "molecular_target", "mechanism",
  "tt_drug_id", and "bmt_drug_id".
- If you cannot determine a value, use the empty string "" (for IDs,
  molecular_target, mechanism) or an empty list [] (for alternative_names).

Input JSON:
{payload_json}

Output format (IMPORTANT):
- Return ONLY a valid JSON object with:
    - keys   = canonical drug names
    - values = objects with EXACTLY:
        * "role"
        * "alternative_names"
        * "molecular_target"
        * "mechanism"
        * "tt_drug_id"
        * "bmt_drug_id"

Example output:
{{
  "ABC-123": {{
    "role": "Investigational Product",
    "alternative_names": ["ABC123", "Compound-ABC"],
    "molecular_target": "Receptor-Z",
    "mechanism": "Bispecific antibody",
    "tt_drug_id": "123456",
    "bmt_drug_id": "78901"
  }},
  "DrugX": {{
    "role": "Standard of Care",
    "alternative_names": ["GenericX", "ChemX"],
    "molecular_target": "Enzyme-A",
    "mechanism": "Antimetabolite",
    "tt_drug_id": "",
    "bmt_drug_id": ""
  }},
  "Placebo": {{
    "role": "Placebo",
    "alternative_names": [],
    "molecular_target": "",
    "mechanism": "Inert comparator",
    "tt_drug_id": "",
    "bmt_drug_id": ""
  }}
}}
""".strip()


# Shared counters & master mapping
counter = {
    "processed": 0,
    "skipped_existing": 0,
    "llm_error": 0,
    "parse_error": 0,
}
counter_lock = threading.Lock()

master_roles: dict[str, dict] = {}
master_lock = threading.Lock()


def process_trial(row: dict, idx: int, total: int) -> None:
    """Process one trial: prompt LLM, save output & log (only if valid)."""
    trial_hash = str(row.get("trial_hash", "")).strip()
    if not trial_hash:
        print(f"⚠️ [{idx}/{total}] Missing trial_hash, skipping")
        return

    out_fp = DRUG_ROLE_DIR / f"{trial_hash}.json"
    if out_fp.exists():
        with counter_lock:
            counter["skipped_existing"] += 1
        return

    # Build payload from selected columns
    trial_payload = {"trial_hash": trial_hash}
    for col in RELEVANT_COLS:
        trial_payload[col] = row.get(col, "")

    prompt = build_prompt(trial_payload)

    token = trial_hash
    hash_id = trial_hash

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    # Call LLM
    try:
        t0 = time.perf_counter()
        res = client.query(prompt=prompt, model=MODEL)
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)

    except Exception as e:
        print(f"⚠️ [{idx}/{total}] LLM error for trial_hash={trial_hash}: {e}")
        with counter_lock:
            counter["llm_error"] += 1
        return

    drug_roles = extract_json_object(text_response)

    # Treat non-dict OR empty dict as invalid → do NOT save anything
    if not isinstance(drug_roles, dict) or not drug_roles:
        print(f"⚠️ [{idx}/{total}] JSON parse/validity error trial_hash={trial_hash}, raw={text_response!r}")
        with counter_lock:
            counter["parse_error"] += 1
        return

    mapped = {
        "trial_hash": trial_hash,
        "title": row.get("title"),
        "drug_roles": drug_roles,
        "source": "llm",
    }

    # Save per-trial roles JSON
    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    # Log entry
    log_payload = {
        "token": token,
        "hash_id": hash_id,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (DRUG_ROLE_LOG_DIR / f"{hash_id}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    # Update master roles
    with master_lock:
        master_roles[trial_hash] = mapped
        MASTER_ROLES_PATH.write_text(
            json.dumps(master_roles, ensure_ascii=False, indent=2),
            encoding="utf-8"
        )

    with counter_lock:
        counter["processed"] += 1
        if counter["processed"] % 50 == 0:
            print(f"Progress: processed {counter['processed']} trials...")


# -------------------------------------------------
# RUN CONCURRENTLY
# -------------------------------------------------
df_trials = pd.read_csv(TRIALS_WITH_HASH_CSV, dtype=str).fillna("")
rows = df_trials.to_dict(orient="records")
total_trials = len(rows)
print(f"Loaded {total_trials} trials from {TRIALS_WITH_HASH_CSV}")

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
    futures = {
        ex.submit(process_trial, row, idx, total_trials): row.get("trial_hash")
        for idx, row in enumerate(rows, start=1)
    }
    for fut in as_completed(futures):
        th = futures[fut]
        try:
            fut.result()
        except Exception as e:
            print(f"⚠️ Worker error trial_hash={th}: {e}")

print(
    f"✅ Trial drug-role mapping complete. "
    f"processed={counter['processed']}, "
    f"skipped={counter['skipped_existing']}, "
    f"llm_error={counter['llm_error']}, "
    f"parse_error={counter['parse_error']}"
)
print(f"Roles directory: {DRUG_ROLE_DIR}")
print(f"Log directory:   {DRUG_ROLE_LOG_DIR}")
print(f"Master roles:    {MASTER_ROLES_PATH}")

Loaded 184 trials from cache/raw_trials_with_hash.csv
Progress: processed 50 trials...
Progress: processed 100 trials...
Progress: processed 150 trials...
✅ Trial drug-role mapping complete. processed=184, skipped=0, llm_error=0, parse_error=0
Roles directory: cache/trial_drug_roles
Log directory:   cache/trial_drug_roles_log
Master roles:    cache/trial_drug_roles_master.json


In [6]:
import json
from pathlib import Path

LOG_DIR = Path("cache/trial_drug_roles_log")

total_cost = 0.0
num_entries = 0
costs = []

for fp in LOG_DIR.glob("*.json"):
    try:
        log = json.loads(fp.read_text(encoding="utf-8"))
        c = float(log.get("total_cost") or 0.0)
        total_cost += c
        costs.append((fp.name, c))
        num_entries += 1
    except Exception as e:
        print(f"⚠️ Error reading {fp.name}: {e}")

# Sort descending by cost
costs_sorted = sorted(costs, key=lambda x: x[1], reverse=True)

print("========== LLM COST SUMMARY ==========")
print(f"Total LLM cost:             ${total_cost:,.4f}")
print(f"Number of logged trials:     {num_entries}")
if num_entries > 0:
    print(f"Average cost per trial:      ${total_cost / num_entries:,.4f}")
print("")

print("Top 10 most expensive trials:")
for name, c in costs_sorted[:10]:
    print(f"  {name}: ${c:,.4f}")

print("========================================")

Total LLM cost:             $4.3453
Number of logged trials:     184
Average cost per trial:      $0.0236

Top 10 most expensive trials:
  tid_261f0233308ca080d1c60e3fda61ca85.json: $0.0692
  tid_1158b3369546dc4b16dc21c8c026b619.json: $0.0606
  tid_e0a77c4ecf93cf781f04cc467c974511.json: $0.0522
  tid_94883aa2d583afced004e22a7991ef3e.json: $0.0519
  tid_196721abc2d5ee98883da9bfcf5bb255.json: $0.0486
  tid_e9e01f51b6680ba4f467ac191bb307c5.json: $0.0467
  tid_8b4d60a5fddc078962af34399d7e342c.json: $0.0452
  tid_763e3011bc90e46c88c7a2953a39ed2a.json: $0.0447
  tid_837737698a5271d314ea8208addb2d72.json: $0.0440
  tid_7e80effdd579ba535ef686ac50dcc4bc.json: $0.0431


In [3]:
import json
from pathlib import Path

import pandas as pd

# -------------------------------------------------
# Build trial_product_breakdown.csv
# -------------------------------------------------

# Base directory for cache + input/output
BASE_DIR = Path("cache")

# Directory that contains per-trial drug-role JSONs
DRUG_ROLE_DIR = BASE_DIR / "trial_drug_roles"

# Output CSV path
OUT_CSV = BASE_DIR / "trial_product_breakdown.csv"

rows = []

for fp in DRUG_ROLE_DIR.glob("*.json"):
    try:
        obj = json.loads(fp.read_text(encoding="utf-8"))
    except Exception as e:
        print(f"⚠️ Error reading {fp.name}: {e}")
        continue

    trial_hash = obj.get("trial_hash")
    if not trial_hash:
        print(f"⚠️ Missing trial_hash in {fp.name}, skipping")
        continue

    drug_roles = obj.get("drug_roles") or {}
    if not isinstance(drug_roles, dict):
        print(f"⚠️ drug_roles not dict in {fp.name}, skipping")
        continue

    # Containers
    inv_names = []
    inv_alt_names = []          # list of lists
    inv_targets = []
    inv_mechanisms = []
    inv_tt_ids = []
    inv_bmt_ids = []

    ac_names = []
    ac_alt_names = []           # list of lists
    ac_targets = []
    ac_mechanisms = []
    ac_tt_ids = []
    ac_bmt_ids = []

    plc_names = []
    plc_alt_names = []          # list of lists
    plc_targets = []
    plc_mechanisms = []

    soc_names = []
    soc_alt_names = []          # list of lists
    soc_targets = []
    soc_mechanisms = []
    soc_tt_ids = []
    soc_bmt_ids = []

    for drug_name, meta in drug_roles.items():
        if not isinstance(meta, dict):
            continue

        role = (meta.get("role") or "").strip()
        role_norm = role.lower()

        alt_names = meta.get("alternative_names") or []
        if not isinstance(alt_names, list):
            alt_names = [str(alt_names)]

        molecular_target = meta.get("molecular_target") or ""
        mechanism = meta.get("mechanism") or ""

        # IDs are always stored as strings in the LLM output, but be defensive
        tt_id = str(meta.get("tt_drug_id") or "")
        bmt_id = str(meta.get("bmt_drug_id") or "")

        if role_norm == "investigational product":
            inv_names.append(drug_name)
            inv_alt_names.append(alt_names)
            inv_targets.append(molecular_target)
            inv_mechanisms.append(mechanism)
            inv_tt_ids.append(tt_id)
            inv_bmt_ids.append(bmt_id)

        elif role_norm == "active comparator":
            ac_names.append(drug_name)
            ac_alt_names.append(alt_names)
            ac_targets.append(molecular_target)
            ac_mechanisms.append(mechanism)
            ac_tt_ids.append(tt_id)
            ac_bmt_ids.append(bmt_id)

        elif role_norm == "placebo":
            plc_names.append(drug_name)
            plc_alt_names.append(alt_names)
            plc_targets.append(molecular_target)
            plc_mechanisms.append(mechanism)

        elif role_norm == "standard of care":
            soc_names.append(drug_name)
            soc_alt_names.append(alt_names)
            soc_targets.append(molecular_target)
            soc_mechanisms.append(mechanism)
            soc_tt_ids.append(tt_id)
            soc_bmt_ids.append(bmt_id)

    rows.append(
        {
            "trial_hash": trial_hash,

            "investigational_products": inv_names,
            "investigational_products_alternative_names": inv_alt_names,
            "investigational_products_molecular_target": inv_targets,
            "investigational_products_mechanism": inv_mechanisms,
            "investigational_products_tt_drug_id": inv_tt_ids,
            "investigational_products_bmt_drug_id": inv_bmt_ids,

            "active_comparators": ac_names,
            "active_comparators_alternative_names": ac_alt_names,
            "active_comparators_molecular_target": ac_targets,
            "active_comparators_mechanism": ac_mechanisms,
            "active_comparators_tt_drug_id": ac_tt_ids,
            "active_comparators_bmt_drug_id": ac_bmt_ids,

            "placebos": plc_names,
            "placebos_alternative_names": plc_alt_names,
            "placebos_molecular_target": plc_targets,
            "placebos_mechanism": plc_mechanisms,

            "standard_of_care": soc_names,
            "standard_of_care_alternative_names": soc_alt_names,
            "standard_of_care_molecular_target": soc_targets,
            "standard_of_care_mechanism": soc_mechanisms,
            "standard_of_care_tt_drug_id": soc_tt_ids,
            "standard_of_care_bmt_drug_id": soc_bmt_ids,
        }
    )

df_out = pd.DataFrame(rows).sort_values("trial_hash")

OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
df_out.to_csv(OUT_CSV, index=False)

print(f"Saved trial product breakdown to {OUT_CSV}")
print(df_out.head().to_markdown())

Saved trial product breakdown to cache/trial_product_breakdown.csv
|     | trial_hash                           | investigational_products                      | investigational_products_alternative_names                                                                                                                                                 | investigational_products_molecular_target   | investigational_products_mechanism                                    | investigational_products_tt_drug_id   | investigational_products_bmt_drug_id   | active_comparators                       | active_comparators_alternative_names                                                                                                                                                                                            | active_comparators_molecular_target   | active_comparators_mechanism                                 | active_comparators_tt_drug_id   | active_comparators_bmt_drug_id   | placebos 

Manually check the rows with no investigational products

In [4]:
import ast
import pandas as pd

IN_CSV = BASE_DIR / "trial_product_breakdown.csv"

df = pd.read_csv(IN_CSV, dtype=str).fillna("")

def parse_listish(s: str):
    """
    Parse a stringified list like "['A', 'B']" into a Python list.
    If parsing fails or the cell is empty, return [].
    """
    if not isinstance(s, str):
        return []
    s = s.strip()
    if not s:
        return []
    # Common empty-list cases
    if s in ("[]", "[ ]"):
        return []
    try:
        val = ast.literal_eval(s)
        if isinstance(val, list):
            return val
        # If it's something else, treat as a single non-empty token
        return [val]
    except Exception:
        # Fallback: treat non-empty string as a single element
        return [s]

# Parse the investigational_products column into real lists
df["investigational_products_parsed"] = df["investigational_products"].apply(parse_listish)

# Flag rows with no investigational products
no_inv_mask = df["investigational_products_parsed"].apply(lambda x: len(x) == 0)

num_no_inv = int(no_inv_mask.sum())
total = len(df)

print(f"Rows with NO investigational products: {num_no_inv} / {total}")

# Show a few examples
print(
    df.loc[no_inv_mask, ["trial_hash", "investigational_products"]]
      .head(20)
      .to_markdown(index=False)
)

Rows with NO investigational products: 4 / 184
| trial_hash                           | investigational_products   |
|:-------------------------------------|:---------------------------|
| tid_4c45730f6411aa1e5a38bb1223d66988 | []                         |
| tid_67de51bf9728e056a6fb42c76e4b0212 | []                         |
| tid_8cab7b7177fcb0d10255bced8b0633ee | []                         |
| tid_bb1e0571142dde8a49976632c349593c | []                         |


Manual checks 
- tid_4c45730f6411aa1e5a38bb1223d66988
    - This trial is combining three standard-of-care agents into a regimen “DCF”
- tid_67de51bf9728e056a6fb42c76e4b0212
    - Even though they administer Yisaipu in a structured way, it is an approved drug and not being tested for regulatory approval.
- tid_8cab7b7177fcb0d10255bced8b0633ee
    - The trial is studying treatment strategies, regimens, algorithms, imaging-guided regimen selection, or dosing, using only approved standard therapies.
- tid_bb1e0571142dde8a49976632c349593c
    - The trial's focus is on optimizing regimen selection (e.g., TIPy or TCbIPy) via imaging, rather than testing a new drug entity.

these are all confirmed generics biosimilars

Identify and group trials by unique products

In [18]:
import ast
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache")
IN_CSV   = BASE_DIR / "trial_product_breakdown.csv"

# ---------------------------
# Helpers
# ---------------------------
def parse_listish(x):
    """Parse a list-like string (e.g. "['a','b']") into a Python list."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        return v if isinstance(v, list) else []
    except Exception:
        return []


# ---------------------------
# Load
# ---------------------------
df = pd.read_csv(IN_CSV, dtype=str).fillna("")

# We'll aggregate everything keyed by tt_drug_id
agg = {}  # tt_id -> {"names": set(), "alt_names": set(), "targets": set(), "mechs": set(), "trials": set()}

ROLE_PAIRS = [
    ("investigational_products", "investigational_products_tt_drug_id"),
    ("active_comparators", "active_comparators_tt_drug_id"),
    ("standard_of_care", "standard_of_care_tt_drug_id"),
]

for _, row in df.iterrows():
    trial_hash = str(row.get("trial_hash", "")).strip()

    for base_col, tt_col in ROLE_PAIRS:
        # aligned lists
        names_list   = parse_listish(row.get(base_col, ""))
        alts_list    = parse_listish(row.get(f"{base_col}_alternative_names", ""))
        targets_list = parse_listish(row.get(f"{base_col}_molecular_target", ""))
        mechs_list   = parse_listish(row.get(f"{base_col}_mechanism", ""))
        tt_ids       = parse_listish(row.get(tt_col, ""))

        # iterate by index over tt_ids (they define the products)
        for i, raw_tt in enumerate(tt_ids):
            tt_id = str(raw_tt).strip()
            if not tt_id:
                continue

            # init aggregate bucket if needed
            if tt_id not in agg:
                agg[tt_id] = {
                    "names": set(),
                    "alt_names": set(),
                    "targets": set(),
                    "mechs": set(),
                    "trials": set(),
                }

            # record trial hash if available
            if trial_hash:
                agg[tt_id]["trials"].add(trial_hash)

            # name
            if i < len(names_list):
                name = str(names_list[i]).strip()
                if name:
                    agg[tt_id]["names"].add(name)

            # alternative names (may be nested lists)
            if i < len(alts_list):
                alt_entry = alts_list[i]
                if isinstance(alt_entry, list):
                    for a in alt_entry:
                        a_str = str(a).strip()
                        if a_str:
                            agg[tt_id]["alt_names"].add(a_str)
                else:
                    a_str = str(alt_entry).strip()
                    if a_str:
                        agg[tt_id]["alt_names"].add(a_str)

            # target
            if i < len(targets_list):
                tgt = str(targets_list[i]).strip()
                if tgt:
                    agg[tt_id]["targets"].add(tgt)

            # mechanism
            if i < len(mechs_list):
                mech = str(mechs_list[i]).strip()
                if mech:
                    agg[tt_id]["mechs"].add(mech)

# ---------------------------
# Build aggregated DataFrame
# ---------------------------
rows_out = []
for tt_id, payload in agg.items():
    rows_out.append(
        {
            "tt_drug_id": tt_id,
            "drug_names": sorted(payload["names"]),
            "alternative_names": sorted(payload["alt_names"]),
            "molecular_targets": sorted(payload["targets"]),
            "product_mechanisms": sorted(payload["mechs"]),
            "trial_hashes": sorted(payload["trials"]),
        }
    )

grouped_df = pd.DataFrame(rows_out).sort_values("tt_drug_id")

print("Aggregated by tt_drug_id (first 10 rows):")
print(grouped_df.head(10).to_markdown(index=False))

OUT_AGG = BASE_DIR / "product_id_master_by_tt.csv"
grouped_df.to_csv(OUT_AGG, index=False)
print(f"Saved aggregated tt_drug_id table → {OUT_AGG}")

# ---------------------------------------------
# Cell: Print all rows missing targets or mechanisms
# ---------------------------------------------

# A row is "missing" if either list is empty
missing_mask = grouped_df["molecular_targets"].apply(lambda x: len(x) == 0) & \
               grouped_df["product_mechanisms"].apply(lambda x: len(x) == 0)

missing_df = grouped_df[missing_mask].copy()

print("Rows missing molecular_targets or product_mechanisms:")
if missing_df.empty:
    print("✅ No missing values — every tt_drug_id has targets and mechanisms.")
else:
    # Pretty print full table
    print(missing_df.to_markdown(index=False))

# Optionally save for debugging
OUT_MISSING = BASE_DIR / "product_id_missing_targets_or_mechs.csv"
missing_df.to_csv(OUT_MISSING, index=False)
print(f"Saved → {OUT_MISSING}")

Aggregated by tt_drug_id (first 10 rows):
|   tt_drug_id | drug_names                           | alternative_names                                                                                                                                                     | molecular_targets                                                                    | product_mechanisms                                                                                                                                                  | trial_hashes                                                                                                                                                                                                                                     |
|-------------:|:-------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------

In [24]:
import ast
import re
import json
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
from services.openai_wrapper import OpenAIWrapper

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
BASE_DIR = Path("cache")

RAW_TRIALS_CSV = BASE_DIR / "raw_trials_with_hash.csv"

PRODUCT_MECH_DIR = BASE_DIR / "product_mechanism_inference"
PRODUCT_MECH_DIR.mkdir(parents=True, exist_ok=True)

PRODUCT_MECH_LOG_DIR = BASE_DIR / "product_mechanism_inference_log"
PRODUCT_MECH_LOG_DIR.mkdir(parents=True, exist_ok=True)

MASTER_PRODUCT_MECH_PATH = BASE_DIR / "product_mechanism_inference_master.json"

MODEL = "gpt-5"
client = OpenAIWrapper()

MAX_WORKERS = 8

# -------------------------------------------------
# Helpers
# -------------------------------------------------
def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}

    # Direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    # Fallback: first {...} region
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}

    return {}


def safe_parse_listish(val):
    """
    Parse list-like strings back into Python lists, if needed.
    If already a list, return as-is.
    """
    if isinstance(val, list):
        return val
    if val is None:
        return []
    s = str(val).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
        return [v]
    except Exception:
        return [s]


def build_product_prompt(row: dict, trial_context: list[dict]) -> str:
    """
    Build prompt asking the LLM (with web_search) to infer
    molecular_target and mechanism for a product, based on:
      - drug names / alternative names
      - full metadata for associated trials
    """
    drug_names = row.get("drug_names", []) or []

    # Ensure lists are JSON-serializable
    try:
        drug_names_json = json.dumps(drug_names, ensure_ascii=False)
    except TypeError:
        drug_names_json = json.dumps([str(x) for x in drug_names], ensure_ascii=False)

    # Trials context JSON
    trials_json = json.dumps(trial_context, ensure_ascii=False, indent=2)

    return f"""
You are a pharmacology expert with access to web search.

You are given:
- A drug name.
- Full metadata for one or more clinical trials in which this drug appears (JSON objects).

Your goal:
Using web search and your domain knowledge, determine:
1. The primary molecular target(s) of the drug (e.g., EGFR, VEGFR2, TNF, CD20, JAK1/2).
2. A concise, standard mechanism of action label (e.g., "EGFR inhibitor", "Anti-PD-1 antibody", 
   "JAK inhibitor", "DNA-damaging cytotoxic", etc.).

Rules:
- Try searching for the drug using all known names or aliases.
- If **no molecular target or mechanism of action has been disclosed publicly**, then return **empty strings** for those fields.

INPUT
-----
drug_name: {drug_names_json}
trial_metadata:
{trials_json}

OUTPUT (JSON only)
------------------
{{
  "molecular_target": "",
  "mechanism": ""
}}
""".strip()


# Shared counters & master mapping
product_counter = {
    "processed": 0,
    "skipped_existing": 0,
    "llm_error": 0,
    "parse_error": 0,
}
product_counter_lock = threading.Lock()

product_master: dict[str, dict] = {}
product_master_lock = threading.Lock()

# Load existing master if present
if MASTER_PRODUCT_MECH_PATH.exists():
    try:
        product_master = json.loads(MASTER_PRODUCT_MECH_PATH.read_text(encoding="utf-8"))
    except Exception:
        product_master = {}

# -------------------------------------------------
# LOAD TRIAL METADATA AND BUILD INDEX
# -------------------------------------------------
df_trials = pd.read_csv(RAW_TRIALS_CSV, dtype=str).fillna("")
trials_index: dict[str, dict] = {
    str(row["trial_hash"]).strip(): row.to_dict()
    for _, row in df_trials.iterrows()
}


def process_product(row: dict, idx: int, total: int) -> None:
    """Process one tt_drug_id: prompt LLM+web_search with trial context, save output & log."""
    tt_drug_id = str(row.get("tt_drug_id", "")).strip()
    if not tt_drug_id:
        print(f"⚠️ [{idx}/{total}] Missing tt_drug_id, skipping")
        return

    out_fp = PRODUCT_MECH_DIR / f"{tt_drug_id}.json"
    if out_fp.exists():
        with product_counter_lock:
            product_counter["skipped_existing"] += 1
        return

    # Get associated trial hashes and build trial context list
    trial_hashes_raw = row.get("trial_hashes", [])
    trial_hashes = safe_parse_listish(trial_hashes_raw)

    trial_context = []
    for th in trial_hashes:
        th_key = str(th).strip()
        if not th_key:
            continue
        trial_row = trials_index.get(th_key)
        if trial_row:
            trial_context.append(trial_row)

    prompt = build_product_prompt(row, trial_context)

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    # Call LLM with web_search tool
    try:
        t0 = time.perf_counter()
        res = client.query(
            prompt=prompt,
            model=MODEL,
            tools=[{"type": "web_search"}],
        )
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)

    except Exception as e:
        print(f"⚠️ [{idx}/{total}] LLM error for tt_drug_id={tt_drug_id}: {e}")
        with product_counter_lock:
            product_counter["llm_error"] += 1
        return

    mech_obj = extract_json_object(text_response)

    # Expect a dict with the two keys
    if not isinstance(mech_obj, dict) or not mech_obj:
        print(f"⚠️ [{idx}/{total}] JSON parse/validity error tt_drug_id={tt_drug_id}, raw={text_response!r}")
        with product_counter_lock:
            product_counter["parse_error"] += 1
        return

    molecular_target = str(mech_obj.get("molecular_target", "") or "").strip()
    mechanism = str(mech_obj.get("mechanism", "") or "").strip()

    mapped = {
        "tt_drug_id": tt_drug_id,
        "drug_names": row.get("drug_names", []),
        "alternative_names": row.get("alternative_names", []),
        "trial_hashes": trial_hashes,
        "molecular_target": molecular_target,
        "mechanism": mechanism,
        "source": "llm_web_search",
    }

    # Save per-product JSON
    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    # Log entry
    log_payload = {
        "tt_drug_id": tt_drug_id,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (PRODUCT_MECH_LOG_DIR / f"{tt_drug_id}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    # Update master
    with product_master_lock:
        product_master[tt_drug_id] = mapped
        MASTER_PRODUCT_MECH_PATH.write_text(
            json.dumps(product_master, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )

    with product_counter_lock:
        product_counter["processed"] += 1
        if product_counter["processed"] % 50 == 0:
            print(f"Progress: processed {product_counter['processed']} products...")


# -------------------------------------------------
# RUN CONCURRENTLY ON MISSING PRODUCTS
# -------------------------------------------------
# missing_df was defined in the previous cell and includes trial_hashes
missing_rows = missing_df.to_dict(orient="records")
total_missing = len(missing_rows)
print(f"Loaded {total_missing} products missing targets/mechanisms")

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
    futures = {
        ex.submit(process_product, row, idx, total_missing): row.get("tt_drug_id")
        for idx, row in enumerate(missing_rows, start=1)
    }
    for fut in as_completed(futures):
        tid = futures[fut]
        try:
            fut.result()
        except Exception as e:
            print(f"⚠️ Worker error tt_drug_id={tid}: {e}")

print(
    f"✅ Product mechanism inference complete. "
    f"processed={product_counter['processed']}, "
    f"skipped={product_counter['skipped_existing']}, "
    f"llm_error={product_counter['llm_error']}, "
    f"parse_error={product_counter['parse_error']}"
)
print(f"Per-product directory: {PRODUCT_MECH_DIR}")
print(f"Log directory:        {PRODUCT_MECH_LOG_DIR}")
print(f"Master file:          {MASTER_PRODUCT_MECH_PATH}")

Loaded 2 products missing targets/mechanisms
✅ Product mechanism inference complete. processed=2, skipped=0, llm_error=0, parse_error=0
Per-product directory: cache/product_mechanism_inference
Log directory:        cache/product_mechanism_inference_log
Master file:          cache/product_mechanism_inference_master.json


In [26]:
import ast
import json
from pathlib import Path

import pandas as pd

# ----------------------------------------
# CONFIG
# ----------------------------------------
BASE_DIR = Path("cache")
IN_BREAKDOWN_CSV = BASE_DIR / "trial_product_breakdown.csv"
MASTER_PRODUCT_MECH_PATH = BASE_DIR / "product_mechanism_inference_master.json"
OUT_FILLED_CSV = BASE_DIR / "trial_product_breakdown_filled.csv"

# ----------------------------------------
# Helpers
# ----------------------------------------
def parse_listish(x):
    """Parse a list-like string (e.g. "['a','b']") into a Python list."""
    if isinstance(x, list):
        return x
    if x is None:
        return []
    s = str(x).strip()
    if not s or s in ("[]", "[ ]"):
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
        return [v]
    except Exception:
        return [s]


def pad_to_length(lst, n):
    """Pad list with empty strings so len(lst) >= n."""
    lst = list(lst)
    while len(lst) < n:
        lst.append("")
    return lst


# ----------------------------------------
# Load inputs
# ----------------------------------------
df = pd.read_csv(IN_BREAKDOWN_CSV, dtype=str).fillna("")
print(f"Loaded trial breakdown: {IN_BREAKDOWN_CSV}, shape={df.shape}")

if MASTER_PRODUCT_MECH_PATH.exists():
    product_master = json.loads(MASTER_PRODUCT_MECH_PATH.read_text(encoding="utf-8"))
else:
    product_master = {}
    print(f"⚠️ No product master mech file found at {MASTER_PRODUCT_MECH_PATH}")

# role → tt_id / target / mech columns
ROLE_SPECS = [
    (
        "investigational_products",
        "investigational_products_tt_drug_id",
        "investigational_products_molecular_target",
        "investigational_products_mechanism",
    ),
    (
        "active_comparators",
        "active_comparators_tt_drug_id",
        "active_comparators_molecular_target",
        "active_comparators_mechanism",
    ),
    (
        "standard_of_care",
        "standard_of_care_tt_drug_id",
        "standard_of_care_molecular_target",
        "standard_of_care_mechanism",
    ),
    # Note: placebos have no tt_drug_id column, so we can't enrich them here.
]

# ----------------------------------------
# Enrich targets/mechanisms from product_master
# ----------------------------------------
updated_rows = 0
filled_targets = 0
filled_mechs = 0

for idx, row in df.iterrows():
    row_changed = False

    for role_name, tt_col, tgt_col, mech_col in ROLE_SPECS:
        # If column missing (defensive), skip
        if tt_col not in df.columns or tgt_col not in df.columns or mech_col not in df.columns:
            continue

        tt_ids = parse_listish(row.get(tt_col, ""))
        if not tt_ids:
            continue

        targets = parse_listish(row.get(tgt_col, ""))
        mechs = parse_listish(row.get(mech_col, ""))

        # Ensure alignment
        targets = pad_to_length(targets, len(tt_ids))
        mechs = pad_to_length(mechs, len(tt_ids))

        for i, raw_tt in enumerate(tt_ids):
            tt_id = str(raw_tt).strip()
            if not tt_id:
                continue

            info = product_master.get(tt_id)
            if not info:
                continue

            inferred_target = str(info.get("molecular_target", "") or "").strip()
            inferred_mech = str(info.get("mechanism", "") or "").strip()

            # Only fill if currently empty and inference has a non-empty value
            if (not targets[i].strip()) and inferred_target:
                targets[i] = inferred_target
                filled_targets += 1
                row_changed = True

            if (not mechs[i].strip()) and inferred_mech:
                mechs[i] = inferred_mech
                filled_mechs += 1
                row_changed = True

        # Write back updated lists (as Python-literal strings)
        if row_changed:
            df.at[idx, tgt_col] = repr(targets)
            df.at[idx, mech_col] = repr(mechs)

    if row_changed:
        updated_rows += 1

print(f"Updated {updated_rows} trial rows with inferred mechanisms/targets.")
print(f"Filled targets: {filled_targets}, filled mechanisms: {filled_mechs}")

# ----------------------------------------
# Save filled CSV
# ----------------------------------------
OUT_FILLED_CSV.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(OUT_FILLED_CSV, index=False)
print(f"✅ Saved filled trial breakdown to {OUT_FILLED_CSV}")


Loaded trial breakdown: cache/trial_product_breakdown.csv, shape=(184, 23)
Updated 1 trial rows with inferred mechanisms/targets.
Filled targets: 1, filled mechanisms: 1
✅ Saved filled trial breakdown to cache/trial_product_breakdown_filled.csv


#### Task 2

Identify whether the drugs are innovative or/generic biosimilars

In [None]:
import json
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from services.openai_wrapper import OpenAIWrapper
import ast

import pandas as pd

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
BASE_DIR = Path("cache")

TRIALS_WITH_HASH_CSV    = BASE_DIR / "raw_trials_with_hash.csv"
PRODUCT_BREAKDOWN_CSV   = BASE_DIR / "trial_product_breakdown.csv"

INNOV_DIR = BASE_DIR / "trial_investigational_drugs_classifications"
INNOV_DIR.mkdir(parents=True, exist_ok=True)

INNOV_LOG_DIR = BASE_DIR / "trial_investigational_drugs_classifications_log"
INNOV_LOG_DIR.mkdir(parents=True, exist_ok=True)

MASTER_INNOV_PATH = BASE_DIR / "trial_investigational_drugs_classifications_master.json"

RELEVANT_COLS = [
    "title",
    "objective",
    "outcome_details",
    "treatment_plan",
    "notes_json",
    "results_json",
    "primary_drugs_tested_json",
    "other_drugs_tested_json",
    "therapeutic_areas_json",
    "bmt_other_drugs_tested_json",
    "bmt_primary_drugs_tested_json",
    "ct_gov_mesh_terms_json",
]

MAX_WORKERS_INNOV = 8

MODEL = "gpt-5"
client = OpenAIWrapper()

# -------------------------------------------------
# Helpers
# -------------------------------------------------
def load_master_innov() -> dict:
    if not MASTER_INNOV_PATH.exists():
        return {}
    try:
        return json.loads(MASTER_INNOV_PATH.read_text(encoding="utf-8"))
    except Exception:
        return {}

def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}

    # Direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    # Fallback: first {...} region
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}

    return {}

def parse_listish(s: str):
    """
    Parse a stringified list like "['A', 'B']" into a Python list.
    If parsing fails or the cell is empty, return [].
    """
    if not isinstance(s, str):
        return []
    s = s.strip()
    if not s:
        return []
    # Common empty-list cases
    if s in ("[]", "[ ]"):
        return []
    try:
        val = ast.literal_eval(s)
        if isinstance(val, list):
            return val
        # If it's something else, treat as a single non-empty token
        return [val]
    except Exception:
        # Fallback: treat non-empty string as a single element
        return [s]

master_innov = load_master_innov()
master_lock = threading.Lock()

innov_counter = {
    "processed": 0,
    "skipped_existing": 0,
    "llm_error": 0,
    "parse_error": 0,
    "coverage_error": 0,
}
counter_lock = threading.Lock()


def build_innovation_prompt(trial_payload: dict, investigational_products: list[str]) -> str:
    """
    Build prompt to classify each investigational product as
    Innovative / Generic / Biosimilar, with one-sentence explanation.
    """
    payload_json = json.dumps(trial_payload, ensure_ascii=False, indent=2)
    drugs_json   = json.dumps(investigational_products, ensure_ascii=False, indent=2)

    return f"""
You are a clinical trial design and drug development expert.

You are given:
1) Structured information about a clinical trial (title, objective, results, etc.).
2) A list of investigational products used in the trial.
3) Structured fields describing how drugs are classified in the study
   (investigational_products, active_comparators, placebos, standard_of_care, etc.).

Your task: For EACH investigational product, classify whether it is:
- "Innovative"
- "Generic"
- "Biosimilar"

and provide a one-sentence concise explanation for your classification.

DEFINITIONS / GUIDANCE
----------------------

Innovative:
- A novel or proprietary drug.
- New mechanism of action OR new biological entity OR clearly sponsor's lead product.
- Often associated with superiority or efficacy language:
  - "evaluate efficacy", "vs placebo", "improve outcomes", etc.
- Not a copy of an already-approved product.

Generic:
- A small-molecule copy of an already-approved branded drug.
- Same active ingredient, strength, dosage form, and route.
- Often associated with:
  - language like "generic", "copy", "equivalent",
  - OR clear indication that the product is a non-branded version.

Biosimilar:
- A biologic product that is highly similar to an already-approved reference biologic.
- Same target and mechanism as a branded biologic.
- Strong clues:
  - "equivalence", "non-inferiority", "no clinically meaningful differences",
  - direct comparison to a specific branded reference biologic with the SAME active ingredient.

Task 2 — Classify Innovation Status
-----------------------------------

Use clues within the text to determine whether each investigational drug is:
- "Innovative"
- "Generic"
- "Biosimilar"

Examples of helpful cues:
- Innovative:
  - Superiority/efficacy language ("versus placebo", "evaluate efficacy").
  - Novel or advanced mechanism, new target, or first-in-class description.
- Biosimilar:
  - Equivalence or non-inferiority language.
  - Direct comparison to a branded reference product with the same active ingredient.
- Generic:
  - Explicitly described as generic.
  - Non-biologic small-molecule copy of an existing branded product.

If the information is incomplete, choose the MOST LIKELY label based on the text and typical drug naming patterns.
You MUST still choose ONE of the three labels ("Innovative", "Generic", "Biosimilar") for each drug.
If you are uncertain, you may say so in the one-sentence explanation.

OUTPUT FORMAT (IMPORTANT)
-------------------------

Return ONLY a valid JSON object, with:
- KEYS   = exactly the investigational product names as provided in the list below
- VALUES = an object with exactly two fields:
    - "classification": one of "Innovative", "Generic", "Biosimilar"
    - "explanation": a single, concise sentence explaining your reasoning

You MUST provide a classification for EVERY investigational product name.

Example output:
{{
  "DrugA": {{
    "classification": "Innovative",
    "explanation": "DrugA is a novel monoclonal antibody targeting a new receptor and is the sponsor's lead product."
  }},
  "DrugB": {{
    "classification": "Biosimilar",
    "explanation": "DrugB is tested for non-inferiority compared to the branded biologic with the same target."
  }}
}}

TRIAL PAYLOAD (includes trial text and all drug-role breakdown columns):
{payload_json}

INVESTIGATIONAL PRODUCTS (you MUST classify EACH of these):
{drugs_json}
""".strip()


def process_innov_row(row: dict, idx: int, total: int, breakdown_cols: list[str]) -> None:
    """Process a single trial with investigational products."""
    trial_hash = str(row.get("trial_hash", "")).strip()
    if not trial_hash:
        print(f"⚠️ [{idx}/{total}] Missing trial_hash, skipping")
        return

    investigational_products = row.get("investigational_products_parsed") or []
    investigational_products = [str(x).strip() for x in investigational_products if str(x).strip()]

    if not investigational_products:
        # Shouldn't happen due to filtering, but be safe
        return

    out_fp = INNOV_DIR / f"{trial_hash}.json"
    if out_fp.exists():
        with counter_lock:
            innov_counter["skipped_existing"] += 1
        return

    # Build payload from selected columns
    trial_payload = {"trial_hash": trial_hash}

    # 1) Trial-level textual fields from raw_trials_with_hash.csv
    for col in RELEVANT_COLS:
        trial_payload[col] = row.get(col, "")

    # 2) ALL columns from trial_product_breakdown.csv
    for col in breakdown_cols:
        trial_payload[col] = row.get(col, "")

    prompt = build_innovation_prompt(trial_payload, investigational_products)

    token = trial_hash
    hash_id = trial_hash

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    # Call LLM
    try:
        t0 = time.perf_counter()
        res = client.query(prompt=prompt, model=MODEL)
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)
    except Exception as e:
        print(f"⚠️ [{idx}/{total}] LLM error for trial_hash={trial_hash}: {e}")
        with counter_lock:
            innov_counter["llm_error"] += 1
        return

    # Parse JSON
    classifications = extract_json_object(text_response)

    if not isinstance(classifications, dict) or not classifications:
        print(f"⚠️ [{idx}/{total}] JSON parse error trial_hash={trial_hash}, raw={text_response!r}")
        with counter_lock:
            innov_counter["parse_error"] += 1
        return

    # Check coverage: every investigational product must be present as a key
    missing = [d for d in investigational_products if d not in classifications]
    if missing:
        print(
            f"⚠️ [{idx}/{total}] Coverage error for trial_hash={trial_hash}: "
            f"missing classifications for {missing}"
        )
        with counter_lock:
            innov_counter["coverage_error"] += 1
        # DO NOT save this trial so it can be re-run next time
        return

    # Optional: sanity check that each value has classification + explanation
    for d in investigational_products:
        meta = classifications.get(d, {})
        if not isinstance(meta, dict):
            print(f"⚠️ [{idx}/{total}] Invalid meta for {d} in trial_hash={trial_hash}")
            with counter_lock:
                innov_counter["parse_error"] += 1
            return
        if "classification" not in meta or "explanation" not in meta:
            print(f"⚠️ [{idx}/{total}] Missing fields for {d} in trial_hash={trial_hash}")
            with counter_lock:
                innov_counter["parse_error"] += 1
            return

    mapped = {
        "trial_hash": trial_hash,
        "investigational_products": investigational_products,
        "classifications": classifications,
        "source": "llm",
    }

    # Save per-trial JSON
    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    # Log entry
    log_payload = {
        "token": token,
        "hash_id": hash_id,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (INNOV_LOG_DIR / f"{hash_id}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    # Update master
    with master_lock:
        master_innov[trial_hash] = mapped
        MASTER_INNOV_PATH.write_text(
            json.dumps(master_innov, ensure_ascii=False, indent=2),
            encoding="utf-8"
        )

    with counter_lock:
        innov_counter["processed"] += 1
        if innov_counter["processed"] % 50 == 0:
            print(f"Progress: processed {innov_counter['processed']} trials for innovation status...")


# -------------------------------------------------
# LOAD & MERGE DATA
# -------------------------------------------------
# Load breakdown (investigational products + all drug-role cols)
df_breakdown = pd.read_csv(PRODUCT_BREAKDOWN_CSV, dtype=str).fillna("")

# Reuse parse_listish from previous cell
df_breakdown["investigational_products_parsed"] = df_breakdown["investigational_products"].apply(parse_listish)
mask_has_inv = df_breakdown["investigational_products_parsed"].apply(lambda x: len(x) > 0)

# Restrict to rows with investigational products
df_breakdown_sub = df_breakdown.loc[mask_has_inv].copy()

# All columns from trial_product_breakdown.csv except trial_hash (which is already separate)
BREAKDOWN_COLS = [c for c in df_breakdown_sub.columns if c != "trial_hash"]

# Load raw trials (for RELEVANT_COLS)
df_trials = pd.read_csv(TRIALS_WITH_HASH_CSV, dtype=str).fillna("")

# Merge on trial_hash; keep all breakdown columns + investigational_products_parsed + RELEVANT_COLS
df_merged = df_breakdown_sub.merge(
    df_trials[["trial_hash"] + RELEVANT_COLS],
    on="trial_hash",
    how="left",
)

innov_rows = df_merged.to_dict(orient="records")
total_innov = len(innov_rows)
print(f"Loaded {total_innov} trials with investigational products for innovation-status classification.")

# -------------------------------------------------
# RUN CONCURRENTLY
# -------------------------------------------------

with ThreadPoolExecutor(max_workers=MAX_WORKERS_INNOV) as ex:
    futures = {
        ex.submit(process_innov_row, row, idx, total_innov, BREAKDOWN_COLS): row.get("trial_hash")
        for idx, row in enumerate(innov_rows, start=1)
    }
    for fut in as_completed(futures):
        th = futures[fut]
        try:
            fut.result()
        except Exception as e:
            print(f"⚠️ Worker error (innovation) trial_hash={th}: {e}")

print(
    f"✅ Trial investigational-drug innovation classification complete. "
    f"processed={innov_counter['processed']}, "
    f"skipped={innov_counter['skipped_existing']}, "
    f"llm_error={innov_counter['llm_error']}, "
    f"parse_error={innov_counter['parse_error']}, "
    f"coverage_error={innov_counter['coverage_error']}"
)
print(f"Classifications directory: {INNOV_DIR}")
print(f"Log directory:             {INNOV_LOG_DIR}")
print(f"Master classifications:    {MASTER_INNOV_PATH}")

Loaded 180 trials with investigational products for innovation-status classification.


In [None]:
import json
from pathlib import Path

LOG_DIR = Path("cache/trial_investigational_drugs_classifications_log")

total_cost = 0.0
num_entries = 0
costs = []

for fp in LOG_DIR.glob("*.json"):
    try:
        log = json.loads(fp.read_text(encoding="utf-8"))
        c = float(log.get("total_cost") or 0.0)
        total_cost += c
        costs.append((fp.name, c))
        num_entries += 1
    except Exception as e:
        print(f"⚠️ Error reading {fp.name}: {e}")

# Sort descending by cost
costs_sorted = sorted(costs, key=lambda x: x[1], reverse=True)

print("========== LLM COST SUMMARY ==========")
print(f"Total LLM cost:             ${total_cost:,.4f}")
print(f"Number of logged trials:     {num_entries}")
if num_entries > 0:
    print(f"Average cost per trial:      ${total_cost / num_entries:,.4f}")
print("")

print("Top 10 most expensive trials:")
for name, c in costs_sorted[:10]:
    print(f"  {name}: ${c:,.4f}")

print("========================================")

In [None]:
import json
import pandas as pd

OUT_CSV = BASE_DIR / "trial_investigational_drugs_classifications.csv"

rows = []

for fp in INNOV_DIR.glob("*.json"):
    try:
        obj = json.loads(fp.read_text(encoding="utf-8"))
    except Exception as e:
        print(f"⚠️ Error reading {fp.name}: {e}")
        continue

    trial_hash = obj.get("trial_hash")
    if not trial_hash:
        print(f"⚠️ Missing trial_hash in {fp.name}, skipping")
        continue

    inv_products_raw = obj.get("investigational_products") or []
    classifications_map = obj.get("classifications") or {}

    flat_products = []
    flat_classifications = []

    for drug_raw in inv_products_raw:
        # drug_raw might be "['inetetamab', 'toripalimab']" or just "SSGJ-707"
        if isinstance(drug_raw, str):
            parsed_names = parse_listish(drug_raw)  # from earlier cell (uses ast.literal_eval)
        else:
            parsed_names = [drug_raw]

        # Prefer classification using the exact key that was sent to the model
        meta = classifications_map.get(drug_raw, {})
        cls = meta.get("classification", "")

        # If not found, try each parsed name as a key
        if not cls:
            for name in parsed_names:
                meta_n = classifications_map.get(name, {})
                if "classification" in meta_n:
                    cls = meta_n.get("classification", "")
                    break

        if not cls:
            print(
                f"⚠️ Missing classification for raw drug {drug_raw!r} in "
                f"trial_hash={trial_hash}, file={fp.name}"
            )

        # Add one entry per parsed name so both lists are flat and aligned
        for name in parsed_names:
            flat_products.append(name)
            flat_classifications.append(cls)

    # Sanity check: lengths must match
    if len(flat_products) != len(flat_classifications):
        print(
            f"⚠️ Length mismatch for trial_hash={trial_hash}: "
            f"{len(flat_products)} products vs {len(flat_classifications)} classifications"
        )

    rows.append(
        {
            "trial_hash": trial_hash,
            # store as JSON stringified flat lists
            "investigational_products": json.dumps(flat_products, ensure_ascii=False),
            "investigational_products_classifications": json.dumps(flat_classifications, ensure_ascii=False),
        }
    )

df_out = pd.DataFrame(rows).sort_values("trial_hash")

OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
df_out.to_csv(OUT_CSV, index=False)

print(f"✅ Saved investigational drug classifications to {OUT_CSV}")
print(df_out.head().to_markdown(index=False))

#### Task 3

In [None]:
import json
import hashlib
import pandas as pd
from pathlib import Path

# -----------------------------------
# CONFIG
# -----------------------------------
BASE_DIR = Path("cache")
BREAKDOWN_CSV = BASE_DIR / "trial_product_breakdown.csv"
MOA_MASTER_PATH = BASE_DIR / "investigational_drug_moa_master.json"

# -----------------------------------
# Helper: deterministic hash
# -----------------------------------
def make_hash_id(text: str) -> str:
    """Deterministic hash ID for a mechanism string."""
    normalized = text.strip().lower()
    return "mid_" + hashlib.md5(normalized.encode("utf-8")).hexdigest()

# Ensure parse_listish exists
try:
    parse_listish
except NameError:
    import ast
    def parse_listish(s: str):
        if not isinstance(s, str):
            return []
        s = s.strip()
        if not s or s in ("[]", "[ ]"):
            return []
        try:
            val = ast.literal_eval(s)
            return val if isinstance(val, list) else [val]
        except Exception:
            return [s]

# -----------------------------------
# LOAD DATA
# -----------------------------------
df = pd.read_csv(BREAKDOWN_CSV, dtype=str).fillna("")

# Columns
INV_PRODUCTS_COL = "investigational_products"
INV_MOA_COL      = "investigational_products_mechanism"

MOA_COLS_EXTRA = [
    "active_comparators_mechanism",
    "standard_of_care_mechanism",
]

missing_cols = [c for c in [INV_PRODUCTS_COL, INV_MOA_COL] + MOA_COLS_EXTRA if c not in df.columns]
if missing_cols:
    print(f"⚠️ Missing expected columns: {missing_cols}")

# -----------------------------------
# COLLECT UNIQUE MOAs
# -----------------------------------
unique_moas = set()

for _, row in df.iterrows():
    # --- 1) Investigational products: mechanism with fallback to product name ---
    inv_products_raw = row.get(INV_PRODUCTS_COL, "")
    inv_moa_raw      = row.get(INV_MOA_COL, "")

    inv_products = parse_listish(inv_products_raw)
    inv_moas     = parse_listish(inv_moa_raw)

    # Make sure lists are aligned by index; fallback per drug
    max_len = max(len(inv_products), len(inv_moas))
    for i in range(max_len):
        drug_name = str(inv_products[i]).strip() if i < len(inv_products) else ""
        mech      = str(inv_moas[i]).strip() if i < len(inv_moas) else ""

        if mech:
            unique_moas.add(mech)
        elif drug_name:
            # Fallback: use investigational product name as the "mechanism" stand-in
            unique_moas.add(drug_name)

    # --- 2) Extra MOA columns (no fallback, same as before) ---
    for col in MOA_COLS_EXTRA:
        if col not in df.columns:
            continue
        raw = row.get(col, "")
        parsed = parse_listish(raw)
        for item in parsed:
            item = str(item).strip()
            if item:
                unique_moas.add(item)

print(f"Found {len(unique_moas)} unique MOA strings (including fallbacks to product names).")

# -----------------------------------
# BUILD MASTER DICT
# -----------------------------------
moa_master = {}

for moa in sorted(unique_moas):
    hash_id = make_hash_id(moa)
    moa_master[hash_id] = {
        "moa_id": hash_id,
        "mechanism": moa,
        "source": "trial_product_breakdown_or_fallback",
    }

# -----------------------------------
# SAVE MASTER FILE
# -----------------------------------
MOA_MASTER_PATH.write_text(
    json.dumps(moa_master, ensure_ascii=False, indent=2),
    encoding="utf-8"
)

print(f"✅ Saved MOA master to {MOA_MASTER_PATH}")

In [None]:
# PubMed search for each MOA (HASH-BASED OUTPUT FILENAMES)

import os, json, time, html, unicodedata
from pathlib import Path
import requests
from xml.etree import ElementTree as ET

from services.openai_wrapper import OpenAIWrapper

# -----------------------------
# Paths / Config
# -----------------------------
BASE_DIR = Path("cache")

MOA_MASTER_PATH   = BASE_DIR / "investigational_drug_moa_master.json"
OUT_DIR           = BASE_DIR / "investigational_drug_moa_pubmed_search"
OUT_DIR.mkdir(parents=True, exist_ok=True)

MASTER_INDEX_PATH = BASE_DIR / "investigational_drug_moa_pubmed_index.json"

EUTILS     = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
API_KEY    = os.getenv("NCBI_API_KEY") or None
EMAIL      = os.getenv("NCBI_EMAIL") or None
SLEEP      = 0.25
RETRY_MAX  = 3
RETRY_WAIT = 1.0

# LLM config (for MOA refinement)
MODEL  = "gpt-5-mini"   # adjust if needed
client = OpenAIWrapper()

# -----------------------------
# Helpers
# -----------------------------
NAN_STRINGS = {"nan", "none", "null", ""}

def _clean(s):
    if s is None:
        return ""
    s_str = str(s).strip()
    return "" if s_str.lower() in NAN_STRINGS else s_str

def norm_text(s: str) -> str:
    if not isinstance(s, str):
        return ""
    t = html.unescape(s)
    t = unicodedata.normalize("NFKC", t)
    t = " ".join(t.strip().lower().split())
    return "" if t in NAN_STRINGS else t

def _http_get_with_retry(url: str, params: dict, timeout: int) -> requests.Response:
    last_err = None
    for attempt in range(1, RETRY_MAX + 1):
        try:
            r = requests.get(url, params=params, timeout=timeout)
            r.raise_for_status()
            return r
        except Exception as e:
            last_err = e
            if attempt < RETRY_MAX:
                time.sleep(RETRY_WAIT)
            else:
                raise last_err

def esearch_ids(term: str, n: int = 3) -> list[str]:
    term = _clean(term)
    if not term:
        return []
    params = {
        "db": "pubmed",
        "term": term,
        "retmode": "json",
        "retmax": n,
        "sort": "relevance",
    }
    if API_KEY:
        params["api_key"] = API_KEY
    if EMAIL:
        params["email"] = EMAIL
    r = _http_get_with_retry(f"{EUTILS}/esearch.fcgi", params=params, timeout=30)
    return r.json().get("esearchresult", {}).get("idlist", []) or []

def _parse_xml_with_retry(text: str) -> ET.Element:
    last_err = None
    for attempt in range(1, RETRY_MAX + 1):
        try:
            return ET.fromstring(text)
        except ET.ParseError as e:
            last_err = e
            if attempt < RETRY_MAX:
                time.sleep(RETRY_WAIT)
            else:
                raise last_err

def efetch_details(pmids: list[str]) -> dict:
    if not pmids:
        return {}
    params = {"db": "pubmed", "id": ",".join(pmids), "retmode": "xml"}
    if API_KEY:
        params["api_key"] = API_KEY
    if EMAIL:
        params["email"] = EMAIL
    r = _http_get_with_retry(f"{EUTILS}/efetch.fcgi", params=params, timeout=60)
    root = _parse_xml_with_retry(r.text)

    out = {}

    def text_from_el(el):
        return "".join(el.itertext()).strip() if el is not None else ""

    def join_abstract(abs_parent):
        parts = []
        for t in abs_parent.findall("AbstractText"):
            label = t.attrib.get("Label")
            txt = text_from_el(t)
            if txt:
                parts.append(f"{label}: {txt}" if label else txt)
        return "\n".join(parts).strip()

    for art in root.findall(".//PubmedArticle"):
        pmid_el = art.find(".//MedlineCitation/PMID")
        if pmid_el is None or not (pmid_el.text or "").strip():
            continue
        pmid = pmid_el.text.strip()

        title = text_from_el(art.find(".//Article/ArticleTitle"))
        abs_parent = art.find(".//Article/Abstract")
        abstract = join_abstract(abs_parent) if abs_parent is not None else ""

        mesh_terms = []
        for mh in art.findall(".//MedlineCitation/MeshHeadingList/MeshHeading"):
            desc = mh.find("DescriptorName")
            if desc is None or not (desc.text or "").strip():
                continue
            d_text = desc.text.strip()
            d_major = desc.attrib.get("MajorTopicYN") == "Y"
            d_str = f"{d_text}{'*' if d_major else ''}"

            quals = []
            for q in mh.findall("QualifierName"):
                q_text = (q.text or "").strip()
                if q_text:
                    q_major = q.attrib.get("MajorTopicYN") == "Y"
                    quals.append(f"{q_text}{'*' if q_major else ''}")

            mesh_terms.append(d_str if not quals else d_str + " / " + "; ".join(quals))

        substances = []
        for chem in art.findall(".//Chemical"):
            nm_el = chem.find("NameOfSubstance")
            rn_el = chem.find("RegistryNumber")
            nm = nm_el.text.strip() if nm_el is not None else ""
            rn = rn_el.text.strip() if rn_el is not None else ""
            if nm and rn and rn != "0":
                substances.append(f"{nm} [RN:{rn}]")
            elif nm:
                substances.append(nm)
            elif rn and rn != "0":
                substances.append(f"[RN:{rn}]")

        def uniq(xs):
            seen, out_local = set(), []
            for x in xs:
                if x and x not in seen:
                    seen.add(x)
                    out_local.append(x)
            return out_local

        out[pmid] = {
            "title": title,
            "abstract": abstract,
            "mesh_terms": uniq(mesh_terms),
            "substances": uniq(substances),
        }

    return out

def save_json(path: Path, obj: dict):
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(obj, indent=2, ensure_ascii=False), encoding="utf-8")

def load_json_or_empty(path: Path) -> dict:
    if not path.exists():
        return {}
    try:
        return json.loads(path.read_text(encoding="utf-8"))
    except Exception:
        return {}

def split_terms(s: str):
    """
    For MOA strings like:
      'Thrombopoietin receptor agonist (recombinant growth factor); PEGylated recombinant human EPO'
    we split on ';' and treat each piece as a candidate search term.
    """
    if not s:
        return []
    raw = [t.strip() for t in str(s).split(";")]
    return [t for t in raw if t and t.lower() not in NAN_STRINGS]

# --------------- NEW: LLM refinement helpers ---------------

def build_moa_refinement_prompt(mechanism: str) -> str:
    """
    Prompt the chatbot to turn a free-text MOA into a concise, canonical
    mechanism-of-action phrase suitable for PubMed search.

    Examples:
      - "Monoclonal antibody–IL-15 fusion (bifunctional immunocytokine) targeting B7-H3"
            -> "Immunocytokines"
      - "Non-absorbed phosphate-binding polymer (anion exchange resin)"
            -> "Ion Exchange Resins"
    """
    return f"""
You are an expert clinical pharmacologist and mechanisms-of-action classifier.

Given the following mechanism-of-action (MOA) description from a drug development database:

\"\"\"{mechanism}\"\"\"

Rewrite or condense it into a SHORT, CANONICAL mechanism-of-action term that would work well as a PubMed search term.

Rules:
- Output a concise mechanism class or well-recognized pharmacologic concept, not a full sentence.
- Prefer standard pharmacologic/mechanistic classes (e.g. "Ion Exchange Resins", "Immunocytokines",
  "Kinase Inhibitors", "Antibodies, Monoclonal", "Immune Checkpoint Inhibitors").
- Do NOT include long target listings or extra explanation.
- If the original MOA is already an appropriate concise search term, you may return it unchanged.

Return ONLY the refined mechanism phrase, with no additional explanation or formatting.
""".strip()

def refine_mechanism_with_llm(mechanism: str) -> str | None:
    """
    Use the OpenAIWrapper .query() interface to get a refined mechanism phrase.
    Returns the refined phrase or None on failure.
    """
    mech_clean = _clean(mechanism)
    if not mech_clean:
        return None

    prompt = build_moa_refinement_prompt(mech_clean)

    try:
        res = client.query(prompt=prompt, model=MODEL)
        text = (res.get("text_response") or "").strip()
        # Strip surrounding quotes if the model adds them
        text = text.strip().strip('"').strip("'")
        refined = _clean(text)
        return refined or None
    except Exception as e:
        print(f"⚠️ LLM refinement failed for mechanism='{mech_clean[:80]}': {e}")
        return None

# -----------------------------
# Load MOA master
# -----------------------------
moa_master = load_json_or_empty(MOA_MASTER_PATH)
if not moa_master:
    raise RuntimeError(f"No MOA entries found in {MOA_MASTER_PATH}")

master_index = load_json_or_empty(MASTER_INDEX_PATH) or {}

total = len(moa_master)
print(f"{total} MOA entries to process")
processed = 0

# -----------------------------
# Main loop: one PubMed search per MOA
# -----------------------------
for moa_id, rec in moa_master.items():
    # HASHES ARE PRESERVED — we use moa_id directly as filename and key
    if moa_id in master_index:
        mech = rec.get("mechanism", "")
        print(f"{mech[:60]} || already processed")
        processed += 1
        continue

    mechanism = _clean(rec.get("mechanism", ""))
    if not mechanism:
        print(f"⚠️ Empty mechanism for moa_id={moa_id}, skipping")
        continue

    # Normalized key (for index / dedup purposes, if needed later)
    mech_key = norm_text(mechanism)

    # Build candidate search terms from mechanism string
    terms = split_terms(mechanism)
    if not terms:
        terms = [mechanism]

    tried_terms: list[str] = []
    matched_term: str | None = None
    pmids: list[str] = []
    records: dict = {}
    llm_refined: str | None = None

    # 1) First-pass: direct PubMed search using raw/split terms
    for t in terms:
        t_clean = _clean(t)
        if not t_clean:
            continue
        tried_terms.append(t_clean)

        query = f"\"{t_clean}\""
        try:
            pmids = esearch_ids(query, n=5)
        except Exception:
            pmids = []

        if pmids:
            matched_term = t_clean
            try:
                records = efetch_details(pmids)
            except Exception as e:
                records = {"_error": str(e)}
            break

    # 2) If no hits, ask LLM to refine the mechanism and try again
    if not pmids:
        llm_refined = refine_mechanism_with_llm(mechanism)
        if llm_refined:
            tried_terms.append(llm_refined + " [LLM]")
            query = f"\"{llm_refined}\""
            try:
                pmids = esearch_ids(query, n=5)
            except Exception:
                pmids = []

            if pmids:
                matched_term = llm_refined
                try:
                    records = efetch_details(pmids)
                except Exception as e:
                    records = {"_error": str(e)}

    # 3) If still no PMIDs, skip saving (so you can rerun later)
    if not pmids:
        print(f"⚠️ No PubMed hits for moa_id={moa_id} after raw + LLM refinement, skipping")
        # Do NOT write any JSON or master entry
        continue

    # -----------------------------
    # HASH-BASED OUTPUT
    # -----------------------------
    fname = f"{moa_id}.json"
    out_path = OUT_DIR / fname

    payload = {
        "type": "moa_pubmed_search",
        "moa_id": moa_id,
        "mechanism": mechanism,
        "mechanism_key": mech_key,
        "tried_terms": tried_terms,
        "llm_refined_mechanism": llm_refined,
        "match": {
            "term": matched_term,
            "pmids": pmids,
            "records": records,
        },
    }

    save_json(out_path, payload)

    master_index[moa_id] = {
        "mechanism": mechanism,
        "mechanism_key": mech_key,
        "json_path": f"{OUT_DIR.name}/{fname}",
        "pmids": pmids,
        "matched_term": matched_term,
        "llm_refined_mechanism": llm_refined,
    }
    save_json(MASTER_INDEX_PATH, master_index)

    processed += 1
    if processed % 50 == 0:
        print(f"Processed {processed}/{total}…")

    time.sleep(SLEEP)

save_json(MASTER_INDEX_PATH, master_index)
print(f"✅ Completed {processed} MOA entries with at least one PubMed hit. Files written to {OUT_DIR}")


In [None]:
import json
import re
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

from services.openai_wrapper import OpenAIWrapper  # your wrapper

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
BASE_DIR = Path("cache")

MOA_PUBMED_DIR        = BASE_DIR / "investigational_drug_moa_pubmed_search"
MOA_CHOICE_DIR        = BASE_DIR / "investigational_drug_moa_chosen"
MOA_CHOICE_LOG_DIR    = BASE_DIR / "investigational_drug_moa_chosen_log"
MASTER_MOA_CHOICES_PATH = BASE_DIR / "investigational_drug_moa_chosen_master.json"

MOA_CHOICE_DIR.mkdir(parents=True, exist_ok=True)
MOA_CHOICE_LOG_DIR.mkdir(parents=True, exist_ok=True)

MAX_WORKERS_MOA = 8
MODEL = "gpt-5"   # adjust if needed

client = OpenAIWrapper()

# -------------------------------------------------
# Helpers
# -------------------------------------------------
def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}

    # Direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    # Fallback: first {...} region
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}

    return {}

def load_master_moa_choices() -> dict:
    if not MASTER_MOA_CHOICES_PATH.exists():
        return {}
    try:
        return json.loads(MASTER_MOA_CHOICES_PATH.read_text(encoding="utf-8"))
    except Exception:
        return {}

def build_moa_mesh_prompt(moa_payload: dict, candidate_mesh_terms: list[str]) -> str:
    """
    Prompt the LLM to choose the best MeSH term that represents the mechanism of action.
    If no suitable MeSH term exists, the model MUST return "[none]".
    """

    payload_json = json.dumps(moa_payload, ensure_ascii=False, indent=2)
    mesh_json    = json.dumps(candidate_mesh_terms, ensure_ascii=False, indent=2)

    return f"""
You are an expert pharmacologist and MeSH annotation specialist.

You are given:
1) A mechanism-of-action (MOA) text string describing how a drug works.
2) A set of PubMed-derived MeSH terms (candidate list).
3) Condensed PubMed records used for MOA search.

Your tasks:

------------------------------------------------------------
TASK 1 — Select the Best MeSH Term
------------------------------------------------------------
Choose EXACTLY ONE MeSH term that best represents the mechanism of action.

Rules:
- You MUST select a term *only* from the candidate list.
- Choose the most mechanistic/specific pharmacologic concept available
  (e.g., "Receptor Antagonists", "Antibodies, Monoclonal", "Kinase Inhibitors").
- Avoid generic terms ("Humans", "Adult", "Neoplasms") unless absolutely no mechanistic term exists.

------------------------------------------------------------
TASK 2 — Handle Cases with No Good Mechanistic Term
------------------------------------------------------------
If NONE of the candidate MeSH terms meaningfully represent the MOA:

You MUST output:

  "chosen_mesh_term": "[none]",
  "source_pmid": null,
  "rationale": "Explain why no term fits."

This is a VALID and EXPECTED outcome.

------------------------------------------------------------
OUTPUT FORMAT  (STRICT)
------------------------------------------------------------

Return ONLY a valid JSON object with EXACTLY these fields:

{{
  "chosen_mesh_term": "<one exact candidate term OR '[none]'>",
  "source_pmid": "<PMID you relied on OR null>",
  "rationale": "One concise sentence explaining your decision."
}}

Constraints:
- If you choose a MeSH term, it MUST MATCH EXACTLY one item from the candidate list.
- If no suitable term exists, return "[none]".
- JSON must be valid and parseable.

------------------------------------------------------------
MOA Payload (input data)
------------------------------------------------------------
{payload_json}

------------------------------------------------------------
Candidate MeSH Terms
------------------------------------------------------------
{mesh_json}
""".strip()


master_moa_choices = load_master_moa_choices()
master_moa_lock = threading.Lock()

moa_counter = {
    "processed": 0,
    "skipped_existing": 0,
    "llm_error": 0,
    "parse_error": 0,
    "coverage_error": 0,   # includes "chosen term not in JSON-derived list"
    "no_candidates": 0,
}
moa_counter_lock = threading.Lock()


def process_moa_file(fp: Path, idx: int, total: int) -> None:
    """Process a single MOA PubMed-search JSON file."""
    try:
        payload = json.loads(fp.read_text(encoding="utf-8"))
    except Exception as e:
        print(f"⚠️ [{idx}/{total}] Error reading {fp.name}: {e}")
        with moa_counter_lock:
            moa_counter["parse_error"] += 1
        return

    moa_id = payload.get("moa_id") or fp.stem
    mechanism = payload.get("mechanism", "") or ""

    if not moa_id:
        print(f"⚠️ [{idx}/{total}] Missing moa_id in {fp.name}, skipping")
        return

    out_fp = MOA_CHOICE_DIR / f"{moa_id}.json"
    if out_fp.exists():
        with moa_counter_lock:
            moa_counter["skipped_existing"] += 1
        return

    match = payload.get("match") or {}
    records = match.get("records") or {}
    pmids = match.get("pmids") or []

    # Collect candidate MeSH terms (unique, in stable order) FROM THE JSON ONLY
    candidate_terms = []
    seen_terms = set()
    for pmid, rec in records.items():
        mesh_terms = rec.get("mesh_terms") or []
        for term in mesh_terms:
            if term and term not in seen_terms:
                seen_terms.add(term)
                candidate_terms.append(term)

    if not candidate_terms:
        print(f"⚠️ [{idx}/{total}] No candidate MeSH terms for moa_id={moa_id}, skipping")
        with moa_counter_lock:
            moa_counter["no_candidates"] += 1
        return

    # Condensed payload for the model (avoid full abstracts to save tokens)
    condensed_records = {
        pmid: {
            "title": (rec.get("title") or ""),
            "mesh_terms": (rec.get("mesh_terms") or []),
        }
        for pmid, rec in records.items()
    }

    moa_payload = {
        "moa_id": moa_id,
        "mechanism": mechanism,
        "tried_terms": payload.get("tried_terms") or [],
        "pmids": pmids,
        "records": condensed_records,
    }

    prompt = build_moa_mesh_prompt(moa_payload, candidate_terms)

    token = moa_id
    hash_id = moa_id

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    # Call LLM
    try:
        t0 = time.perf_counter()
        res = client.query(prompt=prompt, model=MODEL)
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)
    except Exception as e:
        print(f"⚠️ [{idx}/{total}] LLM error for moa_id={moa_id}: {e}")
        with moa_counter_lock:
            moa_counter["llm_error"] += 1
        return

    # Parse JSON
    obj = extract_json_object(text_response)

    if not isinstance(obj, dict) or not obj:
        print(f"⚠️ [{idx}/{total}] JSON parse error moa_id={moa_id}, raw={text_response!r}")
        with moa_counter_lock:
            moa_counter["parse_error"] += 1
        return

    chosen_term = obj.get("chosen_mesh_term")
    source_pmid = obj.get("source_pmid")
    rationale = obj.get("rationale")

    # HARD CHECK: chosen term
    if not chosen_term or not isinstance(chosen_term, str):
        print(f"⚠️ [{idx}/{total}] Missing or invalid chosen_mesh_term for moa_id={moa_id}")
        with moa_counter_lock:
            moa_counter["coverage_error"] += 1
        return

    # Special allowed sentinel for "no good term"
    if chosen_term == "[none]":
        # Accept even though it's not in candidate_terms
        mapped = {
            "moa_id": moa_id,
            "mechanism": mechanism,
            "candidate_mesh_terms": candidate_terms,
            "chosen_mesh_term": chosen_term,
            "source_pmid": source_pmid,
            "rationale": rationale,
            "source": "llm",
        }
    else:
        # For any real term, it MUST come from the JSON-derived candidate list
        if chosen_term not in candidate_terms:
            # DNE in JSON (hallucinated or modified term) → reject, do NOT save
            print(
                f"⚠️ [{idx}/{total}] chosen_mesh_term not in JSON-derived candidate list "
                f"for moa_id={moa_id}: {chosen_term!r}"
            )
            with moa_counter_lock:
                moa_counter["coverage_error"] += 1
            return

        # Optional: source_pmid sanity check (must be one of pmids or None)
        if source_pmid is not None and source_pmid not in pmids:
            print(
                f"⚠️ [{idx}/{total}] source_pmid {source_pmid!r} not in pmids for moa_id={moa_id}; "
                f"still accepting chosen_mesh_term"
            )

        mapped = {
            "moa_id": moa_id,
            "mechanism": mechanism,
            "candidate_mesh_terms": candidate_terms,
            "chosen_mesh_term": chosen_term,
            "source_pmid": source_pmid,
            "rationale": rationale,
            "source": "llm",
        }

    # Save per-MOA JSON
    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    # Log entry
    log_payload = {
        "token": token,
        "hash_id": hash_id,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (MOA_CHOICE_LOG_DIR / f"{hash_id}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    # Update master
    with master_moa_lock:
        master_moa_choices[moa_id] = mapped
        MASTER_MOA_CHOICES_PATH.write_text(
            json.dumps(master_moa_choices, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )

    with moa_counter_lock:
        moa_counter["processed"] += 1
        if moa_counter["processed"] % 50 == 0:
            print(f"Progress: processed {moa_counter['processed']} MOA entries...")


# -------------------------------------------------
# RUN CONCURRENTLY OVER MOA PUBMED SEARCH FILES
# -------------------------------------------------
moa_files = sorted(MOA_PUBMED_DIR.glob("*.json"))
total_moa = len(moa_files)
print(f"Loaded {total_moa} MOA PubMed-search files for MeSH-term selection.")

with ThreadPoolExecutor(max_workers=MAX_WORKERS_MOA) as ex:
    futures = {
        ex.submit(process_moa_file, fp, idx, total_moa): fp.name
        for idx, fp in enumerate(moa_files, start=1)
    }
    for fut in as_completed(futures):
        name = futures[fut]
        try:
            fut.result()
        except Exception as e:
            print(f"⚠️ Worker error (MOA MeSH selection) file={name}: {e}")

print(
    f"✅ MOA MeSH-term selection complete. "
    f"processed={moa_counter['processed']}, "
    f"skipped={moa_counter['skipped_existing']}, "
    f"llm_error={moa_counter['llm_error']}, "
    f"parse_error={moa_counter['parse_error']}, "
    f"coverage_error={moa_counter['coverage_error']}, "
    f"no_candidates={moa_counter['no_candidates']}"
)
print(f"Chosen MOA directory: {MOA_CHOICE_DIR}")
print(f"Log directory:        {MOA_CHOICE_LOG_DIR}")
print(f"Master choices:       {MASTER_MOA_CHOICES_PATH}")

In [None]:
from pathlib import Path
import requests

BASE_URL = "https://nlmpubs.nlm.nih.gov/projects/mesh/MESH_FILES/xmlmesh"
OUT_DIR = Path("cache")
OUT_DIR.mkdir(parents=True, exist_ok=True)

FILES = ["desc2025.xml", "supp2025.xml"]

for fname in FILES:
    url = f"{BASE_URL}/{fname}"
    out_path = OUT_DIR / fname

    # Skip if already downloaded
    if out_path.exists() and out_path.stat().st_size > 0:
        print(f"Skipping {fname}, already exists.")
        continue

    print(f"⬇Downloading {url} -> {out_path}")
    r = requests.get(url, timeout=60)
    r.raise_for_status()
    out_path.write_bytes(r.content)
    print(f"✅ Downloaded {fname}")

print("Done.")

In [None]:
import os, xml.etree.ElementTree as ET
import html, unicodedata

DESC_XML  = "cache/desc2025.xml"
SUPP_XML  = "cache/supp2025.xml"

os.makedirs("output", exist_ok=True)

def norm(s: str) -> str:
    if not isinstance(s, str):
        return ""
    t = html.unescape(s)
    t = unicodedata.normalize("NFKC", t)
    t = t.replace("\u2019", "'").replace("\u2018", "'").replace("\u2032", "'").replace("\u2033", '"')
    t = t.replace("\u201C", '"').replace("\u201D", '"')
    t = t.replace("\u2010", "-").replace("\u2011", "-").replace("\u2012", "-").replace("\u2013", "-").replace("\u2014", "-")
    return " ".join(t.strip().lower().split())

def clean_text(s: str) -> str:
    """Unicode-clean + collapse whitespace (preserve case)."""
    if not isinstance(s, str):
        return ""
    t = html.unescape(s)
    t = unicodedata.normalize("NFKC", t)
    t = t.replace("\u2019", "'").replace("\u2018", "'").replace("\u2032", "'").replace("\u2033", '"')
    t = t.replace("\u201C", '"').replace("\u201D", '"')
    t = t.replace("\u2010", "-").replace("\u2011", "-").replace("\u2012", "-").replace("\u2013", "-").replace("\u2014", "-")
    return " ".join(t.strip().split())

def _dedup(seq):
    seen = set(); out = []
    for x in seq:
        if x and x not in seen:
            seen.add(x); out.append(x)
    return out

def _extract_scope_note_from_descriptor(rec: ET.Element) -> str:
    """
    Prefer the ScopeNote of the PreferredConcept (PreferredConceptYN='Y'),
    else fall back to the first ScopeNote present under any Concept.
    """
    # Preferred concept first
    pref = rec.find(".//ConceptList/Concept[@PreferredConceptYN='Y']/ScopeNote")
    if pref is not None and pref.text:
        return clean_text(pref.text)

    # Any concept scope note as fallback
    any_sn = rec.find(".//ConceptList/Concept/ScopeNote")
    if any_sn is not None and any_sn.text:
        return clean_text(any_sn.text)

    return ""

def _extract_scope_note_from_supp(rec: ET.Element) -> str:
    """
    For SCRs, ScopeNote can also live under Concept.
    Prefer the PreferredConcept (if flagged), else the first available.
    """
    pref = rec.find(".//ConceptList/Concept[@PreferredConceptYN='Y']/ScopeNote")
    if pref is not None and pref.text:
        return clean_text(pref.text)

    any_sn = rec.find(".//ConceptList/Concept/ScopeNote")
    if any_sn is not None and any_sn.text:
        return clean_text(any_sn.text)

    return ""

def load_mesh_tree_and_id(desc_xml_fp: str, supp_xml_fp: str) -> dict[str, dict[str, list[str] | str]]:
    # term_map[normalized_term] = {"mesh_id": <UI>, "tree_numbers": [..], "scope_note": <str>}
    term_map: dict[str, dict[str, list[str] | str]] = {}

    # Helper maps for fallbacks/joins
    heading_to_tree: dict[str, list[str]] = {}
    ui_to_tree: dict[str, list[str]] = {}
    heading_to_scope: dict[str, str] = {}
    ui_to_scope: dict[str, str] = {}

    # --- Descriptors ---
    if os.path.exists(desc_xml_fp):
        root = ET.parse(desc_xml_fp).getroot()
        for rec in root.findall(".//DescriptorRecord"):
            desc_ui = (rec.findtext("DescriptorUI") or "").strip()
            tree_numbers = _dedup([tn.text.strip() for tn in rec.findall(".//TreeNumberList/TreeNumber") if tn.text])

            heading_raw = rec.findtext("DescriptorName/String")
            heading_norm = norm(heading_raw) if heading_raw else ""
            scope_note = _extract_scope_note_from_descriptor(rec)

            if heading_norm:
                heading_to_tree[heading_norm] = tree_numbers
                heading_to_scope[heading_norm] = scope_note
            if desc_ui:
                ui_to_tree[desc_ui] = tree_numbers
                ui_to_scope[desc_ui] = scope_note

            # Collect all terms mapped to this descriptor
            terms = set()
            if heading_raw:
                terms.add(heading_norm)
            for concept in rec.findall(".//Concept"):
                for term in concept.findall(".//Term"):
                    s = term.findtext("String")
                    if s:
                        terms.add(norm(s))

            for term in terms:
                term_map[term] = {
                    "mesh_id": desc_ui,
                    "tree_numbers": tree_numbers,
                    "scope_note": scope_note
                }

    # --- Supplementary (SCRs) ---
    if os.path.exists(supp_xml_fp):
        root = ET.parse(supp_xml_fp).getroot()
        for rec in root.findall(".//SupplementalRecord"):
            supp_ui = (rec.findtext("SupplementalRecordUI") or "").strip()

            # Collect ALL names for this SCR
            names = set()

            for s in rec.findall(".//SupplementalRecordName/String"):
                if s is not None and s.text:
                    names.add(norm(s.text))

            for s in rec.findall(".//ConceptList/Concept/ConceptName/String"):
                if s is not None and s.text:
                    names.add(norm(s.text))

            for s in rec.findall(".//ConceptList/Concept/TermList/Term/String"):
                if s is not None and s.text:
                    names.add(norm(s.text))

            # Direct trees (often none for SCRs)
            tree_numbers = [tn.text.strip() for tn in rec.findall(".//TreeNumberList/TreeNumber") if tn.text]

            # SCR scope note (preferred concept first)
            scr_scope_note = _extract_scope_note_from_supp(rec)

            # Fallback via HeadingMappedTo (names and UIs)
            mapped_scope_note = ""
            if not tree_numbers:
                # Try mapped names
                mapped_names = [n.text.strip() for n in rec.findall(".//HeadingMappedTo/DescriptorReferredTo/DescriptorName/String") if n is not None and n.text]
                for m in mapped_names:
                    m_norm = norm(m)
                    tns = heading_to_tree.get(m_norm)
                    if tns:
                        tree_numbers.extend(tns)
                    if not mapped_scope_note and m_norm in heading_to_scope and heading_to_scope[m_norm]:
                        mapped_scope_note = heading_to_scope[m_norm]

                # Try mapped UIs
                mapped_uis = [u.text.strip().lstrip("*") for u in rec.findall(".//HeadingMappedTo/DescriptorReferredTo/DescriptorUI") if u is not None and u.text]
                for mui in mapped_uis:
                    tns = ui_to_tree.get(mui)
                    if tns:
                        tree_numbers.extend(tns)
                    if not mapped_scope_note and mui in ui_to_scope and ui_to_scope[mui]:
                        mapped_scope_note = ui_to_scope[mui]

                tree_numbers = _dedup(tree_numbers)

            final_scope_note = scr_scope_note or mapped_scope_note or ""

            for name in names:
                # Keep Descriptor mapping if already present for same term
                if name in term_map and str(term_map[name].get("mesh_id", "")).startswith("D"):
                    continue
                term_map[name] = {
                    "mesh_id": supp_ui,
                    "tree_numbers": tree_numbers,
                    "scope_note": final_scope_note
                }

    return term_map

# -------- Run --------
TREE_INDEX = load_mesh_tree_and_id(DESC_XML, SUPP_XML)
print(f"✅ Loaded MeSH index terms: {len(TREE_INDEX):,} unique normalized terms")

# Quick checks
for q in ["sotatercept", "ACE-011", "winrevair"]:
    k = norm(q)
    print(q, "→", TREE_INDEX.get(k))

In [None]:
import json
import ast
from pathlib import Path
import pandas as pd

BASE_DIR = Path("cache")
TRIALS_IN_PATH  = BASE_DIR / "trial_product_breakdown.csv"
TRIALS_OUT_PATH = BASE_DIR / "trial_mechanism_mesh_mapping.csv"

MASTER_MOA_CHOICES_PATH = BASE_DIR / "investigational_drug_moa_chosen_master.json"

# -----------------------------------------
# Sanity: TREE_INDEX and norm must already be loaded
# -----------------------------------------
try:
    TREE_INDEX
except NameError:
    raise RuntimeError("TREE_INDEX is not defined — run the MeSH loader cell first.")

try:
    norm
except NameError:
    raise RuntimeError("norm() is not defined — ensure it is defined in the MeSH loader cell.")

# -----------------------------------------
# Load MOA → MeSH choices
# -----------------------------------------
def load_master_moa_choices() -> dict:
    if not MASTER_MOA_CHOICES_PATH.exists():
        return {}
    try:
        return json.loads(MASTER_MOA_CHOICES_PATH.read_text(encoding="utf-8"))
    except Exception:
        return {}

master_moa_choices = load_master_moa_choices()

# Map from mechanism string → choice record
# (mechanism strings may be either true MOAs or drug names like "narfurine hydrochloride")
mechanism_to_choice: dict[str, dict] = {}
for moa_id, rec in master_moa_choices.items():
    mech = rec.get("mechanism")
    if mech:
        # first occurrence wins, keep deterministic mapping
        mechanism_to_choice.setdefault(mech, rec)

# -----------------------------------------
# Helpers
# -----------------------------------------
def parse_listish(x):
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
    except Exception:
        pass
    return [s]


def mesh_heading_to_tree_numbers(chosen_mesh_term: str):
    """
    Return *all* tree numbers for the chosen MeSH heading.
    """
    if not chosen_mesh_term or chosen_mesh_term == "[none]":
        return []
    base = chosen_mesh_term.split(" / ")[0].replace("*", "").strip()
    key = norm(base)
    info = TREE_INDEX.get(key)
    if not info:
        return []
    return info.get("tree_numbers", []) or []


# Clinical-pharmacology-ish heuristic for ONE primary tree number
PRIORITY_PREFIXES = [
    "D12.",  # Proteins: receptors, enzymes, cytokines, antibodies (biologics / targets)
    "D27.",  # Chemical Actions and Uses: classic pharmacologic classes
    "D02.",  # Organic Chemicals: small-molecule drugs
    "D09.",  # Carbohydrates
    "D23.",  # Immunologic Factors
    "D26.",  # Biological Factors
]

def _depth(tn: str) -> int:
    # fewer segments = higher-level class
    return len(tn.split("."))

def choose_primary_tree(tree_numbers):
    """
    Given a list of tree numbers for ONE MeSH term,
    choose a single 'primary' tree that best reflects the
    pharmacologic / target-level concept.
    """
    if not tree_numbers:
        return ""

    tns = [t.strip() for t in tree_numbers if isinstance(t, str) and t.strip()]
    if not tns:
        return ""

    # 1) Prefer specific high-value branches (by prefix)
    for prefix in PRIORITY_PREFIXES:
        candidates = [t for t in tns if t.startswith(prefix)]
        if candidates:
            # choose the highest-level (shortest depth) node in that branch
            return max(candidates, key=_depth)

    # 2) Else prefer any Chemicals & Drugs branch (D*)
    d_candidates = [t for t in tns if t.startswith("D")]
    if d_candidates:
        return max(d_candidates, key=_depth)

    # 3) Fallback: shortest overall
    return max(tns, key=_depth)


def map_mechanism_list(mech_list, fallback_terms=None):
    """
    For a list of MOA strings (mech_list):

    - If mech_list[i] is non-empty and not "[none]", map using that mechanism string.
    - If mech_list[i] is empty / "[none]" and fallback_terms is provided, then:
        * use fallback_terms[i] (e.g., the investigational product name)
          as the key into mechanism_to_choice, if present.

    Returns:
      mapped_terms   : list of MeSH headings (one per mechanism, or "")
      all_tree_lists : list of [list-of-tree-numbers] per mechanism
      primary_trees  : list of ONE chosen tree number per mechanism ("" if none)
    """
    mapped_terms   = []
    all_tree_lists = []
    primary_trees  = []

    fallback_terms = fallback_terms or []

    for idx, mech in enumerate(mech_list):
        mech_str = (mech or "").strip()

        # If mechanism is missing / "[none]", try fallback = investigational product name
        if (not mech_str or mech_str == "[none]") and fallback_terms:
            fb = (fallback_terms[idx] if idx < len(fallback_terms) else "") or ""
            fb = fb.strip()
            if fb:
                mech_str = fb

        # If still nothing, bail on this slot
        if not mech_str:
            mapped_terms.append("")
            all_tree_lists.append([])
            primary_trees.append("")
            continue

        choice = mechanism_to_choice.get(mech_str)
        if not choice:
            # No mapping found in master MOA choices
            mapped_terms.append("")
            all_tree_lists.append([])
            primary_trees.append("")
            continue

        chosen = choice.get("chosen_mesh_term") or ""
        mapped_terms.append(chosen)

        all_trees = mesh_heading_to_tree_numbers(chosen)
        all_tree_lists.append(all_trees)

        primary = choose_primary_tree(all_trees)
        primary_trees.append(primary)

    return mapped_terms, all_tree_lists, primary_trees


# -----------------------------------------
# Load trial dataset
# -----------------------------------------
df = pd.read_csv(TRIALS_IN_PATH)

# -----------------------------------------
# Compute mapped columns row-wise
# -----------------------------------------
inv_mapped = []
inv_trees_all  = []
inv_trees_primary = []

ac_mapped  = []
ac_trees_all   = []
ac_trees_primary = []

soc_mapped = []
soc_trees_all  = []
soc_trees_primary = []

for _, row in df.iterrows():
    # Parse mechanisms
    inv_mechs = parse_listish(row.get("investigational_products_mechanism"))
    ac_mechs  = parse_listish(row.get("active_comparators_mechanism"))
    soc_mechs = parse_listish(row.get("standard_of_care_mechanism"))

    # Parse investigational product names for fallback
    inv_products = parse_listish(row.get("investigational_products"))

    # Investigational products: use mechanism, but if empty, fall back to product name
    inv_m, inv_all_t, inv_primary_t = map_mechanism_list(inv_mechs, fallback_terms=inv_products)
    inv_mapped.append(inv_m)
    inv_trees_all.append(inv_all_t)
    inv_trees_primary.append(inv_primary_t)

    # Active comparators: no fallback requested → use mechanisms directly
    ac_m,  ac_all_t,  ac_primary_t  = map_mechanism_list(ac_mechs, fallback_terms=None)
    ac_mapped.append(ac_m)
    ac_trees_all.append(ac_all_t)
    ac_trees_primary.append(ac_primary_t)

    # Standard of care: no fallback requested → use mechanisms directly
    soc_m, soc_all_t, soc_primary_t = map_mechanism_list(soc_mechs, fallback_terms=None)
    soc_mapped.append(soc_m)
    soc_trees_all.append(soc_all_t)
    soc_trees_primary.append(soc_primary_t)


# -----------------------------------------
# Insert columns next to original mechanism columns (logical order)
# -----------------------------------------
def insert_after(df, col, newcol, values):
    cols = list(df.columns)
    idx = cols.index(col)
    df.insert(idx + 1, newcol, values)

# Investigational products
insert_after(df,
             "investigational_products_mechanism",
             "investigational_products_mechanism_mapped",
             inv_mapped)

insert_after(df,
             "investigational_products_mechanism_mapped",
             "investigational_products_mechanism_tree_numbers",
             inv_trees_all)

insert_after(df,
             "investigational_products_mechanism_tree_numbers",
             "investigational_products_mechanism_primary_tree_numbers",
             inv_trees_primary)

# Active comparators
insert_after(df,
             "active_comparators_mechanism",
             "active_comparators_mechanism_mapped",
             ac_mapped)

insert_after(df,
             "active_comparators_mechanism_mapped",
             "active_comparators_mechanism_tree_numbers",
             ac_trees_all)

insert_after(df,
             "active_comparators_mechanism_tree_numbers",
             "active_comparators_mechanism_primary_tree_numbers",
             ac_trees_primary)

# Standard of care
insert_after(df,
             "standard_of_care_mechanism",
             "standard_of_care_mechanism_mapped",
             soc_mapped)

insert_after(df,
             "standard_of_care_mechanism_mapped",
             "standard_of_care_mechanism_tree_numbers",
             soc_trees_all)

insert_after(df,
             "standard_of_care_mechanism_tree_numbers",
             "standard_of_care_mechanism_primary_tree_numbers",
             soc_trees_primary)

# -----------------------------------------
# Save output
# -----------------------------------------
df.to_csv(TRIALS_OUT_PATH, index=False)
print(f"✅ Wrote: {TRIALS_OUT_PATH}")
print(df.head(5).to_markdown(index=False))

In [None]:
import ast
import pandas as pd
from pathlib import Path
from collections import Counter, defaultdict

BASE_DIR = Path("cache")
INPUT_PATH  = BASE_DIR / "trial_mechanism_mesh_mapping.csv"
OUTPUT_PATH = BASE_DIR / "trial_mechanism_mesh_tree_number_counts.csv"

df = pd.read_csv(INPUT_PATH)

# Columns with mapped MeSH terms
MESH_TERM_COLS = [
    "investigational_products_mechanism_mapped",
    "active_comparators_mechanism_mapped",
    "standard_of_care_mechanism_mapped",
]

# Columns with *primary* tree numbers (one tree per mechanism)
TREE_COLS = [
    "investigational_products_mechanism_primary_tree_numbers",
    "active_comparators_mechanism_primary_tree_numbers",
    "standard_of_care_mechanism_primary_tree_numbers",
]

def parse_list(x):
    """Parse list-like strings safely into Python lists."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    try:
        val = ast.literal_eval(x)
        return val if isinstance(val, list) else []
    except Exception:
        return []

tree_to_mesh_terms = defaultdict(list)
pair_counter = Counter()   # (mesh_term, primary_tree_number) → count

for _, row in df.iterrows():
    # Parse lists for each mechanism category
    mesh_term_lists = [parse_list(row[c]) for c in MESH_TERM_COLS]
    tree_number_lists = [parse_list(row[c]) for c in TREE_COLS]

    # Iterate over the three mechanism categories in parallel
    for mesh_terms, tree_nums in zip(mesh_term_lists, tree_number_lists):
        # mesh_terms: e.g. ["Antibodies, Bispecific* / pharmacology; immunology", "[none]", ...]
        # tree_nums:  e.g. ["D12.776.124.486.485.114.125", "", ...]
        for mesh_term, primary_tn in zip(mesh_terms, tree_nums):
            # Skip unusable entries
            if not mesh_term or mesh_term == "[none]":
                continue
            if not isinstance(primary_tn, str) or not primary_tn.strip():
                continue

            tn = primary_tn.strip()
            pair_counter[(mesh_term, tn)] += 1
            tree_to_mesh_terms[tn].append(mesh_term)

# Convert to DataFrame
out_rows = [
    {
        "mesh_term": mesh_term,
        "tree_number": tree_num,
        "count": count,
    }
    for (mesh_term, tree_num), count in pair_counter.items()
]

out_df = pd.DataFrame(out_rows).sort_values("count", ascending=False)

# Save
out_df.to_csv(OUTPUT_PATH, index=False)

print(f"✅ Saved tree number + term breakdown → {OUTPUT_PATH}")
print("Top 20 combinations:")
print(out_df.head(20).to_markdown())

In [None]:
import pandas as pd
from pathlib import Path

BASE_DIR = Path("cache")
INPUT_PATH  = BASE_DIR / "trial_mechanism_mesh_tree_number_counts.csv"
OUTPUT_PATH = BASE_DIR / "trial_mechanism_super_group_mapping.csv"

df = pd.read_csv(INPUT_PATH)

# -------------------------------------------------------------------
# Define super-group labels (5 buckets)
# -------------------------------------------------------------------
G1 = "hematopoietic_growth_factor_esa"
G2 = "antibody_or_bispecific_therapy"
G3 = "classic_cytotoxic_or_antimetabolite"
G4 = "small_molecule_immunomodulator"
G5 = "metabolic_enzyme_endocrine_other"


def classify_super_group(mesh_term: str, tree_number: str) -> str:
    """
    Heuristic mapping of mesh_term + tree_number to one of 5 MOA super-groups.
    Think like a clinical pharmacologist, but keep it deterministic and simple.
    """
    t = (mesh_term or "").lower()
    tn = (tree_number or "").strip()

    # -----------------------
    # GROUP 1: Hematopoietic growth factors / ESAs
    # -----------------------
    if any(kw in t for kw in [
        "erythropoietin", "thrombopoietin", "epo", "tpo", "hematopoietic"
    ]):
        return G1
    if tn.startswith("D12.776.543.750.705.852.150") or tn.startswith("D12.776.543.750.705.852.610"):
        return G1
    if tn.startswith("D12.644.276.374.410.240.150") or tn.startswith("D12.644.276.374.410.240.750"):
        return G1

    # -----------------------
    # GROUP 2: Antibodies & bispecifics (checkpoint, cytokine, receptor targeting)
    # -----------------------
    # Obvious antibody keywords
    if "antibod" in t or "immunoconjugate" in t or "chimeric antigen" in t or "cancer vaccines" in t:
        return G2
    # MeSH branches that are essentially antibody land or fusion proteins
    if tn.startswith("D12.776.124.486.485.114") or tn.startswith("D12.776.124.790.651.114"):
        return G2
    # Checkpoint / cytokine biology where in practice these are nearly always mAbs / fusion biologics
    if any(kw in t for kw in [
        "programmed cell death 1 receptor", "pd-1", "pd1",
        "pd-l1", "pdl1",
        "erbB-2", "erbb2",
        "tumor necrosis factor", "tnf",
        "interleukin-4 receptor", "interleukin-5", "interleukin-17",
        "cd47 antigen", "lectins, c-type"
    ]):
        return G2

    # -----------------------
    # GROUP 3: Classic cytotoxic / antimetabolite oncology drugs
    # -----------------------
    if any(kw in t for kw in [
        "antimetabolites, antineoplastic",
        "antineoplastic agents, alkylating",
        "topoisomerase i inhibitors",
        "topoisomerase ii inhibitors",
        "vinca alkaloids",
        "vinblastine",
        "taxoids",
        "paclitaxel",
        "fluorouracil",
        "phosphoramide mustards",
        "tubulin modulators"
    ]):
        return G3
    # If the tree number clearly sits in classic chemo small-molecule branches
    if tn.startswith("D27.505.519.186") or tn.startswith("D27.505.519.124"):
        return G3
    if tn.startswith("D02.455.526.728") or tn.startswith("D03.633.100.496.500.500.681.827"):
        return G3
    if tn.startswith("D01.710"):  # platinum compounds
        return G3
    if tn.startswith("D27.505.519.593.249.500"):  # tubulin modulators
        return G3

    # -----------------------
    # GROUP 4: Small-molecule immunomodulators / signaling modifiers
    # -----------------------
    if any(kw in t for kw in [
        "calcineurin inhibitors",
        "janus kinase inhibitors",
        "immunosuppressive agents",
        "glucocorticoids",
        "adrenal cortex hormones",
        "tor serine-threonine kinases",
        "hypoxia-inducible factor 1",
        "hypoxia-inducible factor-proline dioxygenases",
        "interleukin-2",
    ]):
        return G4
    # HIF pathway signaling
    if tn.startswith("D12.776.260.103.625") or tn.startswith("D08.811.682.690.416.617.500"):
        return G4

    # -----------------------
    # GROUP 5: Metabolic / enzyme / endocrine / other
    # -----------------------
    # Metabolic / enzyme / endocrine-type keywords
    if any(kw in t for kw in [
        "biguanides",
        "urate oxidase",
        "phosphodiesterase 4 inhibitors",
        "histamine h1 antagonists",
        "che lating agents",
        "chelating agents",
        "heparin",
        "antimalarials",
        "imp dehydrogenase",
        "thymidine phosphorylase",
        "thymidylate synthase",
        "leucovorin",
        "cross-linking reagents",
        "membrane glycoproteins"
    ]):
        return G5
    # Heuristic: if it's in D27.505.* but we didn't classify as chemo or immunomod,
    # it's likely a metabolic / enzyme / other chemical agent
    if tn.startswith("D27.505."):
        return G5

    # Default fallback
    return G5


# Apply classifier
out_df = df[["mesh_term", "tree_number"]].copy()
out_df["mechanism_super_group"] = [
    classify_super_group(m, tn) for m, tn in zip(out_df["mesh_term"], out_df["tree_number"])
]

# Save
out_df.to_csv(OUTPUT_PATH, index=False)
print(f"✅ Saved mechanism super-group mapping → {OUTPUT_PATH}")
print(out_df.head(20).to_markdown())


In [None]:
import ast
import pandas as pd
from pathlib import Path

BASE_DIR = Path("cache")

MAP_PATH      = BASE_DIR / "trial_mechanism_super_group_mapping.csv"
TRIALS_IN     = BASE_DIR / "trial_mechanism_mesh_mapping.csv"
TRIALS_OUT    = BASE_DIR / "trial_mechanism_with_super_groups.csv"

# ---------------------------------------------------
# Load mapping: (mesh_term, tree_number) → super_group
# ---------------------------------------------------
map_df = pd.read_csv(MAP_PATH)

# Normalize tree_number a bit (strip whitespace)
map_df["tree_number"] = map_df["tree_number"].astype(str).str.strip()

mapping = {
    (str(row["mesh_term"]), str(row["tree_number"])): row["mechanism_super_group"]
    for _, row in map_df.iterrows()
}

print(f"Loaded {len(mapping):,} (mesh_term, tree_number) → super_group mappings")


# ---------------------------------------------------
# Helpers
# ---------------------------------------------------
def parse_list(x):
    """Parse list-like strings (e.g. "['a','b']") into Python lists."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
    except Exception:
        return []
    # If it's a single scalar, wrap in list
    return [s]


def build_super_group_list(mech_terms, mech_trees):
    """
    Given:
      mech_terms : list of MeSH headings (strings)
      mech_trees : list of single tree numbers (strings)
    Return:
      list of mechanism_super_group strings (same length).
    """
    out = []
    for term, tn in zip(mech_terms, mech_trees):
        term = (term or "").strip()
        tn   = (tn or "").strip()
        if not term or term == "[none]" or not tn:
            out.append("")
            continue
        key = (term, tn)
        out.append(mapping.get(key, ""))  # "" if not found
    return out


def build_super_group_list_with_fallback(
    mech_terms,
    mech_trees,
    fallback_terms,
    fallback_trees,
):
    """
    For investigational products:
    - First try mechanism-based MeSH mapping (mech_terms/mech_trees).
    - If the mechanism term is missing / '[none]' / empty, fall back to
      the investigational product MeSH mapping (fallback_terms/fallback_trees).
    """
    # Ensure all are lists
    mech_terms   = mech_terms or []
    mech_trees   = mech_trees or []
    fallback_terms = fallback_terms or []
    fallback_trees = fallback_trees or []

    n = max(len(mech_terms), len(mech_trees), len(fallback_terms), len(fallback_trees))
    out = []

    for i in range(n):
        # Primary (mechanism-based) term/tn
        term_mech = (mech_terms[i] if i < len(mech_terms) else "") or ""
        tn_mech   = (mech_trees[i] if i < len(mech_trees) else "") or ""

        term_mech = term_mech.strip()
        tn_mech   = tn_mech.strip()

        # Fallback (drug-based) term/tn
        term_fb = (fallback_terms[i] if i < len(fallback_terms) else "") or ""
        tn_fb   = (fallback_trees[i] if i < len(fallback_trees) else "") or ""

        term_fb = term_fb.strip()
        tn_fb   = tn_fb.strip()

        # Decide which to use
        chosen_term = None
        chosen_tn   = None

        # 1) Use mechanism term if present and not [none]
        if term_mech and term_mech != "[none]" and tn_mech:
            chosen_term = term_mech
            chosen_tn   = tn_mech
        # 2) Else fall back to investigational product MeSH term
        elif term_fb and term_fb != "[none]" and tn_fb:
            chosen_term = term_fb
            chosen_tn   = tn_fb

        if not chosen_term or not chosen_tn:
            out.append("")
            continue

        key = (chosen_term, chosen_tn)
        out.append(mapping.get(key, ""))  # "" if no mapping

    return out


# ---------------------------------------------------
# Load trials and build super-group columns
# ---------------------------------------------------
df = pd.read_csv(TRIALS_IN)

inv_sg_list = []
ac_sg_list  = []
soc_sg_list = []

for _, row in df.iterrows():
    # -----------------------------
    # INVESTIGATIONAL PRODUCTS
    # -----------------------------
    # Mechanism-based mapping
    inv_mech_terms = parse_list(row.get("investigational_products_mechanism_mapped"))
    inv_mech_tn    = parse_list(row.get("investigational_products_mechanism_primary_tree_numbers"))

    # Fallback: investigational product MeSH mapping
    inv_prod_terms = parse_list(row.get("investigational_products_mapped"))
    inv_prod_tn    = parse_list(row.get("investigational_products_primary_tree_numbers"))

    inv_sg_list.append(
        build_super_group_list_with_fallback(
            inv_mech_terms,
            inv_mech_tn,
            inv_prod_terms,
            inv_prod_tn,
        )
    )

    # -----------------------------
    # ACTIVE COMPARATORS
    # (no fallback requested, so keep as-is)
    # -----------------------------
    ac_terms = parse_list(row.get("active_comparators_mechanism_mapped"))
    ac_tn    = parse_list(row.get("active_comparators_mechanism_primary_tree_numbers"))
    ac_sg_list.append(build_super_group_list(ac_terms, ac_tn))

    # -----------------------------
    # STANDARD OF CARE
    # (no fallback requested, so keep as-is)
    # -----------------------------
    soc_terms = parse_list(row.get("standard_of_care_mechanism_mapped"))
    soc_tn    = parse_list(row.get("standard_of_care_mechanism_primary_tree_numbers"))
    soc_sg_list.append(build_super_group_list(soc_terms, soc_tn))


# ---------------------------------------------------
# Insert new columns next to the primary_tree_numbers
# ---------------------------------------------------
def insert_after(df, col, newcol, values):
    cols = list(df.columns)
    idx = cols.index(col)
    df.insert(idx + 1, newcol, values)

insert_after(
    df,
    "investigational_products_mechanism_primary_tree_numbers",
    "investigational_products_mechanism_super_group",
    inv_sg_list,
)

insert_after(
    df,
    "active_comparators_mechanism_primary_tree_numbers",
    "active_comparators_mechanism_super_group",
    ac_sg_list,
)

insert_after(
    df,
    "standard_of_care_mechanism_primary_tree_numbers",
    "standard_of_care_mechanism_super_group",
    soc_sg_list,
)

# ---------------------------------------------------
# Save final CSV
# ---------------------------------------------------
df.to_csv(TRIALS_OUT, index=False)
print(f"✅ Wrote final trial-level file with super-groups → {TRIALS_OUT}")
print(df.head(5).to_markdown(index=False))

In [None]:
import ast
import pandas as pd
from pathlib import Path
from collections import Counter

BASE_DIR = Path("cache")
INPUT_PATH  = BASE_DIR / "trial_mechanism_with_super_groups.csv"
OUTPUT_PATH = BASE_DIR / "trial_super_group_distribution.csv"

df = pd.read_csv(INPUT_PATH)

def parse_list(x):
    """Parse list-like strings safely into Python lists."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    try:
        v = ast.literal_eval(x)
        return v if isinstance(v, list) else []
    except Exception:
        return []

super_group_list = []

for _, row in df.iterrows():
    inv_sg_list = parse_list(row.get("investigational_products_mechanism_super_group"))
    soc_sg_list = parse_list(row.get("standard_of_care_mechanism_super_group"))

    chosen = ""

    # Priority 1 — investigational product supergroup
    if inv_sg_list and isinstance(inv_sg_list, list) and inv_sg_list[0].strip():
        chosen = inv_sg_list[0].strip()
    # Priority 2 — fallback to SOC
    elif soc_sg_list and isinstance(soc_sg_list, list) and soc_sg_list[0].strip():
        chosen = soc_sg_list[0].strip()
    else:
        chosen = ""  # nothing found

    super_group_list.append(chosen)

# Count distribution
dist = Counter(super_group_list)

# Convert to DataFrame
dist_df = pd.DataFrame(
    [{"mechanism_super_group": sg, "count": count}
     for sg, count in dist.items() if sg]  # exclude empty
).sort_values("count", ascending=False)

# Save
dist_df.to_csv(OUTPUT_PATH, index=False)

print(f"✅ Saved distribution → {OUTPUT_PATH}")
print("Top categories:\n")
print(dist_df.head(20).to_markdown())

#### Output results

In [None]:
import ast
import pandas as pd
from pathlib import Path

BASE_DIR = Path("cache")

CLASS_PATH   = BASE_DIR / "trial_investigational_drugs_classifications.csv"
MECH_PATH    = BASE_DIR / "trial_mechanism_with_super_groups.csv"
OUTPUT_PATH  = BASE_DIR / "trial_results_table.csv"

# ---------------------------------
# Helpers
# ---------------------------------
def parse_listish(x):
    """Parse a list-like string into a Python list."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        return v if isinstance(v, list) else []
    except Exception:
        return []

def get_first_or_blank(x):
    lst = parse_listish(x)
    return lst[0] if lst else ""

# ---------------------------------
# Load inputs
# ---------------------------------
df_class = pd.read_csv(CLASS_PATH)
df_mech  = pd.read_csv(MECH_PATH)

# ---------------------------------
# Prepare classification info
# ---------------------------------
df_class["drug_name"]   = df_class["investigational_products"].apply(get_first_or_blank)
df_class["innovation"]  = df_class["investigational_products_classifications"].apply(get_first_or_blank)

# If you want to strictly limit to "Innovative" / "Biosimilar" and blank out others:
def normalize_innovation(val: str) -> str:
    v = (val or "").strip()
    if v in {"Innovative", "Biosimilar"}:
        return v
    return v  # or return "" if you want to hide "Generic"

df_class["innovation"] = df_class["innovation"].apply(normalize_innovation)

# ---------------------------------
# Prepare mechanism / category info
# ---------------------------------
# MOA: first investigational mechanism text
df_mech["moa"] = df_mech["investigational_products_mechanism"].apply(get_first_or_blank)

# Category: first non-empty investigational super-group;
# if empty, fall back to first non-empty SOC super-group
def pick_category(row):
    inv_list = parse_listish(row.get("investigational_products_mechanism_super_group"))
    soc_list = parse_listish(row.get("standard_of_care_mechanism_super_group"))
    for val in inv_list:
        if val:
            return val
    for val in soc_list:
        if val:
            return val
    return ""

df_mech["category"] = df_mech.apply(pick_category, axis=1)

# For convenience, also extract first drug name from the mech file (to sanity-check alignment)
df_mech["drug_name_mech"] = df_mech["investigational_products"].apply(get_first_or_blank)

# ---------------------------------
# Merge on trial_hash
# ---------------------------------
merged = pd.merge(
    df_mech[["trial_hash", "moa", "category"]],
    df_class[["trial_hash", "drug_name", "innovation"]],
    on="trial_hash",
    how="left",
)

# ---------------------------------
# Build final table
# ---------------------------------
# NOTE: we don't have a title column here, so we use trial_hash as a proxy.
# If you have a separate trials file with titles, you can join it here and replace trial_hash.
merged["trial_title"] = merged["trial_hash"]

results = merged[["trial_title", "drug_name", "moa", "innovation", "category"]].copy()

# Save
results.to_csv(OUTPUT_PATH, index=False)
print(f"✅ Saved results table → {OUTPUT_PATH}")
print(results.head(20).to_markdown(index=False))


In [None]:
# ---------------------------------
# Find rows with missing values
# ---------------------------------

# Treat "" as missing for easier filtering
cols_to_check = ["trial_title", "drug_name", "moa", "innovation", "category"]

def is_missing(x):
    return (pd.isna(x)) or (str(x).strip() == "")

mask_missing = results[cols_to_check].applymap(is_missing).any(axis=1)

missing_rows = results[mask_missing].copy()

print(f"Found {len(missing_rows)} rows with at least one missing value.")

# Save for debugging
MISSING_OUTPUT_PATH = BASE_DIR / "trial_results_table_missing_rows.csv"
missing_rows.to_csv(MISSING_OUTPUT_PATH, index=False)

print(f"❗ Missing rows saved to → {MISSING_OUTPUT_PATH}")
print(missing_rows.to_markdown(index=False))
