#### Data Pre-processing

In [36]:
"""
Load raw data from provided CSV file.
"""

import pandas as pd 

data = pd.read_csv("data/raw_trials.csv")

In [37]:
print(data.columns)

Index(['title', 'objective', 'outcome_details', 'phase',
       'primary_completion_date', 'primary_endpoints_reported_date',
       'prior_concurrent_therapy', 'start_date', 'study_design',
       'treatment_plan', 'record_type', 'patients_per_site_per_month',
       'primary_endpoint_json', 'other_endpoint_json', 'associated_cro_json',
       'notes_json', 'outcomes_json', 'patient_dispositions_json',
       'results_json', 'study_keywords_json', 'tags_json',
       'primary_drugs_tested_json', 'other_drugs_tested_json',
       'therapeutic_areas_json', 'bmt_other_drugs_tested_json',
       'bmt_primary_drugs_tested_json', 'ct_gov_listed_locations_json',
       'ct_gov_mesh_terms_json'],
      dtype='object')


In [38]:
print(data.isna().sum().to_markdown())
print("Shape:", data.shape)

|                                 |   0 |
|:--------------------------------|----:|
| title                           |   0 |
| objective                       |   3 |
| outcome_details                 | 146 |
| phase                           |   0 |
| primary_completion_date         |  61 |
| primary_endpoints_reported_date | 161 |
| prior_concurrent_therapy        | 184 |
| start_date                      |  45 |
| study_design                    |  16 |
| treatment_plan                  |   1 |
| record_type                     |   0 |
| patients_per_site_per_month     | 119 |
| primary_endpoint_json           |   0 |
| other_endpoint_json             |   0 |
| associated_cro_json             |   0 |
| notes_json                      |   0 |
| outcomes_json                   |   0 |
| patient_dispositions_json       |   0 |
| results_json                    |   0 |
| study_keywords_json             |   0 |
| tags_json                       |   0 |
| primary_drugs_tested_json       

In [39]:
"""
Generate a unique, deterministic trial hash for each clinical trial and save an
augmented CSV.

Inputs:
- CSV file: data/raw_trials.csv
    Must contain at least:
    • "title"
    • "start_date"
    • "phase"

Process:
- Load the raw trials into a DataFrame.
- For each row, build a small JSON payload from (title, start_date, phase).
- Compute an MD5 hash of the payload and prefix with "tid_" to form
  a deterministic trial identifier.
- Insert "trial_hash" as the first column.

Outputs:
- CSV written to:
      cache/data_preprocess/raw_trials_with_hash.csv
  containing all original columns plus the leading "trial_hash" column.
"""

import hashlib
import json
from pathlib import Path

import pandas as pd

# -------------------------------------------------
# CONFIG
# -------------------------------------------------

INPUT_PATH = Path("data/raw_trials.csv")
OUTPUT_PATH = Path("cache/data_preprocess/raw_trials_with_hash.csv")
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)

# -------------------------------------------------
# RUN
# -------------------------------------------------

print(f"Loading raw trials from: {INPUT_PATH}")
data = pd.read_csv(INPUT_PATH, dtype=str).fillna("")

print("Generating trial_hash values ...")

def make_trial_hash(row):
    """Deterministic hash for a trial based on stable fields."""
    payload = {
        "title": row.get("title", ""),
        "start_date": row.get("start_date", ""),
        "phase": row.get("phase", ""),
    }
    raw = json.dumps(payload, sort_keys=True, ensure_ascii=False)
    return "tid_" + hashlib.md5(raw.encode("utf-8")).hexdigest()

# Create trial_hash column
data["trial_hash"] = data.apply(make_trial_hash, axis=1)

# Move trial_hash to first column
cols = ["trial_hash"] + [c for c in data.columns if c != "trial_hash"]
data = data[cols]

print("Data columns:", data.columns)
print("Data shape:", data.shape)

# Export
data.to_csv(OUTPUT_PATH, index=False)
print(f"Saved to {OUTPUT_PATH}")

Loading raw trials from: data/raw_trials.csv
Generating trial_hash values ...
Data columns: Index(['trial_hash', 'title', 'objective', 'outcome_details', 'phase',
       'primary_completion_date', 'primary_endpoints_reported_date',
       'prior_concurrent_therapy', 'start_date', 'study_design',
       'treatment_plan', 'record_type', 'patients_per_site_per_month',
       'primary_endpoint_json', 'other_endpoint_json', 'associated_cro_json',
       'notes_json', 'outcomes_json', 'patient_dispositions_json',
       'results_json', 'study_keywords_json', 'tags_json',
       'primary_drugs_tested_json', 'other_drugs_tested_json',
       'therapeutic_areas_json', 'bmt_other_drugs_tested_json',
       'bmt_primary_drugs_tested_json', 'ct_gov_listed_locations_json',
       'ct_gov_mesh_terms_json'],
      dtype='object')
Data shape: (184, 29)
Saved to cache/data_preprocess/raw_trials_with_hash.csv


#### Task 1

In [40]:
"""
Use a chatbot to extract structured drug-role metadata for each clinical trial.

Inputs:
- `cache/data_preprocess/raw_trials_with_hash.csv`
    One row per trial, including:
    • trial_hash (unique ID)
    • title, objective, treatment_plan
    • *_drugs_tested_json fields
    • other structured or semi-structured metadata used to identify interventions.

Process:
- For each trial, build an LLM prompt using selected columns.
- Ask the model to identify all distinct interventions and classify them.
- For each drug:
    • Assign role (Investigational Product, Active Comparator, Placebo, SOC)
    • List alternative names / synonyms
    • Identify molecular target and mechanism (if known)
    • Assign tt_drug_id and bmt_drug_id only when matchable with high confidence.
- Runs in parallel using ThreadPoolExecutor.
- Skips writing output for trials that already have saved results.
- Tracks processed, skipped, LLM errors, and JSON-parse errors.

Outputs:
- Per-trial mapped interventions:
      `cache/task_1/trial_drug_roles/{trial_hash}.json`
- Per-trial log files (prompt + raw response + cost):
      `cache/task_1/trial_drug_roles_log/{trial_hash}.json`
- Aggregated master index of all mappings:
      `cache/task_1/trial_drug_roles_master.json`
"""

import re
import json
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
from services.openai_wrapper import OpenAIWrapper

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
BASE_DIR = Path("cache")

TRIALS_WITH_HASH_CSV = Path("cache/data_preprocess/raw_trials_with_hash.csv")

DRUG_ROLE_DIR = BASE_DIR / "task_1" / "trial_drug_roles"
DRUG_ROLE_DIR.mkdir(parents=True, exist_ok=True)

DRUG_ROLE_LOG_DIR = BASE_DIR / "task_1" / "trial_drug_roles_log"
DRUG_ROLE_LOG_DIR.mkdir(parents=True, exist_ok=True)

MASTER_ROLES_PATH = BASE_DIR / "task_1" / "trial_drug_roles_master.json"

MODEL = "gpt-5"
client = OpenAIWrapper()

MAX_WORKERS = 8

RELEVANT_COLS = [
    "title",
    "objective",
    "outcome_details",
    "treatment_plan",
    "notes_json",
    "results_json",
    "primary_drugs_tested_json",
    "other_drugs_tested_json",
    "therapeutic_areas_json",
    "bmt_other_drugs_tested_json",
    "bmt_primary_drugs_tested_json",
    "ct_gov_mesh_terms_json",
]

# -------------------------------------------------
# Helpers
# -------------------------------------------------
def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}
    return {}


def build_prompt(trial_payload: dict) -> str:
    """
    Build prompt asking the LLM to:
    - Extract drug names
    - Canonicalize names by removing company/manufacturer/location qualifiers
    - Deduplicate synonymous names
    - For each canonical drug, return a dict with:
        * role (Investigational Product / Placebo / Active Comparator / Standard of Care)
        * alternative_names (list)
        * molecular_target
        * mechanism
        * tt_drug_id (TrialTrove/PharmaProjects drugId as string)
        * bmt_drug_id (BioMedTracker bmtDrugId as string)
    """
    payload_json = json.dumps(trial_payload, ensure_ascii=False, indent=2)

    return f"""
You are a clinical trial design and interpretation expert.

You are given structured information about a clinical trial, including:
- Title and objective
- Study design and treatment plan
- JSON fields listing drugs tested in the study:
  - primary_drugs_tested_json
  - other_drugs_tested_json
  - bmt_other_drugs_tested_json
  - bmt_primary_drugs_tested_json
- These JSON fields may also contain metadata such as
  drugApprovalStatus (Approved / Unapproved), mechanisms, synonyms, etc.
- In the TrialTrove/PharmaProjects JSON blocks, the unique drug identifier
  is usually under a key like "drugId".
- In the BioMedTracker JSON blocks, the unique drug identifier
  is usually under a key like "bmtDrugId".

Your tasks:

1. Identify all DISTINCT physical drug entities explicitly used in the study.
   - Strings in the *_drugs_tested_json fields are drug-name candidates.
   - If these fields contain structured JSON, infer names from keys such as
     "name", "drug_name", "drugName", "drugPrimaryName", "preferred_name",
     "label", etc.

2. Canonicalize each drug name:
   Remove company names, manufacturer qualifiers, geographic qualifiers,
   dosage-form qualifiers, or parenthetical descriptors that do NOT change
   the name of the underlying drug.

   Examples:
   - "AlphaBlocker (CompanyX)" → "AlphaBlocker"
   - "Recombinant Growth Factor (rgf)" → "Recombinant Growth Factor"
   - "DrugX citrate (RegionY)" → "DrugX citrate"
   - "BrandName (compound-42, MakerCorp)" → "BrandName"

3. Deduplicate synonymous names referring to the SAME drug.
   - Prefer the simplest, most standard canonical name.
   - Collect all other variations under alternative_names.

4. For EACH distinct drug, build an object with SIX fields:

   • "role": one of:
       - "Investigational Product"
       - "Placebo"
       - "Active Comparator"
       - "Standard of Care"

     ROLE ASSIGNMENT RULES (SUMMARY):

     - "Investigational Product":
       * Sponsor's novel or proprietary product, or a regimen whose key component
         is a novel/proprietary or clearly unapproved pipeline agent.
       * The trial objective explicitly focuses on evaluating this new product
         for safety/efficacy, dose-finding, first-in-human, or proof-of-concept.

     - "Standard of Care":
       * Approved, widely used therapies or regimens that represent background,
         conventional, or standard treatment.
       * IMPORTANT: If ALL drugs in a regimen are already approved therapies
         and the trial is mainly about treatment strategy, regimen choice,
         imaging-guided regimen selection, dosing/scheduling, or algorithm
         optimization (rather than developing a NEW drug entity), then:
             → Classify ALL actively dosed drugs as "Standard of Care".
             → Do NOT create any "Investigational Product" entry for that trial.
       * This includes combinations of standard agents (e.g., docetaxel,
         cisplatin, cyclophosphamide, ifosfamide, paclitaxel, fluorouracil,
         and similar approved drugs) when no new molecular entity is being tested.

     - "Active Comparator":
       * A non-placebo comparator arm explicitly contrasted with another
         investigational product in the protocol (e.g., "experimental vs
         active control").

     - "Placebo":
       * Inert control preparations.

   • "alternative_names": list of synonymous variants.
   • "molecular_target": e.g., "CD20", "VEGF-A". If unknown or not disclosed, use "".
   • "mechanism": e.g., "EGFR inhibitor", "Anti-PD-1 antibody", "JAK inhibitor".
       - If the ONLY information is that it is a "small molecule", "biologic",
         "small molecule; mechanism of action not identified", "mechanism unknown",
         "not identified", "not determined", "not disclosed", or similar,
         then treat the mechanism as UNKNOWN and set "mechanism" = "".
       - Do NOT copy meta-statements such as:
           • "Small molecule; mechanism of action not identified"
           • "Mechanism of action unknown/not identified/not yet determined"
         into the mechanism field. Use "" instead in these cases.
   • "tt_drug_id": STRING. If not confidently matchable, "".
   • "bmt_drug_id": STRING. If not confidently matchable, "".

5. ID MISMATCH SAFETY RULE:
   - Do NOT assign tt_drug_id or bmt_drug_id if they clearly belong to a different drug
     (different target, mechanism, indication, modality, or obviously mismatched name).
   - If there is ANY doubt about the correctness of an ID:
       → Set BOTH "tt_drug_id" and "bmt_drug_id" to "".

Output format (IMPORTANT):
Return ONLY a valid JSON object with:
  - keys   = canonical drug names
  - values = objects with EXACTLY:
        * "role"
        * "alternative_names"
        * "molecular_target"
        * "mechanism"
        * "tt_drug_id"
        * "bmt_drug_id"

Example:
{{
  "ABC-123": {{
    "role": "Investigational Product",
    "alternative_names": ["ABC123", "Compound-ABC"],
    "molecular_target": "Receptor-Z",
    "mechanism": "Bispecific antibody",
    "tt_drug_id": "123456",
    "bmt_drug_id": "78901"
  }},
  "DrugX": {{
    "role": "Standard of Care",
    "alternative_names": ["GenericX", "ChemX"],
    "molecular_target": "Enzyme-A",
    "mechanism": "Antimetabolite",
    "tt_drug_id": "",
    "bmt_drug_id": ""
  }}
}}

Input JSON:
{payload_json}
""".strip()


counter = {"processed": 0, "skipped_existing": 0, "llm_error": 0, "parse_error": 0}
counter_lock = threading.Lock()

master_roles: dict[str, dict] = {}
master_lock = threading.Lock()


def process_trial(row: dict, idx: int, total: int) -> None:
    """Process one trial: call LLM, validate output, save role JSON and log."""
    trial_hash = str(row.get("trial_hash", "")).strip()
    if not trial_hash:
        print(f"[{idx}/{total}] Missing trial_hash, skipping")
        return

    out_fp = DRUG_ROLE_DIR / f"{trial_hash}.json"
    if out_fp.exists():
        with counter_lock:
            counter["skipped_existing"] += 1
        return

    trial_payload = {"trial_hash": trial_hash}
    for col in RELEVANT_COLS:
        trial_payload[col] = row.get(col, "")

    prompt = build_prompt(trial_payload)

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    try:
        t0 = time.perf_counter()
        res = client.query(prompt=prompt, model=MODEL)
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)
    except Exception as e:
        print(f"LLM error for {trial_hash}: {e}")
        with counter_lock:
            counter["llm_error"] += 1
        return

    drug_roles = extract_json_object(text_response)
    if not isinstance(drug_roles, dict) or not drug_roles:
        print(f"JSON parse error trial_hash={trial_hash}")
        with counter_lock:
            counter["parse_error"] += 1
        return

    mapped = {
        "trial_hash": trial_hash,
        "title": row.get("title"),
        "drug_roles": drug_roles,
        "source": "llm",
    }

    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    log_payload = {
        "token": trial_hash,
        "hash_id": trial_hash,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (DRUG_ROLE_LOG_DIR / f"{trial_hash}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2), encoding="utf-8"
    )

    with master_lock:
        master_roles[trial_hash] = mapped
        MASTER_ROLES_PATH.write_text(
            json.dumps(master_roles, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )

    with counter_lock:
        counter["processed"] += 1
        if counter["processed"] % 50 == 0:
            print(f"Processed {counter['processed']} trials...")


# -------------------------------------------------
# RUN
# -------------------------------------------------
df_trials = pd.read_csv(TRIALS_WITH_HASH_CSV, dtype=str).fillna("")
rows = df_trials.to_dict(orient="records")
total_trials = len(rows)
print(f"Loaded {total_trials} trials from {TRIALS_WITH_HASH_CSV}")

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
    futures = {
        ex.submit(process_trial, row, idx, total_trials): row.get("trial_hash")
        for idx, row in enumerate(rows, start=1)
    }
    for fut in as_completed(futures):
        try:
            fut.result()
        except Exception as e:
            print(f"Worker error: {e}")

print(
    f"Complete. processed={counter['processed']}, "
    f"skipped={counter['skipped_existing']}, "
    f"llm_error={counter['llm_error']}, "
    f"parse_error={counter['parse_error']}"
)
print(f"Roles directory: {DRUG_ROLE_DIR}")
print(f"Log directory:   {DRUG_ROLE_LOG_DIR}")
print(f"Master roles:    {MASTER_ROLES_PATH}")

Loaded 184 trials from cache/data_preprocess/raw_trials_with_hash.csv
Complete. processed=1, skipped=183, llm_error=0, parse_error=0
Roles directory: cache/task_1/trial_drug_roles
Log directory:   cache/task_1/trial_drug_roles_log
Master roles:    cache/task_1/trial_drug_roles_master.json


In [41]:
"""
Summarize total LLM usage cost for previous cell by reading all per-trial log files.

Inputs:
- Directory: cache/task_1/trial_drug_roles_log/
    Each log JSON contains:
        • total_cost (float)
        • other metadata (prompt, raw response, timing, etc.)

Process:
- Load each log file and extract its total_cost value.
- Aggregate total cost, count entries, and compute average cost per trial.
- Sort trials by cost to identify the most expensive prompts.

Outputs:
- Console summary including:
    • Total cost
    • Number of logged trials
    • Average cost per trial
    • Top 10 highest-cost trials (filename + cost)
"""

import json
from pathlib import Path

LOG_DIR = Path("cache/task_1/trial_drug_roles_log")

total_cost = 0.0
num_entries = 0
costs = []

for fp in LOG_DIR.glob("*.json"):
    try:
        log = json.loads(fp.read_text(encoding="utf-8"))
        c = float(log.get("total_cost") or 0.0)
        total_cost += c
        costs.append((fp.name, c))
        num_entries += 1
    except Exception as e:
        print(f"Error reading {fp.name}: {e}")

# Sort descending by cost
costs_sorted = sorted(costs, key=lambda x: x[1], reverse=True)

print("========== LLM COST SUMMARY ==========")
print(f"Total LLM cost:             ${total_cost:,.4f}")
print(f"Number of logged trials:     {num_entries}")
if num_entries > 0:
    print(f"Average cost per trial:      ${total_cost / num_entries:,.4f}")
print("")

print("Top 10 most expensive trials:")
for name, c in costs_sorted[:10]:
    print(f"  {name}: ${c:,.4f}")

print("========================================")

Total LLM cost:             $4.3950
Number of logged trials:     184
Average cost per trial:      $0.0239

Top 10 most expensive trials:
  tid_1158b3369546dc4b16dc21c8c026b619.json: $0.0546
  tid_28a767e788d4d9a4e65b3c10d10585c2.json: $0.0546
  tid_763e3011bc90e46c88c7a2953a39ed2a.json: $0.0518
  tid_6e821d7fbd8539bae7baf3a668d6d080.json: $0.0462
  tid_d372a5464ccae4cf39f41537506a78c0.json: $0.0454
  tid_ff64edc14f04fb1d81451cc7475488fe.json: $0.0448
  tid_8b4d60a5fddc078962af34399d7e342c.json: $0.0436
  tid_e0a77c4ecf93cf781f04cc467c974511.json: $0.0428
  tid_7e80effdd579ba535ef686ac50dcc4bc.json: $0.0416
  tid_837737698a5271d314ea8208addb2d72.json: $0.0410


In [42]:
"""
Aggregate per-trial drug-role JSONs into a wide trial-level product breakdown CSV.

Inputs:
- Directory: cache/task_1/trial_drug_roles/
    Each file: {trial_hash}.json with structure:
        {
          "trial_hash": "<tid_...>",
          "title": "...",
          "drug_roles": {
            "<drug_name>": {
              "role": "Investigational Product" | "Active Comparator" | "Placebo" | "Standard of Care",
              "alternative_names": [...],
              "molecular_target": "...",
              "mechanism": "...",
              "tt_drug_id": "...",
              "bmt_drug_id": "..."
            },
            ...
          }
        }

Process:
- Iterate over all JSONs in cache/task_1/trial_drug_roles/.
- For each trial:
    • Partition drugs into four buckets: investigational, active comparator, placebo, standard of care.
    • Collect, per role:
        - canonical names
        - alternative_names (as list-of-lists)
        - molecular_target
        - mechanism
        - tt_drug_id / bmt_drug_id where applicable.
- Build one row per trial with list-valued columns for each role.

Outputs:
- CSV: cache/task_1/trial_product_breakdown.csv
    One row per trial, columns:
        trial_hash
        investigational_products, investigational_products_alternative_names, ...
        active_comparators, ...
        placebos, ...
        standard_of_care, ...
"""

import json
from pathlib import Path

import pandas as pd

# -------------------------------------------------
# CONFIG
# -------------------------------------------------

# Base directory for task_1 cache + input/output
BASE_DIR = Path("cache/task_1")

# Directory that contains per-trial drug-role JSONs
DRUG_ROLE_DIR = BASE_DIR / "trial_drug_roles"

# Output CSV path
OUT_CSV = BASE_DIR / "trial_product_breakdown.csv"

# -------------------------------------------------
# RUN
# -------------------------------------------------

rows = []

for fp in DRUG_ROLE_DIR.glob("*.json"):
    try:
        obj = json.loads(fp.read_text(encoding="utf-8"))
    except Exception as e:
        print(f"Error reading {fp.name}: {e}")
        continue

    trial_hash = obj.get("trial_hash")
    if not trial_hash:
        print(f"Missing trial_hash in {fp.name}, skipping")
        continue

    drug_roles = obj.get("drug_roles") or {}
    if not isinstance(drug_roles, dict):
        print(f"drug_roles not dict in {fp.name}, skipping")
        continue

    # Containers
    inv_names = []
    inv_alt_names = []          # list of lists
    inv_targets = []
    inv_mechanisms = []
    inv_tt_ids = []
    inv_bmt_ids = []

    ac_names = []
    ac_alt_names = []           # list of lists
    ac_targets = []
    ac_mechanisms = []
    ac_tt_ids = []
    ac_bmt_ids = []

    plc_names = []
    plc_alt_names = []          # list of lists
    plc_targets = []
    plc_mechanisms = []

    soc_names = []
    soc_alt_names = []          # list of lists
    soc_targets = []
    soc_mechanisms = []
    soc_tt_ids = []
    soc_bmt_ids = []

    for drug_name, meta in drug_roles.items():
        if not isinstance(meta, dict):
            continue

        role = (meta.get("role") or "").strip()
        role_norm = role.lower()

        alt_names = meta.get("alternative_names") or []
        if not isinstance(alt_names, list):
            alt_names = [str(alt_names)]

        molecular_target = meta.get("molecular_target") or ""
        mechanism = meta.get("mechanism") or ""

        # IDs are always stored as strings in the LLM output, but be defensive
        tt_id = str(meta.get("tt_drug_id") or "")
        bmt_id = str(meta.get("bmt_drug_id") or "")

        if role_norm == "investigational product":
            inv_names.append(drug_name)
            inv_alt_names.append(alt_names)
            inv_targets.append(molecular_target)
            inv_mechanisms.append(mechanism)
            inv_tt_ids.append(tt_id)
            inv_bmt_ids.append(bmt_id)

        elif role_norm == "active comparator":
            ac_names.append(drug_name)
            ac_alt_names.append(alt_names)
            ac_targets.append(molecular_target)
            ac_mechanisms.append(mechanism)
            ac_tt_ids.append(tt_id)
            ac_bmt_ids.append(bmt_id)

        elif role_norm == "placebo":
            plc_names.append(drug_name)
            plc_alt_names.append(alt_names)
            plc_targets.append(molecular_target)
            plc_mechanisms.append(mechanism)

        elif role_norm == "standard of care":
            soc_names.append(drug_name)
            soc_alt_names.append(alt_names)
            soc_targets.append(molecular_target)
            soc_mechanisms.append(mechanism)
            soc_tt_ids.append(tt_id)
            soc_bmt_ids.append(bmt_id)

    rows.append(
        {
            "trial_hash": trial_hash,

            "investigational_products": inv_names,
            "investigational_products_alternative_names": inv_alt_names,
            "investigational_products_molecular_target": inv_targets,
            "investigational_products_mechanism": inv_mechanisms,
            "investigational_products_tt_drug_id": inv_tt_ids,
            "investigational_products_bmt_drug_id": inv_bmt_ids,

            "active_comparators": ac_names,
            "active_comparators_alternative_names": ac_alt_names,
            "active_comparators_molecular_target": ac_targets,
            "active_comparators_mechanism": ac_mechanisms,
            "active_comparators_tt_drug_id": ac_tt_ids,
            "active_comparators_bmt_drug_id": ac_bmt_ids,

            "placebos": plc_names,
            "placebos_alternative_names": plc_alt_names,
            "placebos_molecular_target": plc_targets,
            "placebos_mechanism": plc_mechanisms,

            "standard_of_care": soc_names,
            "standard_of_care_alternative_names": soc_alt_names,
            "standard_of_care_molecular_target": soc_targets,
            "standard_of_care_mechanism": soc_mechanisms,
            "standard_of_care_tt_drug_id": soc_tt_ids,
            "standard_of_care_bmt_drug_id": soc_bmt_ids,
        }
    )

df_out = pd.DataFrame(rows).sort_values("trial_hash")

OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
df_out.to_csv(OUT_CSV, index=False)

print(f"Saved trial product breakdown to {OUT_CSV}")
print(df_out.head().to_markdown())

Saved trial product breakdown to cache/task_1/trial_product_breakdown.csv
|     | trial_hash                           | investigational_products                           | investigational_products_alternative_names                                                                                                                                                                                                          | investigational_products_molecular_target   | investigational_products_mechanism                                         | investigational_products_tt_drug_id   | investigational_products_bmt_drug_id   | active_comparators   | active_comparators_alternative_names   | active_comparators_molecular_target   | active_comparators_mechanism   | active_comparators_tt_drug_id   | active_comparators_bmt_drug_id   | placebos   | placebos_alternative_names   | placebos_molecular_target   | placebos_mechanism   | standard_of_care   | standard_of_care_alternative_names   | standard_of_c

In [43]:
"""
Identify trials with no investigational products.

Purpose:
- Load the trial_product_breakdown.csv file.
- Parse the stringified list column "investigational_products" into real Python lists.
- Flag all trials where the parsed list is empty (i.e., no investigational product identified).
- Print summary statistics and display rows missing investigational products.

Inputs:
- CSV: cache/task_1/trial_product_breakdown.csv

Outputs:
- Console summary of how many trials lack investigational products.
- Markdown preview of example rows with missing investigational products.
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
import ast
import pandas as pd
from pathlib import Path

BASE_DIR = Path("cache/task_1")
IN_CSV = BASE_DIR / "trial_product_breakdown.csv"

# -------------------------------------------------
# RUN
# -------------------------------------------------

df = pd.read_csv(IN_CSV, dtype=str).fillna("")

def parse_listish(s: str):
    """
    Parse a stringified list like "['A', 'B']" into a Python list.
    If parsing fails or the cell is empty, return [].
    """
    if not isinstance(s, str):
        return []
    s = s.strip()
    if not s or s in ("[]", "[ ]"):
        return []
    try:
        val = ast.literal_eval(s)
        if isinstance(val, list):
            return val
        return [val]
    except Exception:
        return [s]

# Parse the investigational_products column into real lists
df["investigational_products_parsed"] = df["investigational_products"].apply(parse_listish)

# Flag rows with no investigational products
no_inv_mask = df["investigational_products_parsed"].apply(lambda x: len(x) == 0)

num_no_inv = int(no_inv_mask.sum())
total = len(df)

print(f"Rows with NO investigational products: {num_no_inv} / {total}")

print(
    df.loc[no_inv_mask, ["trial_hash", "investigational_products"]]
      .head(20)
      .to_markdown(index=False)
)

Rows with NO investigational products: 5 / 184
| trial_hash                           | investigational_products   |
|:-------------------------------------|:---------------------------|
| tid_4c45730f6411aa1e5a38bb1223d66988 | []                         |
| tid_67de51bf9728e056a6fb42c76e4b0212 | []                         |
| tid_8cab7b7177fcb0d10255bced8b0633ee | []                         |
| tid_bb1e0571142dde8a49976632c349593c | []                         |
| tid_e9e01f51b6680ba4f467ac191bb307c5 | []                         |


Manual checks 
- tid_4c45730f6411aa1e5a38bb1223d66988
    - This trial is combining three standard-of-care agents into a regimen “DCF”
- tid_67de51bf9728e056a6fb42c76e4b0212
    - Even though they administer Yisaipu in a structured way, it is an approved drug and not being tested for regulatory approval.
- tid_8cab7b7177fcb0d10255bced8b0633ee
    - The trial is studying treatment strategies, regimens, algorithms, imaging-guided regimen selection, or dosing, using only approved standard therapies.
- tid_bb1e0571142dde8a49976632c349593c
    - The trial's focus is on optimizing regimen selection (e.g., TIPy or TCbIPy) via imaging, rather than testing a new drug entity.
- tid_e9e01f51b6680ba4f467ac191bb307c5
    - All drugs are approved, marketed standard-of-care therapies that the sponsor does not own

these are all confirmed generics biosimilars

In [44]:
"""
Aggregate trial-level product data by TrialTrove tt_drug_id and identify products
missing both molecular targets and mechanisms.

Inputs:
- CSV: cache/task_1/trial_product_breakdown.csv
    Contains per-trial lists of products and their tt_drug_id, targets, and mechanisms.

Process:
- Parse list-like columns from strings into Python lists.
- Aggregate across all trials, keyed by tt_drug_id, collecting:
    • drug_names
    • alternative_names
    • molecular_targets
    • product_mechanisms
    • trial_hashes where each product appears.
- Build a product-level master table (one row per tt_drug_id).
- Identify tt_drug_id entries that are missing BOTH molecular_targets
  and product_mechanisms.

Outputs:
- Product master table:
      cache/task_1/product_id_master_by_tt.csv
- Table of products missing both targets and mechanisms:
      cache/task_1/product_id_missing_targets_or_mechs.csv
- Console preview of the first 10 aggregated rows and all missing rows.
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------

import ast
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache/task_1")
IN_CSV = BASE_DIR / "trial_product_breakdown.csv"
OUT_AGG = BASE_DIR / "product_id_master_by_tt.csv"
OUT_MISSING = BASE_DIR / "product_id_missing_targets_or_mechs.csv"

# -------------------------------------------------
# HELPERS
# -------------------------------------------------

def parse_listish(x):
    """Parse a list-like string (e.g. "['a','b']") into a Python list."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        return v if isinstance(v, list) else []
    except Exception:
        return []

# -------------------------------------------------
# RUN
# -------------------------------------------------

# Load trial-level breakdown
df = pd.read_csv(IN_CSV, dtype=str).fillna("")

# Aggregate everything keyed by tt_drug_id
agg = {}  # tt_id -> {"names": set(), "alt_names": set(), "targets": set(), "mechs": set(), "trials": set()}

ROLE_PAIRS = [
    ("investigational_products", "investigational_products_tt_drug_id"),
    ("active_comparators", "active_comparators_tt_drug_id"),
    ("standard_of_care", "standard_of_care_tt_drug_id"),
]

for _, row in df.iterrows():
    trial_hash = str(row.get("trial_hash", "")).strip()

    for base_col, tt_col in ROLE_PAIRS:
        # aligned lists
        names_list   = parse_listish(row.get(base_col, ""))
        alts_list    = parse_listish(row.get(f"{base_col}_alternative_names", ""))
        targets_list = parse_listish(row.get(f"{base_col}_molecular_target", ""))
        mechs_list   = parse_listish(row.get(f"{base_col}_mechanism", ""))
        tt_ids       = parse_listish(row.get(tt_col, ""))

        # iterate by index over tt_ids (they define the products)
        for i, raw_tt in enumerate(tt_ids):
            tt_id = str(raw_tt).strip()
            if not tt_id:
                continue

            # init aggregate bucket if needed
            if tt_id not in agg:
                agg[tt_id] = {
                    "names": set(),
                    "alt_names": set(),
                    "targets": set(),
                    "mechs": set(),
                    "trials": set(),
                }

            # record trial hash if available
            if trial_hash:
                agg[tt_id]["trials"].add(trial_hash)

            # name
            if i < len(names_list):
                name = str(names_list[i]).strip()
                if name:
                    agg[tt_id]["names"].add(name)

            # alternative names (may be nested lists)
            if i < len(alts_list):
                alt_entry = alts_list[i]
                if isinstance(alt_entry, list):
                    for a in alt_entry:
                        a_str = str(a).strip()
                        if a_str:
                            agg[tt_id]["alt_names"].add(a_str)
                else:
                    a_str = str(alt_entry).strip()
                    if a_str:
                        agg[tt_id]["alt_names"].add(a_str)

            # target
            if i < len(targets_list):
                tgt = str(targets_list[i]).strip()
                if tgt:
                    agg[tt_id]["targets"].add(tgt)

            # mechanism
            if i < len(mechs_list):
                mech = str(mechs_list[i]).strip()
                if mech:
                    agg[tt_id]["mechs"].add(mech)

# Build aggregated DataFrame
rows_out = []
for tt_id, payload in agg.items():
    rows_out.append(
        {
            "tt_drug_id": tt_id,
            "drug_names": sorted(payload["names"]),
            "alternative_names": sorted(payload["alt_names"]),
            "molecular_targets": sorted(payload["targets"]),
            "product_mechanisms": sorted(payload["mechs"]),
            "trial_hashes": sorted(payload["trials"]),
        }
    )

grouped_df = pd.DataFrame(rows_out).sort_values("tt_drug_id")

print("Aggregated by tt_drug_id (first 10 rows):")
print(grouped_df.head(10).to_markdown(index=False))

grouped_df.to_csv(OUT_AGG, index=False)
print(f"Saved aggregated tt_drug_id table → {OUT_AGG}")

# Identify rows missing both targets AND mechanisms
missing_mask = grouped_df["molecular_targets"].apply(lambda x: len(x) == 0) & \
               grouped_df["product_mechanisms"].apply(lambda x: len(x) == 0)

missing_df = grouped_df[missing_mask].copy()

print("Rows missing molecular_targets AND product_mechanisms:")
if missing_df.empty:
    print("No missing values — every tt_drug_id has at least one target or mechanism.")
else:
    print(missing_df.to_markdown(index=False))

missing_df.to_csv(OUT_MISSING, index=False)
print(f"Saved → {OUT_MISSING}")

Aggregated by tt_drug_id (first 10 rows):
|   tt_drug_id | drug_names                           | alternative_names                                                                                                                                                                                                                                                                                                                         | molecular_targets                                                                                                                                   | product_mechanisms                                                                                                                                 | trial_hashes                                                                                                                                                                                                                                     |
|-------------:|:-----------

In [45]:
"""
Infer molecular targets and mechanisms of action for products missing both fields,
using LLM + web search on trial context.

Inputs:
- Product-level CSV of products missing targets/mechanisms:
      cache/task_1/product_id_missing_targets_or_mechs.csv
    (one row per tt_drug_id; columns include tt_drug_id, drug_names,
     alternative_names, trial_hashes, etc.)
- Trial metadata CSV:
      cache/data_preprocess/raw_trials_with_hash.csv
    (one row per trial; must include trial_hash and core text/JSON fields).

Process:
- Load the "missing products" table from disk.
- Build an index of trial_hash → full trial metadata.
- For each tt_drug_id in the missing-products table:
    • Parse trial_hashes and gather associated trials.
    • Build a prompt with all known drug names and trial context.
    • Call the LLM with web_search tools to infer:
        - molecular_target
        - mechanism
      If the target/mechanism is not publicly disclosed, both fields should be "".
- Save a per-product JSON file and a structured log.
- Maintain a rolling master JSON of all inferred product mechanisms.

Outputs:
- Per-product mechanism JSON:
      cache/task_1/product_mechanism_inference/{tt_drug_id}.json
- Per-product log JSON:
      cache/task_1/product_mechanism_inference_log/{tt_drug_id}.json
- Master mapping:
      cache/task_1/product_mechanism_inference_master.json
- Console summary of processed / skipped / error counts.
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
import ast
import re
import json
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
from services.openai_wrapper import OpenAIWrapper

# Base dir for this task's outputs
BASE_DIR = Path("cache/task_1")

# Trial metadata (shared across tasks)
RAW_TRIALS_CSV = Path("cache/data_preprocess/raw_trials_with_hash.csv")

# Missing-products table (from previous aggregation cell)
MISSING_PRODUCTS_CSV = BASE_DIR / "product_id_missing_targets_or_mechs.csv"

# Output dirs/files for product mechanism inference
PRODUCT_MECH_DIR = BASE_DIR / "product_mechanism_inference"
PRODUCT_MECH_DIR.mkdir(parents=True, exist_ok=True)

PRODUCT_MECH_LOG_DIR = BASE_DIR / "product_mechanism_inference_log"
PRODUCT_MECH_LOG_DIR.mkdir(parents=True, exist_ok=True)

MASTER_PRODUCT_MECH_PATH = BASE_DIR / "product_mechanism_inference_master.json"

MODEL = "gpt-5"
client = OpenAIWrapper()

MAX_WORKERS = 8

# -------------------------------------------------
# HELPERS
# -------------------------------------------------
def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}

    # Direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    # Fallback: first {...} region
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}

    return {}


def safe_parse_listish(val):
    """
    Parse list-like strings back into Python lists, if needed.
    If already a list, return as-is.
    """
    if isinstance(val, list):
        return val
    if val is None:
        return []
    s = str(val).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
        return [v]
    except Exception:
        return [s]


def build_product_prompt(row: dict, trial_context: list[dict]) -> str:
    """
    Build prompt asking the LLM (with web_search) to infer
    molecular_target and mechanism for a product, based on:
      - drug names / alternative names
      - full metadata for associated trials
    """
    drug_names = safe_parse_listish(row.get("drug_names", [])) or []

    # Ensure lists are JSON-serializable
    try:
        drug_names_json = json.dumps(drug_names, ensure_ascii=False)
    except TypeError:
        drug_names_json = json.dumps([str(x) for x in drug_names], ensure_ascii=False)

    # Trials context JSON
    trials_json = json.dumps(trial_context, ensure_ascii=False, indent=2)

    return f"""
You are a pharmacology expert with access to web search.

You are given:
- A drug name (and aliases).
- Full metadata for one or more clinical trials in which this drug appears (JSON objects).

Your goal:
Using web search and your domain knowledge, determine:
1. The primary molecular target(s) of the drug (e.g., EGFR, VEGFR2, TNF, CD20, JAK1/2).
2. A concise, standard mechanism of action label (e.g., "EGFR inhibitor", "Anti-PD-1 antibody", 
   "JAK inhibitor", "DNA-damaging cytotoxic", etc.).

Rules:
- Try searching for the drug using all known names or aliases.
- If no molecular target or mechanism of action has been publicly disclosed,
  then return empty strings for BOTH fields.

INPUT
-----
drug_name: {drug_names_json}
trial_metadata:
{trials_json}

OUTPUT (JSON only)
------------------
{{
  "molecular_target": "",
  "mechanism": ""
}}
""".strip()


# Shared counters & master mapping
product_counter = {
    "processed": 0,
    "skipped_existing": 0,
    "llm_error": 0,
    "parse_error": 0,
}
product_counter_lock = threading.Lock()

product_master: dict[str, dict] = {}
product_master_lock = threading.Lock()

# Load existing master if present
if MASTER_PRODUCT_MECH_PATH.exists():
    try:
        product_master = json.loads(MASTER_PRODUCT_MECH_PATH.read_text(encoding="utf-8"))
    except Exception:
        product_master = {}

# Load trial metadata and build index
df_trials = pd.read_csv(RAW_TRIALS_CSV, dtype=str).fillna("")
trials_index: dict[str, dict] = {
    str(row["trial_hash"]).strip(): row.to_dict()
    for _, row in df_trials.iterrows()
}


def process_product(row: dict, idx: int, total: int) -> None:
    """Process one tt_drug_id: call LLM+web_search with trial context, save output & log."""
    tt_drug_id = str(row.get("tt_drug_id", "")).strip()
    if not tt_drug_id:
        print(f"[{idx}/{total}] Missing tt_drug_id, skipping")
        return

    out_fp = PRODUCT_MECH_DIR / f"{tt_drug_id}.json"
    if out_fp.exists():
        with product_counter_lock:
            product_counter["skipped_existing"] += 1
        return

    # Get associated trial hashes and build trial context list
    trial_hashes_raw = row.get("trial_hashes", [])
    trial_hashes = safe_parse_listish(trial_hashes_raw)

    trial_context = []
    for th in trial_hashes:
        th_key = str(th).strip()
        if not th_key:
            continue
        trial_row = trials_index.get(th_key)
        if trial_row:
            trial_context.append(trial_row)

    prompt = build_product_prompt(row, trial_context)

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    # Call LLM with web_search tool
    try:
        t0 = time.perf_counter()
        res = client.query(
            prompt=prompt,
            model=MODEL,
            tools=[{"type": "web_search"}],
        )
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)

    except Exception as e:
        print(f"[{idx}/{total}] LLM error for tt_drug_id={tt_drug_id}: {e}")
        with product_counter_lock:
            product_counter["llm_error"] += 1
        return

    mech_obj = extract_json_object(text_response)

    # Expect a dict with the two keys
    if not isinstance(mech_obj, dict) or not mech_obj:
        print(f"[{idx}/{total}] JSON parse/validity error tt_drug_id={tt_drug_id}, raw={text_response!r}")
        with product_counter_lock:
            product_counter["parse_error"] += 1
        return

    molecular_target = str(mech_obj.get("molecular_target", "") or "").strip()
    mechanism = str(mech_obj.get("mechanism", "") or "").strip()

    mapped = {
        "tt_drug_id": tt_drug_id,
        "drug_names": safe_parse_listish(row.get("drug_names", [])),
        "alternative_names": safe_parse_listish(row.get("alternative_names", [])),
        "trial_hashes": trial_hashes,
        "molecular_target": molecular_target,
        "mechanism": mechanism,
        "source": "llm_web_search",
    }

    # Save per-product JSON
    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    # Log entry
    log_payload = {
        "tt_drug_id": tt_drug_id,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (PRODUCT_MECH_LOG_DIR / f"{tt_drug_id}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    # Update master
    with product_master_lock:
        product_master[tt_drug_id] = mapped
        MASTER_PRODUCT_MECH_PATH.write_text(
            json.dumps(product_master, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )

    with product_counter_lock:
        product_counter["processed"] += 1
        if product_counter["processed"] % 50 == 0:
            print(f"Progress: processed {product_counter['processed']} products...")


# -------------------------------------------------
# RUN
# -------------------------------------------------
missing_df = pd.read_csv(MISSING_PRODUCTS_CSV, dtype=str).fillna("")
missing_rows = missing_df.to_dict(orient="records")
total_missing = len(missing_rows)
print(f"Loaded {total_missing} products missing targets/mechanisms from {MISSING_PRODUCTS_CSV}")

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
    futures = {
        ex.submit(process_product, row, idx, total_missing): row.get("tt_drug_id")
        for idx, row in enumerate(missing_rows, start=1)
    }
    for fut in as_completed(futures):
        tid = futures[fut]
        try:
            fut.result()
        except Exception as e:
            print(f"Worker error tt_drug_id={tid}: {e}")

print(
    f"Product mechanism inference complete. "
    f"processed={product_counter['processed']}, "
    f"skipped={product_counter['skipped_existing']}, "
    f"llm_error={product_counter['llm_error']}, "
    f"parse_error={product_counter['parse_error']}"
)
print(f"Per-product directory: {PRODUCT_MECH_DIR}")
print(f"Log directory:        {PRODUCT_MECH_LOG_DIR}")
print(f"Master file:          {MASTER_PRODUCT_MECH_PATH}")

Loaded 3 products missing targets/mechanisms from cache/task_1/product_id_missing_targets_or_mechs.csv
Product mechanism inference complete. processed=0, skipped=3, llm_error=0, parse_error=0
Per-product directory: cache/task_1/product_mechanism_inference
Log directory:        cache/task_1/product_mechanism_inference_log
Master file:          cache/task_1/product_mechanism_inference_master.json


In [46]:
"""
Build a did-keyed master product dictionary by merging trial-level product info
with (optional) LLM-inferred mechanisms.

Inputs:
- Trial product breakdown CSV:
      cache/task_1/trial_product_breakdown.csv
  Contains, per trial_hash:
      • investigational_products / active_comparators / standard_of_care
      • *_alternative_names
      • *_molecular_target
      • *_mechanism
      • *_tt_drug_id
- Optional LLM mechanism master:
      cache/task_1/product_mechanism_inference_master.json
  Maps tt_drug_id → inferred molecular_target and mechanism.

Process:
- For each role (investigational_products, active_comparators, standard_of_care):
    • Parse list-like columns into Python lists.
    • Aggregate by tt_drug_id:
        - collect all names, alt_names, targets, mechanisms, trial_hashes.
        - fill missing targets/mechanisms from product_mechanism_inference_master.json
    • For entries without tt_drug_id but with target/mechanism:
        - create synthetic product entries keyed by (role, name, target, mechanism).
- Generate a stable deterministic "did_*" ID:
    • For known tt_drug_id: hash of tt_drug_id.
    • For unknown products: hash of composite key (role + data).

Outputs:
- did-keyed JSON:
      cache/task_1/product_id_master_by_did.json
  Structure:
      {
        "did_<hash>": {
          "did": "did_<hash>",
          "tt_drug_id": "<tt_id or ''>",
          "drug_names": [...],
          "alternative_names": [...],
          "molecular_targets": [...],
          "product_mechanisms": [...],
          "trial_hashes": [...]
        },
        ...
      }
"""

# ----------------------------------------
# CONFIG
# ----------------------------------------
import ast
import json
import hashlib
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache/task_1")
IN_BREAKDOWN_CSV = BASE_DIR / "trial_product_breakdown.csv"
MASTER_PRODUCT_MECH_PATH = BASE_DIR / "product_mechanism_inference_master.json"
OUT_JSON = BASE_DIR / "product_id_master_by_did.json"

# ----------------------------------------
# HELPERS
# ----------------------------------------
def parse_listish(x):
    """Parse a list-like string (e.g. "['a','b']") into a Python list."""
    if isinstance(x, list):
        return x
    if x is None:
        return []
    s = str(x).strip()
    if not s or s in ("[]", "[ ]"):
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
        return [v]
    except Exception:
        return [s]


def pad_to_length(lst, n):
    """Pad list with empty strings so len(lst) >= n."""
    lst = list(lst)
    while len(lst) < n:
        lst.append("")
    return lst


def make_did_from_tt(tt_drug_id: str) -> str:
    """Deterministic drug hash ID based on tt_drug_id (for known IDs)."""
    h = hashlib.md5(tt_drug_id.encode("utf-8")).hexdigest()
    return f"did_{h}"


def make_did_for_unknown(key: str) -> str:
    """
    Deterministic drug hash ID for products without tt_drug_id.
    Key can be any composite string (e.g., role + name + target + mechanism).
    """
    h = hashlib.md5(key.encode("utf-8")).hexdigest()
    return f"did_{h}"


# ----------------------------------------
# RUN
# ----------------------------------------
# Load inputs
df = pd.read_csv(IN_BREAKDOWN_CSV, dtype=str).fillna("")
print(f"Loaded trial breakdown: {IN_BREAKDOWN_CSV}, shape={df.shape}")

if MASTER_PRODUCT_MECH_PATH.exists():
    product_master = json.loads(MASTER_PRODUCT_MECH_PATH.read_text(encoding="utf-8"))
else:
    product_master = {}
    print(f"No product master mech file found at {MASTER_PRODUCT_MECH_PATH}")

# role → (base_name_col, tt_id_col, target_col, mech_col)
ROLE_SPECS = [
    (
        "investigational_products",
        "investigational_products_tt_drug_id",
        "investigational_products_molecular_target",
        "investigational_products_mechanism",
    ),
    (
        "active_comparators",
        "active_comparators_tt_drug_id",
        "active_comparators_molecular_target",
        "active_comparators_mechanism",
    ),
    (
        "standard_of_care",
        "standard_of_care_tt_drug_id",
        "standard_of_care_molecular_target",
        "standard_of_care_mechanism",
    ),
]

# Aggregate per tt_drug_id and per "unknown but has mech/target"
agg_tt = {}       # tt_id -> {...}
agg_unknown = {}  # composite_key -> {...}

for _, row in df.iterrows():
    trial_hash = str(row.get("trial_hash", "")).strip()

    for base_col, tt_col, tgt_col, mech_col in ROLE_SPECS:
        # Skip if any required column is missing
        if tt_col not in df.columns or tgt_col not in df.columns or mech_col not in df.columns:
            continue

        # Base name + alt-name columns
        names_list = parse_listish(row.get(base_col, ""))
        alt_list   = parse_listish(row.get(f"{base_col}_alternative_names", ""))

        tt_ids   = parse_listish(row.get(tt_col, ""))
        targets  = parse_listish(row.get(tgt_col, ""))
        mechs    = parse_listish(row.get(mech_col, ""))

        # Align target/mech lists to tt_ids length
        targets = pad_to_length(targets, len(tt_ids))
        mechs   = pad_to_length(mechs, len(tt_ids))

        for i, raw_tt in enumerate(tt_ids):
            tt_id = str(raw_tt).strip()

            # Name (by position if available)
            name = ""
            if i < len(names_list):
                name = str(names_list[i]).strip()

            # Alternative names (can be list-of-lists or flat)
            alt_names_for_this = []
            if i < len(alt_list):
                alt_entry = alt_list[i]
                if isinstance(alt_entry, list):
                    for a in alt_entry:
                        a_str = str(a).strip()
                        if a_str:
                            alt_names_for_this.append(a_str)
                else:
                    a_str = str(alt_entry).strip()
                    if a_str:
                        alt_names_for_this.append(a_str)

            # Existing target/mechanism from CSV
            csv_target = str(targets[i]).strip()
            csv_mech   = str(mechs[i]).strip()

            # Case 1: Have a tt_drug_id → normal aggregation
            if tt_id:
                if tt_id not in agg_tt:
                    agg_tt[tt_id] = {
                        "names": set(),
                        "alt_names": set(),
                        "targets": set(),
                        "mechs": set(),
                        "trials": set(),
                    }

                # Record trial hash
                if trial_hash:
                    agg_tt[tt_id]["trials"].add(trial_hash)

                # Names
                if name:
                    agg_tt[tt_id]["names"].add(name)

                for a_str in alt_names_for_this:
                    agg_tt[tt_id]["alt_names"].add(a_str)

                # LLM-inferred target/mechanism (if available)
                info = product_master.get(tt_id) or {}
                inferred_target = str(info.get("molecular_target", "") or "").strip()
                inferred_mech   = str(info.get("mechanism", "") or "").strip()

                # Final chosen values for this (trial, index, tt_id)
                final_target = csv_target or inferred_target
                final_mech   = csv_mech   or inferred_mech

                if final_target:
                    agg_tt[tt_id]["targets"].add(final_target)
                if final_mech:
                    agg_tt[tt_id]["mechs"].add(final_mech)

            # Case 2: NO tt_drug_id, but we have target or mechanism
            # → create a synthetic product entry with empty tt_drug_id
            else:
                # If we have no name and no mechanistic info, skip
                if not (name or csv_target or csv_mech):
                    continue

                # Only create unknown entry if there is mechanistic info
                if not (csv_target or csv_mech):
                    continue

                # Build a composite key to deduplicate unknown products
                composite_key = f"{base_col}||{name}||{csv_target}||{csv_mech}"

                if composite_key not in agg_unknown:
                    agg_unknown[composite_key] = {
                        "names": set(),
                        "alt_names": set(),
                        "targets": set(),
                        "mechs": set(),
                        "trials": set(),
                    }

                if trial_hash:
                    agg_unknown[composite_key]["trials"].add(trial_hash)

                if name:
                    agg_unknown[composite_key]["names"].add(name)

                for a_str in alt_names_for_this:
                    agg_unknown[composite_key]["alt_names"].add(a_str)

                if csv_target:
                    agg_unknown[composite_key]["targets"].add(csv_target)
                if csv_mech:
                    agg_unknown[composite_key]["mechs"].add(csv_mech)

print(f"Aggregated {len(agg_tt)} distinct tt_drug_id entries.")
print(f"Aggregated {len(agg_unknown)} products without tt_drug_id but with target/mechanism.")

# Build did-keyed JSON structure
drug_master_by_did = {}

# 1) Entries with real tt_drug_id
for tt_id, payload in agg_tt.items():
    did = make_did_from_tt(tt_id)
    drug_master_by_did[did] = {
        "did": did,
        "tt_drug_id": tt_id,
        "drug_names": sorted(payload["names"]),
        "alternative_names": sorted(payload["alt_names"]),
        "molecular_targets": sorted(payload["targets"]),
        "product_mechanisms": sorted(payload["mechs"]),
        "trial_hashes": sorted(payload["trials"]),
    }

# 2) Entries without tt_drug_id (tt_drug_id = "")
for composite_key, payload in agg_unknown.items():
    did = make_did_for_unknown(composite_key)
    drug_master_by_did[did] = {
        "did": did,
        "tt_drug_id": "",  # explicitly empty as requested
        "drug_names": sorted(payload["names"]),
        "alternative_names": sorted(payload["alt_names"]),
        "molecular_targets": sorted(payload["targets"]),
        "product_mechanisms": sorted(payload["mechs"]),
        "trial_hashes": sorted(payload["trials"]),
    }

# Save JSON
OUT_JSON.parent.mkdir(parents=True, exist_ok=True)
OUT_JSON.write_text(
    json.dumps(drug_master_by_did, ensure_ascii=False, indent=2),
    encoding="utf-8",
)
print(f"Saved did-keyed drug master JSON → {OUT_JSON}")
print(f"Total drugs: {len(drug_master_by_did)}")

Loaded trial breakdown: cache/task_1/trial_product_breakdown.csv, shape=(184, 23)
Aggregated 114 distinct tt_drug_id entries.
Aggregated 15 products without tt_drug_id but with target/mechanism.
Saved did-keyed drug master JSON → cache/task_1/product_id_master_by_did.json
Total drugs: 129


#### Task 2

Identify whether the drugs are innovative or/generic biosimilars

In [47]:
"""
Classify investigational drugs in each trial as Innovative, Generic, or Biosimilar.

Inputs (from cache/task_1/):
- raw_trials_with_hash.csv
    • Per-trial metadata (must include trial_hash and RELEVANT_COLS)
- trial_product_breakdown.csv
    • Per-trial drug-role breakdown (investigational_products, *_tt_drug_id, etc.)
- product_id_master_by_did.json
    • did-keyed product master including:
        - tt_drug_id
        - drug_names
        - alternative_names
        - molecular_targets
        - product_mechanisms

Process:
- Build tt_drug_id → product metadata from product_id_master_by_did.json.
- Merge trial metadata with trial_product_breakdown on trial_hash.
- For each trial that has investigational products:
    • Build per-drug context (names, alt names, targets, mechanisms, tt_drug_id).
    • Build an LLM prompt with trial text + all breakdown columns.
    • Ask the LLM to classify each investigational product as:
          "Innovative", "Generic", or "Biosimilar",
      and provide a one-sentence explanation + tt_drug_id.
    • Validate JSON, coverage, and required fields.
    • Save per-trial classification JSON, a log JSON, and update a master JSON.

Outputs (to cache/task_2/):
- trial_investigational_drugs_classifications/{trial_hash}.json
- trial_investigational_drugs_classifications_log/{trial_hash}.json
- trial_investigational_drugs_classifications_master.json
- Console summary of processed / skipped / error counts.
"""


import json
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import ast
import re

import pandas as pd
from services.openai_wrapper import OpenAIWrapper

# -------------------------------------------------
# CONFIG
# -------------------------------------------------

IN_BASE_DIR_PRE = Path("cache/data_preprocess")
IN_BASE_DIR_TASK1 = Path("cache/task_1")

TRIALS_WITH_HASH_CSV  = IN_BASE_DIR_PRE / "raw_trials_with_hash.csv"
PRODUCT_BREAKDOWN_CSV = IN_BASE_DIR_TASK1 / "trial_product_breakdown.csv"
PRODUCT_BY_DID_JSON   = IN_BASE_DIR_TASK1 / "product_id_master_by_did.json"

OUT_BASE_DIR = Path("cache/task_2")
OUT_BASE_DIR.mkdir(parents=True, exist_ok=True)

INNOV_DIR = OUT_BASE_DIR / "trial_investigational_drugs_classifications"
INNOV_DIR.mkdir(parents=True, exist_ok=True)

INNOV_LOG_DIR = OUT_BASE_DIR / "trial_investigational_drugs_classifications_log"
INNOV_LOG_DIR.mkdir(parents=True, exist_ok=True)

MASTER_INNOV_PATH = OUT_BASE_DIR / "trial_investigational_drugs_classifications_master.json"

RELEVANT_COLS = [
    "title",
    "objective",
    "outcome_details",
    "treatment_plan",
    "notes_json",
    "results_json",
    "primary_drugs_tested_json",
    "other_drugs_tested_json",
    "therapeutic_areas_json",
    "bmt_other_drugs_tested_json",
    "bmt_primary_drugs_tested_json",
    "ct_gov_mesh_terms_json",
]

MAX_WORKERS_INNOV = 8

MODEL = "gpt-5"
client = OpenAIWrapper()

# -------------------------------------------------
# HELPERS
# -------------------------------------------------
def load_master_innov() -> dict:
    """Load the master innovation classification JSON, or return empty dict if missing/invalid."""
    if not MASTER_INNOV_PATH.exists():
        return {}
    try:
        return json.loads(MASTER_INNOV_PATH.read_text(encoding="utf-8"))
    except Exception:
        return {}


def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output text."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}

    # Direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    # Fallback: first {...} region
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}

    return {}


def parse_listish(s):
    """
    Parse a stringified list like "['A', 'B']" into a Python list.
    If parsing fails or the cell is empty, return [].
    """
    if isinstance(s, list):
        return s
    if s is None:
        return []
    s = str(s).strip()
    if not s:
        return []
    # Common empty-list cases
    if s in ("[]", "[ ]"):
        return []
    try:
        val = ast.literal_eval(s)
        if isinstance(val, list):
            return val
        # If it's something else, treat as a single non-empty token
        return [val]
    except Exception:
        # Fallback: treat non-empty string as a single element
        return [s]


def pad_to_length(lst, n):
    """Pad list with empty strings so len(lst) >= n."""
    lst = list(lst)
    while len(lst) < n:
        lst.append("")
    return lst


def build_innovation_prompt(trial_payload: dict, drug_contexts: list[dict]) -> str:
    """
    Build prompt to classify each investigational product as
    Innovative / Generic / Biosimilar, with one-sentence explanation,
    using extra context about each drug (names, alt names, targets, mechanisms).
    """
    payload_json = json.dumps(trial_payload, ensure_ascii=False, indent=2)
    drugs_json   = json.dumps(drug_contexts, ensure_ascii=False, indent=2)

    return f"""
You are a clinical trial design and drug development expert.

You are given:
1) Structured information about a clinical trial (title, objective, results, etc.).
2) A list of investigational products used in the trial, with extra metadata for each.
   Each investigational drug entry has:
   - "name": the canonical investigational product name in this trial
   - "tt_drug_id": the TrialTrove/PharmaProjects drugId as a string (if known)
   - "drug_names": list of names for this drug
   - "alternative_names": list of alternative or synonym names
   - "molecular_targets": list of known molecular targets
   - "product_mechanisms": list of known mechanisms of action

Your task: For EACH investigational product (by its "name"), classify whether it is:
- "Innovative"
- "Generic"
- "Biosimilar"

and provide:
- a one-sentence concise explanation for your classification
- the tt_drug_id (string; use "" if unknown).

DEFINITIONS / GUIDANCE
----------------------

Innovative:
- A novel or proprietary drug (new or sponsor-specific product).
- New mechanism of action OR new molecular entity OR clearly the sponsor's lead product.
- Often associated with efficacy or superiority language:
  - "evaluate efficacy", "vs placebo", "improve outcomes", etc.
- Not a copy of an already-approved product.

Generic:
- A small-molecule copy of an already-approved branded drug.
- Same active ingredient, strength, dosage form, and route.
- Often explicitly described as generic or equivalent.

Biosimilar:
- A biologic product that is highly similar to an already-approved reference biologic.
- Same target and mechanism as a branded biologic.
- Strong clues:
  - "equivalence", "non-inferiority", "no clinically meaningful differences",
  - direct comparison to a specific branded reference biologic.

You MUST choose ONE of the three labels ("Innovative", "Generic", "Biosimilar") for each drug.
If you are uncertain, you may say so in the one-sentence explanation, but still pick a label.

Use all available information:
- Trial text and design
- Investigational vs comparator roles
- Known targets/mechanisms from the drug metadata.

OUTPUT FORMAT (IMPORTANT)
-------------------------

Return ONLY a valid JSON object, with:
- KEYS   = exactly the investigational product "name" values given below
- VALUES = an object with exactly three fields:
    - "classification": one of "Innovative", "Generic", "Biosimilar"
    - "explanation": a single, concise sentence explaining your reasoning
    - "tt_drug_id": the string tt_drug_id for this drug ("" if unknown)

Example output:
{{
  "DrugA": {{
    "classification": "Innovative",
    "explanation": "DrugA is a novel monoclonal antibody targeting a new receptor and is the sponsor's lead product.",
    "tt_drug_id": "123456"
  }},
  "DrugB": {{
    "classification": "Biosimilar",
    "explanation": "DrugB is tested for equivalence compared to the branded biologic with the same active ingredient.",
    "tt_drug_id": "789012"
  }}
}}

TRIAL PAYLOAD (includes trial text and all drug-role breakdown columns):
{payload_json}

INVESTIGATIONAL DRUG CONTEXTS (you MUST classify EACH by its 'name' key):
{drugs_json}
""".strip()


# Shared master + counters
master_innov = load_master_innov()
master_lock = threading.Lock()

innov_counter = {
    "processed": 0,
    "skipped_existing": 0,
    "llm_error": 0,
    "parse_error": 0,
    "coverage_error": 0,
}
counter_lock = threading.Lock()


def process_innov_row(row: dict, idx: int, total: int, breakdown_cols: list[str], tt_to_drug_meta: dict) -> None:
    """Process a single trial with investigational products and classify them."""
    trial_hash = str(row.get("trial_hash", "")).strip()
    if not trial_hash:
        print(f"⚠️ [{idx}/{total}] Missing trial_hash, skipping")
        return

    # Names as used in the trial
    investigational_products = row.get("investigational_products_parsed") or []
    investigational_products = [str(x).strip() for x in investigational_products if str(x).strip()]

    if not investigational_products:
        # Shouldn't happen due to filtering, but be safe
        return

    out_fp = INNOV_DIR / f"{trial_hash}.json"
    if out_fp.exists():
        with counter_lock:
            innov_counter["skipped_existing"] += 1
        return

    # ------------------------------------
    # Build per-drug context from did JSON
    # ------------------------------------
    inv_tt_raw = row.get("investigational_products_tt_drug_id", "")
    inv_tt_ids = parse_listish(inv_tt_raw)
    inv_tt_ids = pad_to_length(inv_tt_ids, len(investigational_products))

    drug_contexts = []
    for i, name in enumerate(investigational_products):
        tt_id = str(inv_tt_ids[i]).strip() if i < len(inv_tt_ids) else ""
        meta = tt_to_drug_meta.get(tt_id, {}) if tt_id else {}

        drug_contexts.append(
            {
                "name": name,
                "tt_drug_id": tt_id,
                "drug_names": meta.get("drug_names", []),
                "alternative_names": meta.get("alternative_names", []),
                "molecular_targets": meta.get("molecular_targets", []),
                "product_mechanisms": meta.get("product_mechanisms", []),
            }
        )

    # ------------------------------------
    # Build payload from selected columns
    # ------------------------------------
    trial_payload = {"trial_hash": trial_hash}

    # 1) Trial-level textual fields from raw_trials_with_hash.csv
    for col in RELEVANT_COLS:
        trial_payload[col] = row.get(col, "")

    # 2) ALL columns from trial_product_breakdown.csv
    for col in breakdown_cols:
        trial_payload[col] = row.get(col, "")

    prompt = build_innovation_prompt(trial_payload, drug_contexts)

    token = trial_hash
    hash_id = trial_hash

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    # Call LLM
    try:
        t0 = time.perf_counter()
        res = client.query(prompt=prompt, model=MODEL)
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)
    except Exception as e:
        print(f"⚠️ [{idx}/{total}] LLM error for trial_hash={trial_hash}: {e}")
        with counter_lock:
            innov_counter["llm_error"] += 1
        return

    # Parse JSON
    classifications = extract_json_object(text_response)

    if not isinstance(classifications, dict) or not classifications:
        print(f"⚠️ [{idx}/{total}] JSON parse error trial_hash={trial_hash}, raw={text_response!r}")
        with counter_lock:
            innov_counter["parse_error"] += 1
        return

    # Check coverage: every investigational product must be present as a key
    missing = [d for d in investigational_products if d not in classifications]
    if missing:
        print(
            f"⚠️ [{idx}/{total}] Coverage error for trial_hash={trial_hash}: "
            f"missing classifications for {missing}"
        )
        with counter_lock:
            innov_counter["coverage_error"] += 1
        # DO NOT save this trial so it can be re-run next time
        return

    # Sanity check: each value has classification, explanation, tt_drug_id
    for d in investigational_products:
        meta = classifications.get(d, {})
        if not isinstance(meta, dict):
            print(f"⚠️ [{idx}/{total}] Invalid meta for {d} in trial_hash={trial_hash}")
            with counter_lock:
                innov_counter["parse_error"] += 1
            return
        if ("classification" not in meta) or ("explanation" not in meta) or ("tt_drug_id" not in meta):
            print(f"⚠️ [{idx}/{total}] Missing fields for {d} in trial_hash={trial_hash}")
            with counter_lock:
                innov_counter["parse_error"] += 1
            return

    mapped = {
        "trial_hash": trial_hash,
        "investigational_products": investigational_products,
        "classifications": classifications,
        "source": "llm",
    }

    # Save per-trial JSON
    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    # Log entry
    log_payload = {
        "token": token,
        "hash_id": hash_id,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (INNOV_LOG_DIR / f"{hash_id}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    # Update master
    with master_lock:
        master_innov[trial_hash] = mapped
        MASTER_INNOV_PATH.write_text(
            json.dumps(master_innov, ensure_ascii=False, indent=2),
            encoding="utf-8"
        )

    with counter_lock:
        innov_counter["processed"] += 1
        if innov_counter["processed"] % 50 == 0:
            print(f"Progress: processed {innov_counter['processed']} trials for innovation status...")


# -------------------------------------------------
# RUN
# -------------------------------------------------
# Load per-drug metadata (product_id_master_by_did.json) and build tt_drug_id → metadata
if PRODUCT_BY_DID_JSON.exists():
    product_by_did = json.loads(PRODUCT_BY_DID_JSON.read_text(encoding="utf-8"))
else:
    product_by_did = {}
    print(f"⚠️ No drug master JSON found at {PRODUCT_BY_DID_JSON}")

tt_to_drug_meta: dict[str, dict] = {}
for did, rec in product_by_did.items():
    tt = str(rec.get("tt_drug_id", "")).strip()
    if tt and tt not in tt_to_drug_meta:
        tt_to_drug_meta[tt] = rec

# Load breakdown (investigational products + all drug-role cols)
df_breakdown = pd.read_csv(PRODUCT_BREAKDOWN_CSV, dtype=str).fillna("")

df_breakdown["investigational_products_parsed"] = df_breakdown["investigational_products"].apply(parse_listish)
mask_has_inv = df_breakdown["investigational_products_parsed"].apply(lambda x: len(x) > 0)

# Restrict to rows with investigational products
df_breakdown_sub = df_breakdown.loc[mask_has_inv].copy()

# All columns from trial_product_breakdown.csv except trial_hash (which is already separate)
BREAKDOWN_COLS = [c for c in df_breakdown_sub.columns if c != "trial_hash"]

# Load raw trials (for RELEVANT_COLS)
df_trials = pd.read_csv(TRIALS_WITH_HASH_CSV, dtype=str).fillna("")

# Merge on trial_hash; keep all breakdown columns + investigational_products_parsed + RELEVANT_COLS
df_merged = df_breakdown_sub.merge(
    df_trials[["trial_hash"] + RELEVANT_COLS],
    on="trial_hash",
    how="left",
)

innov_rows = df_merged.to_dict(orient="records")
total_innov = len(innov_rows)
print(f"Loaded {total_innov} trials with investigational products for innovation-status classification.")

# Run concurrently
with ThreadPoolExecutor(max_workers=MAX_WORKERS_INNOV) as ex:
    futures = {
        ex.submit(process_innov_row, row, idx, total_innov, BREAKDOWN_COLS, tt_to_drug_meta): row.get("trial_hash")
        for idx, row in enumerate(innov_rows, start=1)
    }
    for fut in as_completed(futures):
        th = futures[fut]
        try:
            fut.result()
        except Exception as e:
            print(f"⚠️ Worker error (innovation) trial_hash={th}: {e}")

print(
    f"Trial investigational-drug innovation classification complete. "
    f"processed={innov_counter['processed']}, "
    f"skipped={innov_counter['skipped_existing']}, "
    f"llm_error={innov_counter['llm_error']}, "
    f"parse_error={innov_counter['parse_error']}, "
    f"coverage_error={innov_counter['coverage_error']}"
)
print(f"Classifications directory: {INNOV_DIR}")
print(f"Log directory:             {INNOV_LOG_DIR}")
print(f"Master classifications:    {MASTER_INNOV_PATH}")

Loaded 179 trials with investigational products for innovation-status classification.
Progress: processed 50 trials for innovation status...
Progress: processed 100 trials for innovation status...
Progress: processed 150 trials for innovation status...
Trial investigational-drug innovation classification complete. processed=179, skipped=0, llm_error=0, parse_error=0, coverage_error=0
Classifications directory: cache/task_2/trial_investigational_drugs_classifications
Log directory:             cache/task_2/trial_investigational_drugs_classifications_log
Master classifications:    cache/task_2/trial_investigational_drugs_classifications_master.json


In [51]:
"""
Summarize total LLM usage cost for previous cell by reading all per-trial log files.

Inputs:
- Directory: cache/task_2/trial_investigational_drugs_classifications
    Each log JSON contains:
        • total_cost (float)
        • other metadata (prompt, raw response, timing, etc.)

Process:
- Load each log file and extract its total_cost value.
- Aggregate total cost, count entries, and compute average cost per trial.
- Sort trials by cost to identify the most expensive prompts.

Outputs:
- Console summary including:
    • Total cost
    • Number of logged trials
    • Average cost per trial
    • Top 10 highest-cost trials (filename + cost)
"""

import json
from pathlib import Path

LOG_DIR = Path("cache/task_2/trial_investigational_drugs_classifications_log")

total_cost = 0.0
num_entries = 0
costs = []

for fp in LOG_DIR.glob("*.json"):
    try:
        log = json.loads(fp.read_text(encoding="utf-8"))
        c = float(log.get("total_cost") or 0.0)
        total_cost += c
        costs.append((fp.name, c))
        num_entries += 1
    except Exception as e:
        print(f"Error reading {fp.name}: {e}")

# Sort descending by cost
costs_sorted = sorted(costs, key=lambda x: x[1], reverse=True)

print("========== LLM COST SUMMARY ==========")
print(f"Total LLM cost:             ${total_cost:,.4f}")
print(f"Number of logged trials:     {num_entries}")
if num_entries > 0:
    print(f"Average cost per trial:      ${total_cost / num_entries:,.4f}")
print("")

print("Top 10 most expensive trials:")
for name, c in costs_sorted[:10]:
    print(f"  {name}: ${c:,.4f}")

print("========================================")

Total LLM cost:             $1.8837
Number of logged trials:     179
Average cost per trial:      $0.0105

Top 10 most expensive trials:
  tid_9727cefa81bf0a9c341273bce42d3346.json: $0.0330
  tid_8b4d60a5fddc078962af34399d7e342c.json: $0.0301
  tid_99fac3ebe48aad5ebc1077142f61d5eb.json: $0.0218
  tid_28a767e788d4d9a4e65b3c10d10585c2.json: $0.0215
  tid_43635104c2d64be16c8882a500dd5181.json: $0.0205
  tid_456995e1db18e20bcddc2bdf2938fac3.json: $0.0167
  tid_b29013cdbc706b95776d47be1d6e98e6.json: $0.0167
  tid_a50324f4d36f5cc93b795ec7f8b7005b.json: $0.0165
  tid_69d3a93a71b9ed6021c817d9afa127fa.json: $0.0163
  tid_f48b01eb433692a187251a3a10fa9923.json: $0.0161


In [49]:
"""
Flatten per-trial investigational drug innovation JSON into a single CSV.

Inputs (from cache/task_2/):
- trial_investigational_drugs_classifications/{trial_hash}.json
    Each JSON has:
        • trial_hash
        • investigational_products (list of raw drug strings)
        • classifications: { raw_drug_name_or_key: {classification, explanation, tt_drug_id} }

Process:
- For each trial:
    • For each raw investigational product entry:
        - Parse it as a list if it's a stringified list (e.g. "['inetetamab', 'toripalimab']").
        - Look up its classification using the original key first,
          then fall back to each parsed name.
        - Expand to flat aligned lists: one row-level list of product names,
          and one row-level list of classifications.
- Build a DataFrame with one row per trial:
    • trial_hash
    • investigational_products (JSON stringified flat list of names)
    • investigational_products_classifications (JSON stringified flat list of labels)

Outputs (to cache/task_2/):
- trial_investigational_drugs_classifications.csv
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
import json
import ast
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache/task_2")

INNOV_DIR = BASE_DIR / "trial_investigational_drugs_classifications"
OUT_CSV   = BASE_DIR / "trial_investigational_drugs_classifications.csv"

# -------------------------------------------------
# HELPERS
# -------------------------------------------------
def parse_listish(s):
    """
    Parse a stringified list like "['A', 'B']" into a Python list.
    If parsing fails or the cell is empty, return [].
    """
    if isinstance(s, list):
        return s
    if s is None:
        return []
    s = str(s).strip()
    if not s:
        return []
    if s in ("[]", "[ ]"):
        return []
    try:
        val = ast.literal_eval(s)
        if isinstance(val, list):
            return val
        # If it's something else, treat as a single non-empty token
        return [val]
    except Exception:
        # Fallback: treat non-empty string as a single element
        return [s]

# -------------------------------------------------
# RUN
# -------------------------------------------------
rows = []

for fp in INNOV_DIR.glob("*.json"):
    try:
        obj = json.loads(fp.read_text(encoding="utf-8"))
    except Exception as e:
        print(f"Error reading {fp.name}: {e}")
        continue

    trial_hash = obj.get("trial_hash")
    if not trial_hash:
        print(f"Missing trial_hash in {fp.name}, skipping")
        continue

    inv_products_raw = obj.get("investigational_products") or []
    classifications_map = obj.get("classifications") or {}

    flat_products = []
    flat_classifications = []

    for drug_raw in inv_products_raw:
        # drug_raw might be "['inetetamab', 'toripalimab']" or just "SSGJ-707"
        if isinstance(drug_raw, str):
            parsed_names = parse_listish(drug_raw)
        else:
            parsed_names = [drug_raw]

        # Prefer classification using the exact key that was sent to the model
        meta = classifications_map.get(drug_raw, {})
        cls = meta.get("classification", "")

        # If not found, try each parsed name as a key
        if not cls:
            for name in parsed_names:
                meta_n = classifications_map.get(name, {})
                if "classification" in meta_n:
                    cls = meta_n.get("classification", "")
                    break

        if not cls:
            print(
                f"Missing classification for raw drug {drug_raw!r} in "
                f"trial_hash={trial_hash}, file={fp.name}"
            )

        # Add one entry per parsed name so both lists are flat and aligned
        for name in parsed_names:
            flat_products.append(name)
            flat_classifications.append(cls)

    # Sanity check: lengths must match
    if len(flat_products) != len(flat_classifications):
        print(
            f"Length mismatch for trial_hash={trial_hash}: "
            f"{len(flat_products)} products vs {len(flat_classifications)} classifications"
        )

    rows.append(
        {
            "trial_hash": trial_hash,
            # store as JSON stringified flat lists
            "investigational_products": json.dumps(flat_products, ensure_ascii=False),
            "investigational_products_classifications": json.dumps(flat_classifications, ensure_ascii=False),
        }
    )

df_out = pd.DataFrame(rows).sort_values("trial_hash")

OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
df_out.to_csv(OUT_CSV, index=False)

print(f"Saved investigational drug classifications to {OUT_CSV}")
print(df_out.head().to_markdown(index=False))

Saved investigational drug classifications to cache/task_2/trial_investigational_drugs_classifications.csv
| trial_hash                           | investigational_products                           | investigational_products_classifications   |
|:-------------------------------------|:---------------------------------------------------|:-------------------------------------------|
| tid_0541995757b10e613a42173d6b8ddc09 | ["cinacalcet hydrochloride"]                       | ["Generic"]                                |
| tid_0d6e9b2f3f57c17c0e93610e28853f0c | ["Xenopax"]                                        | ["Biosimilar"]                             |
| tid_0da20e863cfc5f3e369868462bff74e0 | ["recombinant erythropoiesis-stimulating protein"] | ["Innovative"]                             |
| tid_0e698ee5065c49d23fcf57516957a273 | ["SB8"]                                            | ["Biosimilar"]                             |
| tid_0e8fa21079f928135dfc6164a15285f8 | ["SSS-17"]        

#### Task 3

In [50]:
"""
PubMed search for each drug's mechanism of action (MOA), one JSON per did.

Inputs (from cache/task_1/):
- product_id_master_by_did.json
    Keys: did_*
    Values: {
        "tt_drug_id": str,
        "drug_names": [...],
        "alternative_names": [...],
        "molecular_targets": [...],
        "product_mechanisms": [...]
    }

Process:
- For each did_* entry:
    • Build candidate PubMed search terms from:
        - product_mechanisms (MOA strings; split on ';')
        - molecular_targets
    • Query PubMed (NCBI E-utilities) via:
        - esearch: get top PMIDs by relevance
        - efetch : fetch titles, abstracts, MeSH terms, and substances
    • If no hits from raw mechanisms, use an LLM to refine MOA
      into a more canonical search term and re-try.
    • Aggregate all PMIDs and records per did and store:
        - per-term search breakdown
        - overall union of articles and first-hit term.

Outputs (to cache/task_3/):
- investigational_drug_moa_pubmed_search/{did}.json
    {
      "type": "drug_moa_pubmed_search",
      "did": ...,
      "tt_drug_id": ...,
      "drug_names": [...],
      "alternative_names": [...],
      "molecular_targets": [...],
      "product_mechanisms": [...],
      "mechanism_combined": "...",
      "mechanism_key": "...",
      "tried_terms": [...],
      "llm_refined_mechanism": "... or null",
      "mechanism_search": { term -> {pmids, records} },
      "target_search": { term -> {pmids, records} },
      "match": {
        "term": first_hit_term,
        "pmids": [...],
        "records": { pmid -> {title, abstract, mesh_terms, substances} }
      }
    }

- investigational_drug_moa_pubmed_index.json
    Lightweight summary index keyed by did_* with:
        "pmids", "matched_term", "json_path", etc.
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
import os
import json
import time
import html
import unicodedata
from pathlib import Path
from xml.etree import ElementTree as ET

import requests

from services.openai_wrapper import OpenAIWrapper

BASE_IN_DIR = Path("cache/task_1")
BASE_OUT_DIR = Path("cache/task_3")

PRODUCT_MASTER_PATH = BASE_IN_DIR / "product_id_master_by_did.json"

OUT_DIR = BASE_OUT_DIR / "investigational_drug_moa_pubmed_search"
OUT_DIR.mkdir(parents=True, exist_ok=True)

MASTER_INDEX_PATH = BASE_OUT_DIR / "investigational_drug_moa_pubmed_index.json"

EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
API_KEY = os.getenv("NCBI_API_KEY") or None
EMAIL = os.getenv("NCBI_EMAIL") or None
SLEEP = 0.25
RETRY_MAX = 3
RETRY_WAIT = 1.0

# LLM config (for MOA refinement)
MODEL = "gpt-5-mini"  # adjust if needed
client = OpenAIWrapper()

NAN_STRINGS = {"nan", "none", "null", ""}


# -------------------------------------------------
# HELPERS
# -------------------------------------------------
def _clean(s):
    """Normalize string; treat various 'nan' / empty-like tokens as empty."""
    if s is None:
        return ""
    s_str = str(s).strip()
    return "" if s_str.lower() in NAN_STRINGS else s_str


def norm_text(s: str) -> str:
    """Lowercase, strip, normalize Unicode and whitespace; return '' if nan-like."""
    if not isinstance(s, str):
        return ""
    t = html.unescape(s)
    t = unicodedata.normalize("NFKC", t)
    t = " ".join(t.strip().lower().split())
    return "" if t in NAN_STRINGS else t


def _http_get_with_retry(url: str, params: dict, timeout: int) -> requests.Response:
    """HTTP GET with basic retry logic."""
    last_err = None
    for attempt in range(1, RETRY_MAX + 1):
        try:
            r = requests.get(url, params=params, timeout=timeout)
            r.raise_for_status()
            return r
        except Exception as e:
            last_err = e
            if attempt < RETRY_MAX:
                time.sleep(RETRY_WAIT)
            else:
                raise last_err


def esearch_ids(term: str, n: int = 3) -> list[str]:
    """Run PubMed esearch for a term, return up to n PMIDs."""
    term = _clean(term)
    if not term:
        return []
    params = {
        "db": "pubmed",
        "term": term,
        "retmode": "json",
        "retmax": n,
        "sort": "relevance",
    }
    if API_KEY:
        params["api_key"] = API_KEY
    if EMAIL:
        params["email"] = EMAIL
    r = _http_get_with_retry(f"{EUTILS}/esearch.fcgi", params=params, timeout=30)
    return r.json().get("esearchresult", {}).get("idlist", []) or []


def _parse_xml_with_retry(text: str) -> ET.Element:
    """Parse XML with retry in case of transient parse errors."""
    last_err = None
    for attempt in range(1, RETRY_MAX + 1):
        try:
            return ET.fromstring(text)
        except ET.ParseError as e:
            last_err = e
            if attempt < RETRY_MAX:
                time.sleep(RETRY_WAIT)
            else:
                raise last_err


def efetch_details(pmids: list[str]) -> dict:
    """
    Fetch article details (title, abstract, MeSH, substances) for a list of PMIDs.
    Returns dict pmid -> {title, abstract, mesh_terms, substances}.
    """
    if not pmids:
        return {}
    params = {"db": "pubmed", "id": ",".join(pmids), "retmode": "xml"}
    if API_KEY:
        params["api_key"] = API_KEY
    if EMAIL:
        params["email"] = EMAIL
    r = _http_get_with_retry(f"{EUTILS}/efetch.fcgi", params=params, timeout=60)
    root = _parse_xml_with_retry(r.text)

    out = {}

    def text_from_el(el):
        return "".join(el.itertext()).strip() if el is not None else ""

    def join_abstract(abs_parent):
        parts = []
        for t in abs_parent.findall("AbstractText"):
            label = t.attrib.get("Label")
            txt = text_from_el(t)
            if txt:
                parts.append(f"{label}: {txt}" if label else txt)
        return "\n".join(parts).strip()

    for art in root.findall(".//PubmedArticle"):
        pmid_el = art.find(".//MedlineCitation/PMID")
        if pmid_el is None or not (pmid_el.text or "").strip():
            continue
        pmid = pmid_el.text.strip()

        title = text_from_el(art.find(".//Article/ArticleTitle"))
        abs_parent = art.find(".//Article/Abstract")
        abstract = join_abstract(abs_parent) if abs_parent is not None else ""

        mesh_terms = []
        for mh in art.findall(".//MedlineCitation/MeshHeadingList/MeshHeading"):
            desc = mh.find("DescriptorName")
            if desc is None or not (desc.text or "").strip():
                continue
            d_text = desc.text.strip()
            d_major = desc.attrib.get("MajorTopicYN") == "Y"
            d_str = f"{d_text}{'*' if d_major else ''}"

            quals = []
            for q in mh.findall("QualifierName"):
                q_text = (q.text or "").strip()
                if q_text:
                    q_major = q.attrib.get("MajorTopicYN") == "Y"
                    quals.append(f"{q_text}{'*' if q_major else ''}")

            mesh_terms.append(d_str if not quals else d_str + " / " + "; ".join(quals))

        substances = []
        for chem in art.findall(".//Chemical"):
            nm_el = chem.find("NameOfSubstance")
            rn_el = chem.find("RegistryNumber")
            nm = nm_el.text.strip() if nm_el is not None else ""
            rn = rn_el.text.strip() if rn_el is not None else ""
            if nm and rn and rn != "0":
                substances.append(f"{nm} [RN:{rn}]")
            elif nm:
                substances.append(nm)
            elif rn and rn != "0":
                substances.append(f"[RN:{rn}]")

        def uniq(xs):
            seen, out_local = set(), []
            for x in xs:
                if x and x not in seen:
                    seen.add(x)
                    out_local.append(x)
            return out_local

        out[pmid] = {
            "title": title,
            "abstract": abstract,
            "mesh_terms": uniq(mesh_terms),
            "substances": uniq(substances),
        }

    return out


def save_json(path: Path, obj: dict):
    """Write JSON to disk with UTF-8 + pretty indent."""
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(obj, indent=2, ensure_ascii=False), encoding="utf-8")


def load_json_or_empty(path: Path) -> dict:
    """Load JSON or return empty dict on failure / missing file."""
    if not path.exists():
        return {}
    try:
        return json.loads(path.read_text(encoding="utf-8"))
    except Exception:
        return {}


def split_terms(s: str):
    """
    For MOA strings like:
      'Thrombopoietin receptor agonist (recombinant growth factor); PEGylated recombinant human EPO'
    split on ';' and treat each piece as a candidate search term.
    """
    if not s:
        return []
    raw = [t.strip() for t in str(s).split(";")]
    return [t for t in raw if t and t.lower() not in NAN_STRINGS]


# --------------- LLM refinement helpers ---------------
def build_moa_refinement_prompt(mechanism: str) -> str:
    """
    Prompt the chatbot to turn a free-text MOA into a concise, canonical
    mechanism-of-action phrase suitable for PubMed search.
    """
    return f"""
You are an expert clinical pharmacologist and mechanisms-of-action classifier.

Given the following mechanism-of-action (MOA) description from a drug development database:

\"\"\"{mechanism}\"\"\"

Rewrite or condense it into a SHORT, CANONICAL mechanism-of-action term that would work well as a PubMed search term.

Rules:
- Output a concise mechanism class or well-recognized pharmacologic concept, not a full sentence.
- Prefer standard pharmacologic/mechanistic classes (e.g. "Ion Exchange Resins", "Immunocytokines",
  "Kinase Inhibitors", "Antibodies, Monoclonal", "Immune Checkpoint Inhibitors").
- Do NOT include long target listings or extra explanation.
- If the original MOA is already an appropriate concise search term, you may return it unchanged.

Return ONLY the refined mechanism phrase, with no additional explanation or formatting.
""".strip()


def refine_mechanism_with_llm(mechanism: str) -> str | None:
    """
    Use the OpenAIWrapper .query() interface to get a refined mechanism phrase.
    Returns the refined phrase or None on failure.
    """
    mech_clean = _clean(mechanism)
    if not mech_clean:
        return None

    prompt = build_moa_refinement_prompt(mech_clean)

    try:
        res = client.query(prompt=prompt, model=MODEL)
        text = (res.get("text_response") or "").strip()
        # Strip surrounding quotes if the model adds them
        text = text.strip().strip('"').strip("'")
        refined = _clean(text)
        return refined or None
    except Exception as e:
        print(f"LLM refinement failed for mechanism='{mech_clean[:80]}': {e}")
        return None


# -------------------------------------------------
# RUN
# -------------------------------------------------
# Load product master (by did) from task_1
product_master_by_did = load_json_or_empty(PRODUCT_MASTER_PATH)
if not product_master_by_did:
    raise RuntimeError(f"No product entries found in {PRODUCT_MASTER_PATH}")

# Load existing PubMed index (from task_3)
master_index = load_json_or_empty(MASTER_INDEX_PATH) or {}

total = len(product_master_by_did)
print(f"{total} drug entries (did_*) to process")
processed = 0

# Main loop: one PubMed search per DRUG (by did)
for did, rec in product_master_by_did.items():
    # Skip if already indexed
    if did in master_index:
        mech_list = rec.get("product_mechanisms", []) or []
        mech_preview = "; ".join(mech_list)[:60]
        print(f"{mech_preview} || already processed for {did}")
        processed += 1
        continue

    # product_mechanisms is a list; join into a single string for a "combo" key,
    # but we'll search EACH mechanism (and its ';'-split pieces) separately.
    mech_list = rec.get("product_mechanisms", []) or []
    mechanism = _clean("; ".join(mech_list))
    if not mechanism:
        print(f"Empty mechanism list for did={did}, skipping")
        continue

    mech_key = norm_text(mechanism)

    # -----------------------------
    # Build candidate search terms
    # -----------------------------
    # For MOAs: search EACH mechanism string (and each ';'-split subterm).
    mechanism_terms: list[str] = []
    for mech in mech_list:
        mech = _clean(mech)
        if not mech:
            continue
        subterms = split_terms(mech)
        if not subterms:
            subterms = [mech]
        for t in subterms:
            t_clean = _clean(t)
            if t_clean and t_clean not in mechanism_terms:
                mechanism_terms.append(t_clean)

    # For molecular targets: direct terms
    target_terms: list[str] = []
    for tgt in rec.get("molecular_targets", []) or []:
        t_clean = _clean(tgt)
        if t_clean and t_clean not in target_terms:
            target_terms.append(t_clean)

    if not mechanism_terms and not target_terms:
        print(f"No usable mechanism or target terms for did={did}, skipping")
        continue

    tried_terms: list[str] = []
    first_hit_term: str | None = None
    llm_refined: str | None = None

    # Detailed per-term results
    mechanism_search: dict[str, dict] = {}
    target_search: dict[str, dict] = {}

    # Aggregate across all searches for summary
    all_pmids: set[str] = set()
    all_records: dict[str, dict] = {}

    # -----------------------------
    # 1) Mechanism term searches
    # -----------------------------
    for term in mechanism_terms:
        tried_terms.append(term)
        query = f"\"{term}\""
        try:
            pmids = esearch_ids(query, n=5)
        except Exception:
            pmids = []

        records = {}
        if pmids:
            try:
                records = efetch_details(pmids)
            except Exception as e:
                records = {"_error": str(e)}

            # Track first term that hits
            if first_hit_term is None and pmids:
                first_hit_term = term

            for p in pmids:
                all_pmids.add(p)
                if p not in all_records and p in records:
                    all_records[p] = records[p]

        mechanism_search[term] = {
            "pmids": pmids,
            "records": records,
        }

    # -----------------------------
    # 2) LLM refinement if NO mechanism hits
    # -----------------------------
    if not all_pmids:
        llm_refined = refine_mechanism_with_llm(mechanism)
        if llm_refined:
            llm_term_key = llm_refined  # store as-is
            tried_terms.append(llm_refined + " [LLM]")
            query = f"\"{llm_refined}\""
            try:
                pmids = esearch_ids(query, n=5)
            except Exception:
                pmids = []

            records = {}
            if pmids:
                try:
                    records = efetch_details(pmids)
                except Exception as e:
                    records = {"_error": str(e)}

                if first_hit_term is None and pmids:
                    first_hit_term = llm_refined

                for p in pmids:
                    all_pmids.add(p)
                    if p not in all_records and p in records:
                        all_records[p] = records[p]

            mechanism_search[llm_term_key] = {
                "pmids": pmids,
                "records": records,
                "llm_refined": True,
            }

    # -----------------------------
    # 3) Molecular target term searches (10 PMIDs each)
    # -----------------------------
    for term in target_terms:
        tried_terms.append(term)
        query = f"\"{term}\""
        try:
            pmids = esearch_ids(query, n=10)  # ← 10 studies per target term
        except Exception:
            pmids = []

        records = {}
        if pmids:
            try:
                records = efetch_details(pmids)
            except Exception as e:
                records = {"_error": str(e)}

            if first_hit_term is None and pmids:
                first_hit_term = term

            for p in pmids:
                all_pmids.add(p)
                if p not in all_records and p in records:
                    all_records[p] = records[p]

        target_search[term] = {
            "pmids": pmids,
            "records": records,
        }

    # 4) If STILL no PMIDs at all, skip saving (so you can rerun later)
    if not all_pmids:
        print(f"No PubMed hits for did={did} after mechanisms + targets + LLM, skipping")
        continue

    # -----------------------------
    # HASH-BASED OUTPUT (BY did)
    # -----------------------------
    fname = f"{did}.json"
    out_path = OUT_DIR / fname

    payload = {
        "type": "drug_moa_pubmed_search",
        "did": did,
        "tt_drug_id": rec.get("tt_drug_id"),
        "drug_names": rec.get("drug_names", []),
        "alternative_names": rec.get("alternative_names", []),
        "molecular_targets": rec.get("molecular_targets", []),
        "product_mechanisms": mech_list,
        "mechanism_combined": mechanism,
        "mechanism_key": mech_key,
        "tried_terms": tried_terms,
        "llm_refined_mechanism": llm_refined,
        # Detailed breakdowns:
        "mechanism_search": mechanism_search,
        "target_search": target_search,
        # Backward-compatible summary:
        "match": {
            "term": first_hit_term,
            "pmids": sorted(all_pmids),
            "records": all_records,
        },
    }

    save_json(out_path, payload)

    # Index entry keyed by did (summary only)
    master_index[did] = {
        "did": did,
        "tt_drug_id": rec.get("tt_drug_id"),
        "drug_names": rec.get("drug_names", []),
        "product_mechanisms": mech_list,
        "mechanism_combined": mechanism,
        "mechanism_key": mech_key,
        "json_path": f"{OUT_DIR.name}/{fname}",
        "pmids": sorted(all_pmids),
        "matched_term": first_hit_term,
        "llm_refined_mechanism": llm_refined,
    }
    save_json(MASTER_INDEX_PATH, master_index)

    processed += 1
    if processed % 50 == 0:
        print(f"Processed {processed}/{total}…")

    time.sleep(SLEEP)

save_json(MASTER_INDEX_PATH, master_index)
print(
    f"Completed {processed} drug entries with at least one PubMed hit. "
    f"Files written to {OUT_DIR}"
)


129 drug entries (did_*) to process
Empty mechanism list for did=did_624e50d2bca9c6314160403c2f83bc0c, skipping
Processed 50/129…
Empty mechanism list for did=did_14f5219735c03c9b814c4b99a887d5f5, skipping
Processed 100/129…
Completed 127 drug entries with at least one PubMed hit. Files written to cache/task_3/investigational_drug_moa_pubmed_search


In [52]:
"""
Select a single, mechanistically appropriate MeSH term for each drug MOA, using PubMed search results.

Inputs (from cache/task_3/):
- investigational_drug_moa_pubmed_search/{did}.json
    Each JSON (one per did_*) should contain:
        • did / moa_id
        • product_mechanisms (list of MOA strings)
        • mechanism_combined (joined MOA string)
        • tried_terms (list of PubMed search terms used)
        • match:
            - pmids: list of PubMed IDs
            - records: {
                  pmid: {
                      "title": str,
                      "abstract": str,
                      "mesh_terms": [list of MeSH descriptor/qualifier strings],
                      "substances": [optional list of chemical/substance labels]
                  },
                  ...
              }

Process:
- For each MOA PubMed-search JSON:
    • Derive a moa_id (prefer payload["moa_id"], then payload["did"], then filename stem).
    • Collect *candidate* MeSH terms by unioning all `mesh_terms` across the matched PMIDs.
    • If there are no candidate MeSH terms, skip that MOA.
    • Build a condensed payload (mechanism text, tried terms, pmids, and titles + mesh_terms only).
    • Call the LLM with a constrained prompt that:
        - Chooses EXACTLY one MeSH term from the candidate list, OR
        - Returns "[none]" if no term meaningfully represents the MOA.
    • Enforce that any non-"[none]" choice must be exactly in the candidate list.
    • Persist:
        - Per-MOA choice to investigational_drug_moa_chosen/{moa_id}.json
        - A log entry with prompt + raw response to investigational_drug_moa_chosen_log/{moa_id}.json
        - A master index of all MOA choices in investigational_drug_moa_chosen_master.json

Outputs (to cache/task_3/):
- investigational_drug_moa_chosen/{moa_id}.json
    {
      "moa_id": "...",
      "mechanism": "...",
      "candidate_mesh_terms": [...],
      "chosen_mesh_term": "<one term from candidates or '[none]'>",
      "source_pmid": "<pmid or null>",
      "rationale": "One concise sentence",
      "source": "llm"
    }

- investigational_drug_moa_chosen_log/{moa_id}.json
    (debug/log record with prompt, raw_response, timing, and cost)

- investigational_drug_moa_chosen_master.json
    {
      "<moa_id>": { ...same structure as per-MOA JSON... },
      ...
    }
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
import json
import re
import time
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

from services.openai_wrapper import OpenAIWrapper  # your wrapper

BASE_DIR = Path("cache/task_3")

MOA_PUBMED_DIR          = BASE_DIR / "investigational_drug_moa_pubmed_search"
MOA_CHOICE_DIR          = BASE_DIR / "investigational_drug_moa_chosen"
MOA_CHOICE_LOG_DIR      = BASE_DIR / "investigational_drug_moa_chosen_log"
MASTER_MOA_CHOICES_PATH = BASE_DIR / "investigational_drug_moa_chosen_master.json"

MOA_CHOICE_DIR.mkdir(parents=True, exist_ok=True)
MOA_CHOICE_LOG_DIR.mkdir(parents=True, exist_ok=True)

MAX_WORKERS_MOA = 8
MODEL = "gpt-5"

client = OpenAIWrapper()

# -------------------------------------------------
# HELPERS
# -------------------------------------------------
def extract_json_object(text: str) -> dict:
    """Extract first valid JSON object from model output."""
    if not isinstance(text, str):
        return {}
    text = text.strip()
    if not text:
        return {}

    # Direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    # Fallback: first {...} region
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {}
    try:
        obj = json.loads(m.group(0))
        if isinstance(obj, dict):
            return obj
    except Exception:
        return {}

    return {}

def load_master_moa_choices() -> dict:
    """Load the existing master MOA → MeSH choice index, or {} if missing/invalid."""
    if not MASTER_MOA_CHOICES_PATH.exists():
        return {}
    try:
        return json.loads(MASTER_MOA_CHOICES_PATH.read_text(encoding="utf-8"))
    except Exception:
        return {}

def build_moa_mesh_prompt(moa_payload: dict, candidate_mesh_terms: list[str]) -> str:
    """
    Prompt the LLM to choose the best MeSH term that represents the mechanism of action.
    If no suitable MeSH term exists, the model MUST return "[none]".
    """
    payload_json = json.dumps(moa_payload, ensure_ascii=False, indent=2)
    mesh_json    = json.dumps(candidate_mesh_terms, ensure_ascii=False, indent=2)

    return f"""
You are an expert pharmacologist and MeSH annotation specialist.

You are given:
1) A mechanism-of-action (MOA) text string describing how a drug works.
2) A set of PubMed-derived MeSH terms (candidate list).
3) Condensed PubMed records used for MOA search.

Your tasks:

------------------------------------------------------------
TASK 1 — Select the Best MeSH Term
------------------------------------------------------------
Choose EXACTLY ONE MeSH term that best represents the mechanism of action.

Rules:
- You MUST select a term *only* from the candidate list.
- Choose the most mechanistic/specific pharmacologic concept available
  (e.g., "Receptor Antagonists", "Antibodies, Monoclonal", "Kinase Inhibitors").
- Avoid generic terms ("Humans", "Adult", "Neoplasms") unless absolutely no mechanistic term exists.

------------------------------------------------------------
TASK 2 — Handle Cases with No Good Mechanistic Term
------------------------------------------------------------
If NONE of the candidate MeSH terms meaningfully represent the MOA:

You MUST output:

  "chosen_mesh_term": "[none]",
  "source_pmid": null,
  "rationale": "Explain why no term fits."

This is a VALID and EXPECTED outcome.

------------------------------------------------------------
OUTPUT FORMAT  (STRICT)
------------------------------------------------------------

Return ONLY a valid JSON object with EXACTLY these fields:

{{
  "chosen_mesh_term": "<one exact candidate term OR '[none]'>",
  "source_pmid": "<PMID you relied on OR null>",
  "rationale": "One concise sentence explaining your decision."
}}

Constraints:
- If you choose a MeSH term, it MUST MATCH EXACTLY one item from the candidate list.
- If no suitable term exists, return "[none]".
- JSON must be valid and parseable.

------------------------------------------------------------
MOA Payload (input data)
------------------------------------------------------------
{payload_json}

------------------------------------------------------------
Candidate MeSH Terms
------------------------------------------------------------
{mesh_json}
""".strip()


master_moa_choices = load_master_moa_choices()
master_moa_lock = threading.Lock()

moa_counter = {
    "processed": 0,
    "skipped_existing": 0,
    "llm_error": 0,
    "parse_error": 0,
    "coverage_error": 0,   # includes "chosen term not in JSON-derived list"
    "no_candidates": 0,
}
moa_counter_lock = threading.Lock()


def process_moa_file(fp: Path, idx: int, total: int) -> None:
    """Process a single MOA PubMed-search JSON file and select a MeSH term via LLM."""
    try:
        payload = json.loads(fp.read_text(encoding="utf-8"))
    except Exception as e:
        print(f"[{idx}/{total}] Error reading {fp.name}: {e}")
        with moa_counter_lock:
            moa_counter["parse_error"] += 1
        return

    moa_id = (
        payload.get("moa_id")
        or payload.get("did")   # backward-compat for did_* key
        or fp.stem
    )

    mechanism = (
        payload.get("mechanism")        # if already present
        or payload.get("mechanism_combined")
        or "; ".join(payload.get("product_mechanisms", []) or [])
        or ""
    )

    if not moa_id:
        print(f"[{idx}/{total}] Missing moa_id in {fp.name}, skipping")
        return

    out_fp = MOA_CHOICE_DIR / f"{moa_id}.json"
    if out_fp.exists():
        with moa_counter_lock:
            moa_counter["skipped_existing"] += 1
        return

    match = payload.get("match") or {}
    records = match.get("records") or {}
    pmids = match.get("pmids") or []

    # Collect candidate MeSH terms (unique, in stable order) FROM THE JSON ONLY
    candidate_terms = []
    seen_terms = set()
    for pmid, rec in records.items():
        mesh_terms = rec.get("mesh_terms") or []
        for term in mesh_terms:
            if term and term not in seen_terms:
                seen_terms.add(term)
                candidate_terms.append(term)

    if not candidate_terms:
        print(f"[{idx}/{total}] No candidate MeSH terms for moa_id={moa_id}, skipping")
        with moa_counter_lock:
            moa_counter["no_candidates"] += 1
        return

    # Condensed payload for the model (avoid full abstracts to save tokens)
    condensed_records = {
        pmid: {
            "title": (rec.get("title") or ""),
            "mesh_terms": (rec.get("mesh_terms") or []),
        }
        for pmid, rec in records.items()
    }

    moa_payload = {
        "moa_id": moa_id,
        "mechanism": mechanism,
        "tried_terms": payload.get("tried_terms") or [],
        "pmids": pmids,
        "records": condensed_records,
    }

    prompt = build_moa_mesh_prompt(moa_payload, candidate_terms)

    token = moa_id
    hash_id = moa_id

    text_response = ""
    raw_response = None
    total_cost = 0.0
    elapsed = 0.0

    # Call LLM
    try:
        t0 = time.perf_counter()
        res = client.query(prompt=prompt, model=MODEL)
        elapsed = round(time.perf_counter() - t0, 2)

        text_response = (res.get("text_response") or "").strip()
        raw_response = res.get("raw_response")
        total_cost = float(res.get("cost") or 0.0)
    except Exception as e:
        print(f"[{idx}/{total}] LLM error for moa_id={moa_id}: {e}")
        with moa_counter_lock:
            moa_counter["llm_error"] += 1
        return

    # Parse JSON
    obj = extract_json_object(text_response)

    if not isinstance(obj, dict) or not obj:
        print(f"[{idx}/{total}] JSON parse error moa_id={moa_id}, raw={text_response!r}")
        with moa_counter_lock:
            moa_counter["parse_error"] += 1
        return

    chosen_term = obj.get("chosen_mesh_term")
    source_pmid = obj.get("source_pmid")
    rationale = obj.get("rationale")

    # HARD CHECK: chosen term
    if not chosen_term or not isinstance(chosen_term, str):
        print(f"[{idx}/{total}] Missing or invalid chosen_mesh_term for moa_id={moa_id}")
        with moa_counter_lock:
            moa_counter["coverage_error"] += 1
        return

    # Special allowed sentinel for "no good term"
    if chosen_term == "[none]":
        # Accept even though it's not in candidate_terms
        mapped = {
            "moa_id": moa_id,
            "mechanism": mechanism,
            "candidate_mesh_terms": candidate_terms,
            "chosen_mesh_term": chosen_term,
            "source_pmid": source_pmid,
            "rationale": rationale,
            "source": "llm",
        }
    else:
        # For any real term, it MUST come from the JSON-derived candidate list
        if chosen_term not in candidate_terms:
            # DNE in JSON (hallucinated or modified term) → reject, do NOT save
            print(
                f"[{idx}/{total}] chosen_mesh_term not in JSON-derived candidate list "
                f"for moa_id={moa_id}: {chosen_term!r}"
            )
            with moa_counter_lock:
                moa_counter["coverage_error"] += 1
            return

        # Optional: source_pmid sanity check (must be one of pmids or None)
        if source_pmid is not None and source_pmid not in pmids:
            print(
                f"[{idx}/{total}] source_pmid {source_pmid!r} not in pmids for moa_id={moa_id}; "
                f"still accepting chosen_mesh_term"
            )

        mapped = {
            "moa_id": moa_id,
            "mechanism": mechanism,
            "candidate_mesh_terms": candidate_terms,
            "chosen_mesh_term": chosen_term,
            "source_pmid": source_pmid,
            "rationale": rationale,
            "source": "llm",
        }

    # Save per-MOA JSON
    out_fp.write_text(json.dumps(mapped, ensure_ascii=False, indent=2), encoding="utf-8")

    # Log entry
    log_payload = {
        "token": token,
        "hash_id": hash_id,
        "model": MODEL,
        "prompt": prompt,
        "structured_response": json.dumps(mapped, ensure_ascii=False, indent=2),
        "raw_response": repr(raw_response),
        "total_cost": total_cost,
        "time_elapsed": elapsed,
    }
    (MOA_CHOICE_LOG_DIR / f"{hash_id}.json").write_text(
        json.dumps(log_payload, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    # Update master
    with master_moa_lock:
        master_moa_choices[moa_id] = mapped
        MASTER_MOA_CHOICES_PATH.write_text(
            json.dumps(master_moa_choices, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )

    with moa_counter_lock:
        moa_counter["processed"] += 1
        if moa_counter["processed"] % 50 == 0:
            print(f"Progress: processed {moa_counter['processed']} MOA entries...")


# -------------------------------------------------
# RUN
# -------------------------------------------------
moa_files = sorted(MOA_PUBMED_DIR.glob("*.json"))
total_moa = len(moa_files)
print(f"Loaded {total_moa} MOA PubMed-search files for MeSH-term selection.")

with ThreadPoolExecutor(max_workers=MAX_WORKERS_MOA) as ex:
    futures = {
        ex.submit(process_moa_file, fp, idx, total_moa): fp.name
        for idx, fp in enumerate(moa_files, start=1)
    }
    for fut in as_completed(futures):
        name = futures[fut]
        try:
            fut.result()
        except Exception as e:
            print(f"Worker error (MOA MeSH selection) file={name}: {e}")

print(
    f"MOA MeSH-term selection complete. "
    f"processed={moa_counter['processed']}, "
    f"skipped={moa_counter['skipped_existing']}, "
    f"llm_error={moa_counter['llm_error']}, "
    f"parse_error={moa_counter['parse_error']}, "
    f"coverage_error={moa_counter['coverage_error']}, "
    f"no_candidates={moa_counter['no_candidates']}"
)
print(f"Chosen MOA directory: {MOA_CHOICE_DIR}")
print(f"Log directory:        {MOA_CHOICE_LOG_DIR}")
print(f"Master choices:       {MASTER_MOA_CHOICES_PATH}")

Loaded 127 MOA PubMed-search files for MeSH-term selection.
Progress: processed 50 MOA entries...
Progress: processed 100 MOA entries...
MOA MeSH-term selection complete. processed=127, skipped=0, llm_error=0, parse_error=0, coverage_error=0, no_candidates=0
Chosen MOA directory: cache/task_3/investigational_drug_moa_chosen
Log directory:        cache/task_3/investigational_drug_moa_chosen_log
Master choices:       cache/task_3/investigational_drug_moa_chosen_master.json


In [54]:
"""
Attach chosen MeSH mechanism terms to the per-trial product breakdown table.

What is a "did"?
- A "did" is a hashed drug identifier string like "did_2c2fb9efd4b8a1f837bf47004a49ce45".
- Each did represents one drug, aggregating all known names and metadata for that product.
- did records are stored in:
      cache/task_1/product_id_master_by_did.json
  For example:
      "did_2c2fb9efd4b8a1f837bf47004a49ce45": {
        "did": "did_2c2fb9efd4b8a1f837bf47004a49ce45",
        "tt_drug_id": "9084",
        "drug_names": ["ifosfamide"],
        "alternative_names": ["Holoxan", "Ifex", "iphosphamide", "isophosphamide"],
        "molecular_targets": [],
        "product_mechanisms": ["Alkylating agent (DNA crosslinker)"],
        "trial_hashes": ["tid_bb1e0571142dde8a49976632c349593c"]
      }

Inputs:
- cache/task_1/trial_product_breakdown.csv
    Per-trial product lists, including:
      • investigational_products
      • active_comparators
      • standard_of_care
- cache/task_1/product_id_master_by_did.json
    did → {
      "did",
      "tt_drug_id",
      "drug_names",
      "alternative_names",
      "product_mechanisms",
      "trial_hashes",
      ...
    }
- cache/task_3/investigational_drug_moa_chosen_master.json
    did → {
      "chosen_mesh_term",
      "source_pmid",
      "rationale",
      ...
    }

Process:
- Build did → chosen_mesh_term (ignoring "[none]").
- For each did:
    • build a normalized name set from drug_names + alternative_names
    • use trial_hashes to map dids to the trials where they appear
- For each row of trial_product_breakdown.csv:
    • look up which dids are associated with that trial_hash
    • for each product name in each role column:
        - normalize the name
        - match against candidate did name sets
        - collect 0, 1, or multiple MeSH terms per product
- Create three new list-columns aligned with the existing product lists:
    • investigational_products_mechanism_mesh_terms
    • active_comparators_mechanism_mesh_terms
    • standard_of_care_mechanism_mesh_terms

Outputs:
- cache/task_3/trial_product_breakdown_w_chosen_mechanisms.csv
    Same rows as trial_product_breakdown.csv with the three
    *_mechanism_mesh_terms columns added.
"""

# ----------------------------------------
# CONFIG
# ----------------------------------------
import ast
import json
from pathlib import Path

import pandas as pd

IN_BASE_DIR  = Path("cache/task_1")
MOA_BASE_DIR = Path("cache/task_3")

IN_BREAKDOWN_CSV      = IN_BASE_DIR  / "trial_product_breakdown.csv"
PRODUCT_MASTER_BY_DID = IN_BASE_DIR  / "product_id_master_by_did.json"
MOA_MASTER_PATH       = MOA_BASE_DIR / "investigational_drug_moa_chosen_master.json"
OUT_BREAKDOWN_CSV     = MOA_BASE_DIR / "trial_product_breakdown_w_chosen_mechanisms.csv"

# ----------------------------------------
# HELPERS
# ----------------------------------------
def parse_listish(x):
    """Parse strings like "['a','b']" into Python lists."""
    if isinstance(x, list):
        return x
    if x is None:
        return []
    s = str(x).strip()
    if not s or s in ("[]", "[ ]"):
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
        return [v]
    except Exception:
        return [s]

def insert_after(df: pd.DataFrame, col: str, newcol: str, values):
    """
    Insert a new column `newcol` with `values` immediately after `col`.
    If `col` is missing, append `newcol` at the end.
    """
    cols = list(df.columns)
    if col not in cols:
        df[newcol] = values
        return
    idx = cols.index(col)
    df.insert(idx + 1, newcol, values)

def is_none_term(s: str) -> bool:
    """Return True if the MeSH term is '[none]' or empty/None-like."""
    if not s:
        return True
    s2 = str(s).strip().lower()
    return s2 in ("[none]", "none", "")

def norm_name(s: str) -> str:
    """Normalize a product name for matching: lowercase and stripped."""
    return str(s).strip().lower()

# ----------------------------------------
# RUN
# ----------------------------------------

# Load main trial breakdown
df = pd.read_csv(IN_BREAKDOWN_CSV, dtype=str).fillna("")
print(f"Loaded trial breakdown: {IN_BREAKDOWN_CSV}, shape={df.shape}")

if not PRODUCT_MASTER_BY_DID.exists():
    raise FileNotFoundError(f"Missing product master: {PRODUCT_MASTER_BY_DID}")

if not MOA_MASTER_PATH.exists():
    raise FileNotFoundError(f"Missing MOA master: {MOA_MASTER_PATH}")

product_by_did = json.loads(PRODUCT_MASTER_BY_DID.read_text(encoding="utf-8"))
moa_master     = json.loads(MOA_MASTER_PATH.read_text(encoding="utf-8"))

# did → chosen_mesh_term (only keep non-[none] terms)
did_to_mesh: dict[str, str] = {}
for did, rec in moa_master.items():
    mesh = (rec.get("chosen_mesh_term") or "").strip()
    if mesh and not is_none_term(mesh):
        did_to_mesh[did] = mesh

print(f"did_to_mesh (non-[none]) count: {len(did_to_mesh)}")

# Build:
#   trial_hash → list[did]
#   did → set(normalized names)
trial_to_dids: dict[str, list[str]] = {}
did_to_names_norm: dict[str, set[str]] = {}

for did, rec in product_by_did.items():
    # Only consider dids that actually have a chosen MeSH term
    if did not in did_to_mesh:
        continue

    drug_names = rec.get("drug_names", []) or []
    alt_names  = rec.get("alternative_names", []) or []
    all_names  = set(drug_names) | set(alt_names)

    names_norm = {norm_name(n) for n in all_names if str(n).strip()}
    if not names_norm:
        continue

    did_to_names_norm[did] = names_norm

    trial_hashes = rec.get("trial_hashes", []) or []
    for th in trial_hashes:
        th_str = str(th).strip()
        if not th_str:
            continue
        trial_to_dids.setdefault(th_str, []).append(did)

print(f"trial_to_dids entries: {len(trial_to_dids)}")
print(f"did_to_names_norm entries: {len(did_to_names_norm)}")

# For each trial row, build MeSH lists per role
ROLE_NAME_SPECS = [
    ("investigational_products", "investigational_products_mechanism_mesh_terms"),
    ("active_comparators",      "active_comparators_mechanism_mesh_terms"),
    ("standard_of_care",        "standard_of_care_mechanism_mesh_terms"),
]

# Prepare containers for new columns
new_cols = {spec[1]: [] for spec in ROLE_NAME_SPECS}

for _, row in df.iterrows():
    trial_hash = str(row.get("trial_hash", "")).strip()
    candidate_dids = trial_to_dids.get(trial_hash, [])

    # If no dids linked to this trial, all role columns get empty-string lists
    if not candidate_dids:
        for base_col, new_col in ROLE_NAME_SPECS:
            names_list = parse_listish(row.get(base_col, ""))
            mesh_list = ["" for _ in names_list]
            new_cols[new_col].append(str(mesh_list))
        continue

    # Otherwise, match each product name to candidate dids by normalized name
    for base_col, new_col in ROLE_NAME_SPECS:
        names_list = parse_listish(row.get(base_col, ""))
        mesh_list = []

        for prod_name in names_list:
            nn = norm_name(prod_name)
            if not nn:
                mesh_list.append("")
                continue

            meshes_for_this = set()

            for did in candidate_dids:
                name_set = did_to_names_norm.get(did, set())
                if nn in name_set:
                    mesh = did_to_mesh.get(did, "")
                    if mesh:
                        meshes_for_this.add(mesh)

            if not meshes_for_this:
                mesh_list.append("")
            elif len(meshes_for_this) == 1:
                mesh_list.append(next(iter(meshes_for_this)))
            else:
                # Multiple dids mapping to different MeSH terms
                mesh_list.append("; ".join(sorted(meshes_for_this)))

        new_cols[new_col].append(str(mesh_list))

# Attach new columns next to their corresponding name columns
for base_col, new_col in ROLE_NAME_SPECS:
    insert_after(df, base_col, new_col, new_cols[new_col])

# Save output
OUT_BREAKDOWN_CSV.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(OUT_BREAKDOWN_CSV, index=False)

print(f"✔️ Wrote trial breakdown with chosen MeSH mechanisms → {OUT_BREAKDOWN_CSV}")
print(df.head(5).to_markdown(index=False))

Loaded trial breakdown: cache/task_1/trial_product_breakdown.csv, shape=(184, 23)
did_to_mesh (non-[none]) count: 126
trial_to_dids entries: 183
did_to_names_norm entries: 126
✔️ Wrote trial breakdown with chosen MeSH mechanisms → cache/task_3/trial_product_breakdown_w_chosen_mechanisms.csv
| trial_hash                           | investigational_products                           | investigational_products_mechanism_mesh_terms                        | investigational_products_alternative_names                                                                                                                                                                                                          | investigational_products_molecular_target   | investigational_products_mechanism                                         | investigational_products_tt_drug_id   | investigational_products_bmt_drug_id   | active_comparators   | active_comparators_mechanism_mesh_terms   | active_comparators_alterna

In [1]:
"""
Download MeSH descriptor & supplementary XML files for 2025.

This utility fetches the core MeSH XML files directly from NLM and saves them
locally so that downstream tasks can parse MeSH tree numbers, term names,
synonyms, and supplementary concept records.

Inputs (remote only):
- Base URL:
      https://nlmpubs.nlm.nih.gov/projects/mesh/MESH_FILES/xmlmesh
- Files downloaded:
      • desc2025.xml  – main descriptor records (MeSH headings, tree numbers)
      • supp2025.xml  – supplementary concept records

Outputs (to cache/task_3/):
- cache/task_3/desc2025.xml
- cache/task_3/supp2025.xml

Behavior:
- If a file already exists locally and is non-empty, it is skipped.
- Otherwise the file is fetched via HTTP GET and saved to disk.
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
from pathlib import Path
import requests

BASE_URL = "https://nlmpubs.nlm.nih.gov/projects/mesh/MESH_FILES/xmlmesh"

BASE_DIR = Path("cache/task_3")
OUT_DIR = BASE_DIR
OUT_DIR.mkdir(parents=True, exist_ok=True)

FILES = ["desc2025.xml", "supp2025.xml"]

# -------------------------------------------------
# HELPERS
# -------------------------------------------------
def download_file(filename: str) -> None:
    """
    Download a single MeSH XML file by name into OUT_DIR.

    Skips download if the file already exists and has non-zero size.
    Raises an HTTPError if the request fails.
    """
    url = f"{BASE_URL}/{filename}"
    out_path = OUT_DIR / filename

    if out_path.exists() and out_path.stat().st_size > 0:
        print(f"Skipping {filename}, already exists at {out_path}.")
        return

    print(f"⬇ Downloading {url} -> {out_path}")
    resp = requests.get(url, timeout=60)
    resp.raise_for_status()

    out_path.write_bytes(resp.content)
    print(f"Downloaded {filename} ({len(resp.content)} bytes).")

# -------------------------------------------------
# RUN
# -------------------------------------------------
for fname in FILES:
    try:
        download_file(fname)
    except Exception as e:
        print(f"❗ Error downloading {fname}: {e}")

print("Done fetching MeSH XML files for 2025.")

Skipping desc2025.xml, already exists at cache/task_3/desc2025.xml.
Skipping supp2025.xml, already exists at cache/task_3/supp2025.xml.
Done fetching MeSH XML files for 2025.


In [2]:
"""
Build an in-memory MeSH term index from descriptor and supplementary XML.

Inputs (from cache/task_3/):
- desc2025.xml  – main MeSH descriptor records (headings, tree numbers, scope notes)
- supp2025.xml  – supplementary concept records (SCRs), often mapped to descriptors

Process:
- Parse DescriptorRecords:
    • Collect DescriptorUI, tree numbers, preferred heading, scope note.
    • Collect all associated terms (preferred + synonyms) and normalize them.
    • Map each normalized term → {mesh_id, tree_numbers, scope_note}.
- Parse SupplementalRecords:
    • Collect all names (record names, concept names, term strings).
    • Use any direct SCR tree numbers.
    • If no trees, follow HeadingMappedTo descriptors to inherit tree numbers
      and scope notes from the mapped descriptors.
    • Map each normalized SCR name → {mesh_id, tree_numbers, scope_note},
      but do not overwrite an existing descriptor (D-UI) mapping.

Outputs:
- TREE_INDEX (in memory):
    dict[normalized_term] = {
        "mesh_id": <DescriptorUI or SupplementalRecordUI>,
        "tree_numbers": [<tree number strings>],
        "scope_note": <cleaned scope note or "">
    }

- Console summary:
    • Number of unique normalized MeSH terms loaded.
    • A few spot-check lookups (e.g., sotatercept, ACE-011, winrevair).
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------
from pathlib import Path
import html
import unicodedata
import xml.etree.ElementTree as ET

BASE_DIR = Path("cache/task_3")
DESC_XML = BASE_DIR / "desc2025.xml"
SUPP_XML = BASE_DIR / "supp2025.xml"

# -------------------------------------------------
# HELPERS
# -------------------------------------------------
def norm(s: str) -> str:
    """Normalize a string for lookup: lowercase, unicode-clean, collapse spaces."""
    if not isinstance(s, str):
        return ""
    t = html.unescape(s)
    t = unicodedata.normalize("NFKC", t)
    t = (
        t.replace("\u2019", "'")
         .replace("\u2018", "'")
         .replace("\u2032", "'")
         .replace("\u2033", '"')
         .replace("\u201C", '"')
         .replace("\u201D", '"')
         .replace("\u2010", "-")
         .replace("\u2011", "-")
         .replace("\u2012", "-")
         .replace("\u2013", "-")
         .replace("\u2014", "-")
    )
    return " ".join(t.strip().lower().split())


def clean_text(s: str) -> str:
    """Unicode-clean + collapse whitespace (preserve case)."""
    if not isinstance(s, str):
        return ""
    t = html.unescape(s)
    t = unicodedata.normalize("NFKC", t)
    t = (
        t.replace("\u2019", "'")
         .replace("\u2018", "'")
         .replace("\u2032", "'")
         .replace("\u2033", '"')
         .replace("\u201C", '"')
         .replace("\u201D", '"')
         .replace("\u2010", "-")
         .replace("\u2011", "-")
         .replace("\u2012", "-")
         .replace("\u2013", "-")
         .replace("\u2014", "-")
    )
    return " ".join(t.strip().split())


def _dedup(seq):
    """Preserve order while removing duplicates and empty entries."""
    seen = set()
    out = []
    for x in seq:
        if x and x not in seen:
            seen.add(x)
            out.append(x)
    return out


def _extract_scope_note_from_descriptor(rec: ET.Element) -> str:
    """
    Prefer the ScopeNote of the PreferredConcept (PreferredConceptYN='Y'),
    else fall back to the first ScopeNote present under any Concept.
    """
    pref = rec.find(".//ConceptList/Concept[@PreferredConceptYN='Y']/ScopeNote")
    if pref is not None and pref.text:
        return clean_text(pref.text)

    any_sn = rec.find(".//ConceptList/Concept/ScopeNote")
    if any_sn is not None and any_sn.text:
        return clean_text(any_sn.text)

    return ""


def _extract_scope_note_from_supp(rec: ET.Element) -> str:
    """
    For SCRs, ScopeNote can also live under Concept.
    Prefer the PreferredConcept (if flagged), else the first available.
    """
    pref = rec.find(".//ConceptList/Concept[@PreferredConceptYN='Y']/ScopeNote")
    if pref is not None and pref.text:
        return clean_text(pref.text)

    any_sn = rec.find(".//ConceptList/Concept/ScopeNote")
    if any_sn is not None and any_sn.text:
        return clean_text(any_sn.text)

    return ""


def load_mesh_tree_and_id(
    desc_xml_fp: Path,
    supp_xml_fp: Path,
) -> dict[str, dict[str, list[str] | str]]:
    """
    Build a mapping:
        normalized_term -> {"mesh_id", "tree_numbers", "scope_note"}
    by combining MeSH Descriptor and Supplementary Concept XML.
    """
    term_map: dict[str, dict[str, list[str] | str]] = {}

    # Helper maps for SCR fallbacks
    heading_to_tree: dict[str, list[str]] = {}
    ui_to_tree: dict[str, list[str]] = {}
    heading_to_scope: dict[str, str] = {}
    ui_to_scope: dict[str, str] = {}

    # --- Descriptors ---
    if desc_xml_fp.exists():
        root = ET.parse(desc_xml_fp).getroot()
        for rec in root.findall(".//DescriptorRecord"):
            desc_ui = (rec.findtext("DescriptorUI") or "").strip()
            tree_numbers = _dedup(
                [
                    tn.text.strip()
                    for tn in rec.findall(".//TreeNumberList/TreeNumber")
                    if tn.text
                ]
            )

            heading_raw = rec.findtext("DescriptorName/String")
            heading_norm = norm(heading_raw) if heading_raw else ""
            scope_note = _extract_scope_note_from_descriptor(rec)

            if heading_norm:
                heading_to_tree[heading_norm] = tree_numbers
                heading_to_scope[heading_norm] = scope_note
            if desc_ui:
                ui_to_tree[desc_ui] = tree_numbers
                ui_to_scope[desc_ui] = scope_note

            # Collect all terms mapped to this descriptor
            terms = set()
            if heading_raw:
                terms.add(heading_norm)
            for concept in rec.findall(".//Concept"):
                for term in concept.findall(".//Term"):
                    s = term.findtext("String")
                    if s:
                        terms.add(norm(s))

            for term in terms:
                term_map[term] = {
                    "mesh_id": desc_ui,
                    "tree_numbers": tree_numbers,
                    "scope_note": scope_note,
                }

    # --- Supplementary (SCRs) ---
    if supp_xml_fp.exists():
        root = ET.parse(supp_xml_fp).getroot()
        for rec in root.findall(".//SupplementalRecord"):
            supp_ui = (rec.findtext("SupplementalRecordUI") or "").strip()

            # Collect ALL names for this SCR
            names = set()

            for s in rec.findall(".//SupplementalRecordName/String"):
                if s is not None and s.text:
                    names.add(norm(s.text))

            for s in rec.findall(".//ConceptList/Concept/ConceptName/String"):
                if s is not None and s.text:
                    names.add(norm(s.text))

            for s in rec.findall(".//ConceptList/Concept/TermList/Term/String"):
                if s is not None and s.text:
                    names.add(norm(s.text))

            # Direct trees (often none for SCRs)
            tree_numbers = [
                tn.text.strip()
                for tn in rec.findall(".//TreeNumberList/TreeNumber")
                if tn.text
            ]

            # SCR scope note (preferred concept first)
            scr_scope_note = _extract_scope_note_from_supp(rec)

            # Fallback via HeadingMappedTo (names and UIs)
            mapped_scope_note = ""
            if not tree_numbers:
                # Try mapped names
                mapped_names = [
                    n.text.strip()
                    for n in rec.findall(
                        ".//HeadingMappedTo/DescriptorReferredTo/DescriptorName/String"
                    )
                    if n is not None and n.text
                ]
                for m in mapped_names:
                    m_norm = norm(m)
                    tns = heading_to_tree.get(m_norm)
                    if tns:
                        tree_numbers.extend(tns)
                    if (
                        not mapped_scope_note
                        and m_norm in heading_to_scope
                        and heading_to_scope[m_norm]
                    ):
                        mapped_scope_note = heading_to_scope[m_norm]

                # Try mapped UIs
                mapped_uis = [
                    u.text.strip().lstrip("*")
                    for u in rec.findall(
                        ".//HeadingMappedTo/DescriptorReferredTo/DescriptorUI"
                    )
                    if u is not None and u.text
                ]
                for mui in mapped_uis:
                    tns = ui_to_tree.get(mui)
                    if tns:
                        tree_numbers.extend(tns)
                    if (
                        not mapped_scope_note
                        and mui in ui_to_scope
                        and ui_to_scope[mui]
                    ):
                        mapped_scope_note = ui_to_scope[mui]

                tree_numbers = _dedup(tree_numbers)

            final_scope_note = scr_scope_note or mapped_scope_note or ""

            for name in names:
                # Keep Descriptor mapping if already present for same term
                if name in term_map and str(term_map[name].get("mesh_id", "")).startswith(
                    "D"
                ):
                    continue
                term_map[name] = {
                    "mesh_id": supp_ui,
                    "tree_numbers": tree_numbers,
                    "scope_note": final_scope_note,
                }

    return term_map

# -------------------------------------------------
# RUN
# -------------------------------------------------
TREE_INDEX = load_mesh_tree_and_id(DESC_XML, SUPP_XML)
print(f"Loaded MeSH index terms: {len(TREE_INDEX):,} unique normalized terms")

# Quick checks
for q in ["sotatercept", "ACE-011", "winrevair"]:
    k = norm(q)
    print(q, "→", TREE_INDEX.get(k))

Loaded MeSH index terms: 992,888 unique normalized terms
sotatercept → {'mesh_id': 'C542017', 'tree_numbers': ['D12.776.828.300'], 'scope_note': 'Recombinant proteins produced by the GENETIC TRANSLATION of fused genes formed by the combination of NUCLEIC ACID REGULATORY SEQUENCES of one or more genes with the protein coding sequences of one or more genes.'}
ACE-011 → {'mesh_id': 'C542017', 'tree_numbers': ['D12.776.828.300'], 'scope_note': 'Recombinant proteins produced by the GENETIC TRANSLATION of fused genes formed by the combination of NUCLEIC ACID REGULATORY SEQUENCES of one or more genes with the protein coding sequences of one or more genes.'}
winrevair → {'mesh_id': 'C542017', 'tree_numbers': ['D12.776.828.300'], 'scope_note': 'Recombinant proteins produced by the GENETIC TRANSLATION of fused genes formed by the combination of NUCLEIC ACID REGULATORY SEQUENCES of one or more genes with the protein coding sequences of one or more genes.'}


In [3]:
"""
Map chosen MeSH mechanism terms to MeSH tree numbers for each trial and role.

Inputs (from cache/task_3/):
- trial_product_breakdown_w_chosen_mechanisms.csv
    One row per trial, including:
        • trial_hash
        • investigational_products_mechanism_mesh_terms
        • active_comparators_mechanism_mesh_terms
        • standard_of_care_mechanism_mesh_terms
    Each of the *_mechanism_mesh_terms columns is a list-like structure of
    MeSH headings selected in the previous step (e.g., ["Alkylating Agents", ...]).

- In-memory MeSH index from the MeSH loader cell:
    • TREE_INDEX: dict[normalized_term] = {"mesh_id", "tree_numbers", "scope_note"}
    • norm(): helper used to normalize MeSH headings for lookup

Process:
1. For each MeSH heading in the *_mechanism_mesh_terms columns:
    - Look up all MeSH tree numbers via TREE_INDEX (mesh_heading_to_tree_numbers).
2. For each list of tree numbers belonging to a single MeSH term:
    - Pick one "primary" tree number using a simple heuristic that prefers
      pharmacologically relevant branches (D12, D27, D02, etc.).
3. Build six new columns:
    • investigational_products_mechanism_tree_numbers
    • investigational_products_mechanism_primary_tree_numbers
    • active_comparators_mechanism_tree_numbers
    • active_comparators_mechanism_primary_tree_numbers
    • standard_of_care_mechanism_tree_numbers
    • standard_of_care_mechanism_primary_tree_numbers

Outputs (to cache/task_3/):
- trial_mechanism_mesh_mapping.csv
    Same rows as input, plus the six new tree-number columns inserted
    next to the corresponding *_mechanism_mesh_terms columns.
"""

# -----------------------------------------
# CONFIG
# -----------------------------------------
import ast
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache/task_3")
TRIALS_IN_PATH = BASE_DIR / "trial_product_breakdown_w_chosen_mechanisms.csv"
TRIALS_OUT_PATH = BASE_DIR / "trial_mechanism_mesh_mapping.csv"

# -----------------------------------------
# Sanity: TREE_INDEX and norm must already be loaded
# -----------------------------------------
try:
    TREE_INDEX
except NameError:
    raise RuntimeError("TREE_INDEX is not defined — run the MeSH loader cell first.")

try:
    norm
except NameError:
    raise RuntimeError("norm() is not defined — ensure it is defined in the MeSH loader cell.")

# -----------------------------------------
# HELPERS
# -----------------------------------------
def parse_listish(x):
    """Parse a stringified Python list (or empty) into a real list."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
    except Exception:
        pass
    return [s]


def mesh_heading_to_tree_numbers(chosen_mesh_term: str):
    """
    Return *all* tree numbers for the chosen MeSH heading using TREE_INDEX.
    """
    if not chosen_mesh_term or chosen_mesh_term == "[none]":
        return []
    base = chosen_mesh_term.split(" / ")[0].replace("*", "").strip()
    key = norm(base)
    info = TREE_INDEX.get(key)
    if not info:
        return []
    return info.get("tree_numbers", []) or []


# Clinical-pharmacology-ish heuristic for ONE primary tree number
PRIORITY_PREFIXES = [
    "D12.",  # Proteins: receptors, enzymes, cytokines, antibodies (biologics / targets)
    "D27.",  # Chemical Actions and Uses: classic pharmacologic classes
    "D02.",  # Organic Chemicals: small-molecule drugs
    "D09.",  # Carbohydrates
    "D23.",  # Immunologic Factors
    "D26.",  # Biological Factors
]

def _depth(tn: str) -> int:
    """Depth proxy for a tree number: more segments = deeper (more specific)."""
    return len(tn.split("."))

def choose_primary_tree(tree_numbers):
    """
    Given a list of tree numbers for ONE MeSH term,
    choose a single 'primary' tree that best reflects the
    pharmacologic / target-level concept.
    """
    if not tree_numbers:
        return ""

    tns = [t.strip() for t in tree_numbers if isinstance(t, str) and t.strip()]
    if not tns:
        return ""

    # 1) Prefer specific high-value branches (by prefix)
    for prefix in PRIORITY_PREFIXES:
        candidates = [t for t in tns if t.startswith(prefix)]
        if candidates:
            # choose the highest-level (shortest depth) node in that branch
            return max(candidates, key=_depth)

    # 2) Else prefer any Chemicals & Drugs branch (D*)
    d_candidates = [t for t in tns if t.startswith("D")]
    if d_candidates:
        return max(d_candidates, key=_depth)

    # 3) Fallback: shortest overall
    return max(tns, key=_depth)


def trees_for_mesh_term_list(mesh_term_list):
    """
    Given a list of MeSH headings (already mapped mechanism terms),
    return:
      all_tree_lists : list of [list-of-tree-numbers] per term
      primary_trees  : list of ONE chosen tree number per term ("" if none)
    """
    all_tree_lists = []
    primary_trees = []

    for term in mesh_term_list:
        term_str = (term or "").strip()
        if not term_str:
            all_tree_lists.append([])
            primary_trees.append("")
            continue

        all_trees = mesh_heading_to_tree_numbers(term_str)
        all_tree_lists.append(all_trees)

        primary = choose_primary_tree(all_trees)
        primary_trees.append(primary)

    return all_tree_lists, primary_trees


def insert_after(df, col, newcol, values):
    """
    Insert a new column immediately after `col` if it exists,
    otherwise append it to the end of the DataFrame.
    """
    cols = list(df.columns)
    if col not in cols:
        df[newcol] = values
        return
    idx = cols.index(col)
    df.insert(idx + 1, newcol, values)

# -----------------------------------------
# RUN
# -----------------------------------------
# Load trial dataset
df = pd.read_csv(TRIALS_IN_PATH, dtype=str).fillna("")
print(f"Loaded trials: {TRIALS_IN_PATH}, shape={df.shape}")

# Compute tree-number columns row-wise
inv_trees_all = []
inv_trees_primary = []

ac_trees_all = []
ac_trees_primary = []

soc_trees_all = []
soc_trees_primary = []

for _, row in df.iterrows():
    inv_mesh_terms = parse_listish(row.get("investigational_products_mechanism_mesh_terms"))
    ac_mesh_terms = parse_listish(row.get("active_comparators_mechanism_mesh_terms"))
    soc_mesh_terms = parse_listish(row.get("standard_of_care_mechanism_mesh_terms"))

    inv_all_t, inv_primary_t = trees_for_mesh_term_list(inv_mesh_terms)
    ac_all_t, ac_primary_t = trees_for_mesh_term_list(ac_mesh_terms)
    soc_all_t, soc_primary_t = trees_for_mesh_term_list(soc_mesh_terms)

    inv_trees_all.append(inv_all_t)
    inv_trees_primary.append(inv_primary_t)

    ac_trees_all.append(ac_all_t)
    ac_trees_primary.append(ac_primary_t)

    soc_trees_all.append(soc_all_t)
    soc_trees_primary.append(soc_primary_t)

# Insert columns next to the mechanism_mesh_terms columns
insert_after(
    df,
    "investigational_products_mechanism_mesh_terms",
    "investigational_products_mechanism_tree_numbers",
    inv_trees_all,
)
insert_after(
    df,
    "investigational_products_mechanism_tree_numbers",
    "investigational_products_mechanism_primary_tree_numbers",
    inv_trees_primary,
)

insert_after(
    df,
    "active_comparators_mechanism_mesh_terms",
    "active_comparators_mechanism_tree_numbers",
    ac_trees_all,
)
insert_after(
    df,
    "active_comparators_mechanism_tree_numbers",
    "active_comparators_mechanism_primary_tree_numbers",
    ac_trees_primary,
)

insert_after(
    df,
    "standard_of_care_mechanism_mesh_terms",
    "standard_of_care_mechanism_tree_numbers",
    soc_trees_all,
)
insert_after(
    df,
    "standard_of_care_mechanism_tree_numbers",
    "standard_of_care_mechanism_primary_tree_numbers",
    soc_trees_primary,
)

# Save output
TRIALS_OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(TRIALS_OUT_PATH, index=False)
print(f"✅ Wrote: {TRIALS_OUT_PATH}")
print(df.head(5).to_markdown(index=False))

Loaded trials: cache/task_3/trial_product_breakdown_w_chosen_mechanisms.csv, shape=(184, 26)
✅ Wrote: cache/task_3/trial_mechanism_mesh_mapping.csv
| trial_hash                           | investigational_products                           | investigational_products_mechanism_mesh_terms                        | investigational_products_mechanism_tree_numbers                                 | investigational_products_mechanism_primary_tree_numbers   | investigational_products_alternative_names                                                                                                                                                                                                          | investigational_products_molecular_target   | investigational_products_mechanism                                         | investigational_products_tt_drug_id   | investigational_products_bmt_drug_id   | active_comparators   | active_comparators_mechanism_mesh_terms   | active_comparators_mechanism

In [4]:
"""
Summarize MeSH mechanism terms by primary tree number across all trial mechanisms.

Inputs:
- CSV: cache/task_3/trial_mechanism_mesh_mapping.csv
    One row per trial, with per-role lists of:
      • MeSH mechanism terms
      • Primary MeSH tree numbers (one primary tree per mechanism term)

Process:
- For each of three mechanism roles:
      • investigational_products
      • active_comparators
      • standard_of_care
  parse list-like columns for:
      • mechanism MeSH terms
      • primary tree numbers
- Iterate over aligned (mesh_term, primary_tree_number) pairs.
- Drop empty / '[none]' mesh terms and empty tree numbers.
- Count how often each (mesh_term, primary_tree_number) pair occurs across trials.

Outputs:
- Table of counts per (mesh_term, primary_tree_number), sorted by descending count:
      cache/task_3/trial_mechanism_mesh_tree_number_counts.csv
- Console preview of the top 20 (mesh_term, tree_number) combinations.
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------

import ast
from collections import Counter, defaultdict
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache/task_3")
INPUT_PATH  = BASE_DIR / "trial_mechanism_mesh_mapping.csv"
OUTPUT_PATH = BASE_DIR / "trial_mechanism_mesh_tree_number_counts.csv"

# Columns with MeSH mechanism terms (already mapped)
MESH_TERM_COLS = [
    "investigational_products_mechanism_mesh_terms",
    "active_comparators_mechanism_mesh_terms",
    "standard_of_care_mechanism_mesh_terms",
]

# Columns with *primary* tree numbers (one tree per mechanism)
TREE_COLS = [
    "investigational_products_mechanism_primary_tree_numbers",
    "active_comparators_mechanism_primary_tree_numbers",
    "standard_of_care_mechanism_primary_tree_numbers",
]

# -------------------------------------------------
# HELPERS
# -------------------------------------------------

def parse_listish(x):
    """Parse list-like strings safely into Python lists."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        val = ast.literal_eval(s)
        return val if isinstance(val, list) else []
    except Exception:
        return []

def is_none_term(s: str) -> bool:
    """Treat '[none]' / 'none' / empty as unusable."""
    if not isinstance(s, str):
        return True
    t = s.strip().lower()
    return t in ("", "[none]", "none")

# -------------------------------------------------
# RUN
# -------------------------------------------------

df = pd.read_csv(INPUT_PATH)

tree_to_mesh_terms = defaultdict(list)
pair_counter = Counter()   # (mesh_term, primary_tree_number) → count

for _, row in df.iterrows():
    # Parse lists for each mechanism category
    mesh_term_lists   = [parse_listish(row.get(c, "[]")) for c in MESH_TERM_COLS]
    tree_number_lists = [parse_listish(row.get(c, "[]")) for c in TREE_COLS]

    # Iterate over the three mechanism categories in parallel
    for mesh_terms, tree_nums in zip(mesh_term_lists, tree_number_lists):
        for mesh_term, primary_tn in zip(mesh_terms, tree_nums):
            if is_none_term(mesh_term):
                continue
            if not isinstance(primary_tn, str) or not primary_tn.strip():
                continue

            tn = primary_tn.strip()
            mt = mesh_term.strip()

            pair_counter[(mt, tn)] += 1
            tree_to_mesh_terms[tn].append(mt)

# Convert to DataFrame
out_rows = [
    {
        "mesh_term": mesh_term,
        "tree_number": tree_num,
        "count": count,
    }
    for (mesh_term, tree_num), count in pair_counter.items()
]

out_df = pd.DataFrame(out_rows).sort_values("count", ascending=False)

# Save
out_df.to_csv(OUTPUT_PATH, index=False)

print(f"Saved tree number + term breakdown → {OUTPUT_PATH}")
print("Top 20 combinations:")
print(out_df.head(20).to_markdown(index=False))

Saved tree number + term breakdown → cache/task_3/trial_mechanism_mesh_tree_number_counts.csv
Top 20 combinations:
| mesh_term                                                                    | tree_number                             |   count |
|:-----------------------------------------------------------------------------|:----------------------------------------|--------:|
| Receptors, Erythropoietin / agonists*                                        | D12.776.543.750.705.852.150.200         |      27 |
| Cross-Linking Reagents                                                       | D27.720.470.410.210                     |      13 |
| Interleukin-4 Receptor alpha Subunit / antagonists & inhibitors; immunology* | D12.776.543.750.705.852.420.360.300.200 |      13 |
| Receptors, Thrombopoietin / agonists                                         | D12.776.543.750.705.852.610             |      13 |
| ErbB Receptors / antagonists & inhibitors*; metabolism                       | D12.77

In [5]:
"""
Classify MeSH mechanism terms into 9 high-level MOA super-groups.

Inputs:
- CSV: cache/task_3/trial_mechanism_mesh_tree_number_counts.csv
    Contains one row per (mesh_term, tree_number) pair with a "count" column:
        • mesh_term   : MeSH mechanism term chosen for a drug mechanism.
        • tree_number : Primary MeSH tree number used for that mechanism.
        • count       : Frequency of this (mesh_term, tree_number) across trials.

Process:
- For each row, apply a deterministic heuristic classifier that:
    • Uses both mesh_term text and tree_number branches.
    • Maps each mechanism into one of 9 super-groups:
        G1: cytokine_hormone_receptor_modulators
        G2: immune_checkpoint_immune_modulation
        G3: targeted_pathway_inhibitors
        G4: classical_cytotoxic_chemotherapy
        G5: biologic_antibodies_biologics
        G6: small_molecule_immunomod_antiinflammatory
        G7: metabolic_pathway_modulators
        G8: vaccines_immune_biologics
        G9: supportive_adjunctive_agents

Outputs:
- CSV with one row per (mesh_term, tree_number) and its assigned super-group:
      cache/task_3/trial_mechanism_super_group_mapping.csv

  Columns:
      mesh_term
      tree_number
      count
      mechanism_super_group

- Console preview of the full mapping table (markdown).
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------

import pandas as pd
from pathlib import Path

BASE_DIR = Path("cache/task_3")
INPUT_PATH  = BASE_DIR / "trial_mechanism_mesh_tree_number_counts.csv"
OUTPUT_PATH = BASE_DIR / "trial_mechanism_super_group_mapping.csv"

# -------------------------------------------------
# SUPER-GROUP LABELS
# -------------------------------------------------

# Define super-group labels (9 buckets)
G1 = "cytokine_hormone_receptor_modulators"
G2 = "immune_checkpoint_immune_modulation"
G3 = "targeted_pathway_inhibitors"
G4 = "classical_cytotoxic_chemotherapy"
G5 = "biologic_antibodies_biologics"
G6 = "small_molecule_immunomod_antiinflammatory"
G7 = "metabolic_pathway_modulators"
G8 = "vaccines_immune_biologics"
G9 = "supportive_adjunctive_agents"

# -------------------------------------------------
# HELPERS
# -------------------------------------------------

def classify_super_group(mesh_term: str, tree_number: str) -> str:
    """
    Heuristic mapping of mesh_term + tree_number to one of 9 MOA super-groups.
    Think like a clinical pharmacologist, but keep it deterministic and simple.
    """
    # Robust string coercion
    t = str(mesh_term or "").lower()
    tn = str(tree_number or "").strip()

    # -----------------------
    # GROUP 8: Vaccines & immune biologics (non-mAb)
    # -----------------------
    if "vaccine" in t:
        return G8
    if "recombinant fusion proteins" in t:
        return G8

    # -----------------------
    # GROUP 5: Biologic antibodies (mono / bispecific / CAR / fusion)
    # -----------------------
    if (
        "antibod" in t
        or "immunoconjugate" in t
        or "chimeric antigen" in t
    ):
        return G5
    # Core monoclonal antibody/fusion protein MeSH branches
    if tn.startswith("D12.776.124.486.485.114"):  # Antibodies, Monoclonal*
        return G5
    if tn.startswith("D12.776.124.790.651.114"):  # Therapeutic mAbs under Immunologic Factors
        return G5
    if tn.startswith("D12.776.828.300"):  # Recombinant Fusion Proteins
        return G5

    # -----------------------
    # GROUP 1: Cytokine & hormone receptor modulators
    # (EPO-R, TPO-R, IL-2R, glucocorticoid receptor, Ca-sensing receptor, etc.)
    # -----------------------
    if any(kw in t for kw in [
        "receptors, erythropoietin",
        "erythropoietin",
        "receptors, thrombopoietin",
        "thrombopoietin",
        "receptors, interleukin-2",
        "interleukin-2 receptor alpha subunit",
        "receptors, glucocorticoid",
        "receptors, calcium-sensing",
    ]):
        return G1
    # Hematopoietic / cytokine receptor MeSH branches seen in your table
    if tn.startswith("D12.776.543.750.705.852.150"):  # EPO-R agonists
        return G1
    if tn.startswith("D12.776.543.750.705.852.610"):  # TPO-R agonists
        return G1
    if tn.startswith("D12.776.543.750.705.852.420.320"):  # IL-2R
        return G1
    if tn.startswith("D12.776.826.750.430"):  # Glucocorticoid receptor
        return G1
    if tn.startswith("D12.776.543.750.695.115"):  # Ca-sensing receptor
        return G1

    # -----------------------
    # GROUP 2: Immune checkpoint & immune modulation
    # (Immune checkpoint inhibitors, PD-1/PD-L1, TNF, CD47, IL-33, IL-1RA, lectins, etc.)
    # -----------------------
    if any(kw in t for kw in [
        "immune checkpoint inhibitors",
        "immune checkpoint inhibitor",
        "immune checkpoint",  # generic catch-all
        "programmed cell death 1 receptor", "pd-1", "pd1",
        "pd-l1", "pdl1",
        "tumor necrosis factor-alpha",
        "tnf",
        "tumor necrosis factor ligand superfamily member",
        "cd47 antigen",
        "lectins, c-type",
        "interleukin-33",
        "interleukin 1 receptor antagonist protein",
    ]):
        return G2
    if tn.startswith("D12.776.543.750.705.222.875"):  # PD-1 receptor
        return G2
    if tn.startswith("D12.644.276.374.500.800"):  # TNF-alpha
        return G2
    if tn.startswith("D12.776.395.550.014"):  # CD47 antigen
        return G2
    if tn.startswith("D12.776.503.280"):  # C-type lectins
        return G2
    if tn.startswith("D12.644.276.374.750.720"):  # TNF ligand superfamily member 15
        return G2
    if tn.startswith("D12.644.276.374.465.850"):  # IL-33
        return G2

    # -----------------------
    # GROUP 3: Targeted pathway inhibitors (RTK / JAK-STAT / mTOR, VEGF, HER2)
    # -----------------------
    if any(kw in t for kw in [
        "vascular endothelial growth factor a",
        "receptor, erbb-2",
        "erbb2",
        "vegf",
        "janus kinase inhibitors",
        "jak inhibitor",
        "mtor inhibitors",
        "tor serine-threonine kinases",
    ]):
        return G3
    # VEGF-A branch
    if tn.startswith("D12.644.276.100.800.200"):
        return G3
    # HER2 / ErbB-2 receptor branch
    if tn.startswith("D12.776.543.750.750.400.074.400"):
        return G3
    # JAK / mTOR live under D27.505.519.* but we key by text above.

    # -----------------------
    # GROUP 4: Classical cytotoxic chemotherapy (antimetabolite / alkylator / tubulin / topo)
    # -----------------------
    if any(kw in t for kw in [
        "antimetabolites, antineoplastic",
        "antineoplastic agents, alkylating",
        "topoisomerase i inhibitors",
        "topoisomerase ii inhibitors",
        "vinca alkaloids",
        "vinblastine",
        "taxoids",
        "paclitaxel",
        "tubulin modulators",
    ]):
        return G4
    # Classical chemo branches
    if tn.startswith("D27.505.519.186"):  # antimetabolites, antineoplastic
        return G4
    if tn.startswith("D27.505.519.124"):  # alkylating agents
        return G4
    if tn.startswith("D27.505.519.593.249.500"):  # tubulin modulators
        return G4
    # You can add explicit topo/anthracycline branches here if you see them later.

    # -----------------------
    # GROUP 6: Small-molecule immunomodulators & anti-inflammatories
    # (PDE4 inhibitors, COX inhibitors, glucocorticoids as drugs, antimalarials)
    # -----------------------
    if any(kw in t for kw in [
        "phosphodiesterase 4 inhibitors",
        "cyclooxygenase inhibitors",
        "histamine h1 antagonists",
        "glucocorticoids* / metabolism; pharmacology",
        "glucocorticoids",  # as a drug class
        "antimalarials",
        "immunosuppressive agents",
        "calcineurin inhibitors",
    ]):
        return G6
    if tn.startswith("D27.505.519.625.375.425.400"):  # H1 antagonists
        return G6
    if tn.startswith("D27.505.696.663.850.014.040.500.500"):  # COX inhibitors
        return G6
    if tn.startswith("D27.505.519.389.735.374"):  # PDE4 inhibitors
        return G6
    if tn.startswith("D27.505.696.477.656"):  # Immunosuppressive Agents*
        return G6
    if tn.startswith("D27.505.954.122.250.100.085"):  # Antimalarials
        return G6

    # -----------------------
    # GROUP 7: Metabolic pathway modulators
    # (gluconeogenesis, metabolic enzymes, etc.)
    # -----------------------
    if any(kw in t for kw in [
        "gluconeogenesis / drug effects",
        "gluconeogenesis",
        "biguanides",
        "imp dehydrogenase",
        "thymidylate synthase",
        "thymidine phosphorylase",
    ]):
        return G7
    if tn.startswith("G02.111.158.500"):  # Gluconeogenesis / drug effects*
        return G7
    if tn.startswith("D08.811.682.047.820.450"):  # IMP dehydrogenase
        return G7

    # -----------------------
    # GROUP 9: Supportive / adjunctive agents
    # (anion exchange resins, leucovorin, antithrombins, etc.)
    # -----------------------
    if any(kw in t for kw in [
        "anion exchange resins",
        "antithrombins",
        "leucovorin",
    ]):
        return G9
    if tn.startswith("D27.720.470.420.050"):  # Anion exchange resins
        return G9
    if tn.startswith("D27.505.519.389.745.800.449"):  # Antithrombins / agonists
        return G9
    if "leucovorin" in t:
        return G9

    # -----------------------
    # Fallbacks:
    # - If D27.505.* and not otherwise classified → treat as metabolic/chemical other
    # -----------------------
    if tn.startswith("D27.505."):
        return G7  # generic chemical/metabolic "other" rather than immuno

    # Absolute default: call it supportive/other
    return G9

# -------------------------------------------------
# RUN
# -------------------------------------------------

# Load counts per (mesh_term, tree_number)
df = pd.read_csv(INPUT_PATH)

# Start from full input and add group label
out_df = df[["mesh_term", "tree_number", "count"]].copy()
out_df["mechanism_super_group"] = [
    classify_super_group(m, tn) for m, tn in zip(out_df["mesh_term"], out_df["tree_number"])
]

# Optional: sort for readability
out_df = out_df.sort_values(["mechanism_super_group", "count"], ascending=[True, False])

# Save
out_df.to_csv(OUTPUT_PATH, index=False)
print(f"Saved mechanism super-group mapping → {OUTPUT_PATH}")
print(out_df.to_markdown(index=False))

Saved mechanism super-group mapping → cache/task_3/trial_mechanism_super_group_mapping.csv
| mesh_term                                                                              | tree_number                             |   count | mechanism_super_group                     |
|:---------------------------------------------------------------------------------------|:----------------------------------------|--------:|:------------------------------------------|
| Antibodies, Bispecific*                                                                | D12.776.124.486.485.114.125             |       4 | biologic_antibodies_biologics             |
| Antibodies, Bispecific* / pharmacology; therapeutic use                                | D12.776.124.486.485.114.125             |       3 | biologic_antibodies_biologics             |
| Antibodies, Monoclonal, Humanized*                                                     | D12.776.124.486.485.114.224.060         |       1 | biologic_antibodie

In [6]:
"""
Attach 9-bucket MOA super-group labels to each mechanism role at the trial level.

Inputs:
- CSV: cache/task_3/trial_mechanism_super_group_mapping.csv
    One row per (mesh_term, tree_number) with columns:
        • mesh_term
        • tree_number
        • count
        • mechanism_super_group

- CSV: cache/task_3/trial_mechanism_mesh_mapping.csv
    Trial-level table with, for each role:
        • investigational_products_mechanism_mesh_terms
        • investigational_products_mechanism_primary_tree_numbers
        • investigational_products_mapped
        • investigational_products_primary_tree_numbers
        • active_comparators_mechanism_mesh_terms
        • active_comparators_mechanism_primary_tree_numbers
        • standard_of_care_mechanism_mesh_terms
        • standard_of_care_mechanism_primary_tree_numbers

Process:
- Build a mapping: primary tree_number → mechanism_super_group
  from trial_mechanism_super_group_mapping.csv.
- For each trial:
    • INVESTIGATIONAL PRODUCTS:
        - Primary: use mechanism-based MeSH (mechanism_mesh_terms / mechanism_primary_tree_numbers).
        - Fallback: if mechanism term is missing / '[none]' / empty, use
          investigational product MeSH mapping (investigational_products_mapped /
          investigational_products_primary_tree_numbers).
        - Map the chosen tree_number to mechanism_super_group.
    • ACTIVE COMPARATORS / STANDARD OF CARE:
        - Use mechanism-based tree_numbers only (no fallback).
- Insert three new list-valued columns:
    • investigational_products_mechanism_super_group
    • active_comparators_mechanism_super_group
    • standard_of_care_mechanism_super_group
  immediately to the right of their respective *_primary_tree_numbers columns.

Outputs:
- CSV: cache/task_3/trial_mechanism_with_super_groups.csv
    Same as input trial_mechanism_mesh_mapping.csv plus the super-group columns.
- Console preview of the first 5 rows (markdown).
"""

# ---------------------------------------------------
# CONFIG
# ---------------------------------------------------

import ast
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache/task_3")

MAP_PATH   = BASE_DIR / "trial_mechanism_super_group_mapping.csv"
TRIALS_IN  = BASE_DIR / "trial_mechanism_mesh_mapping.csv"
TRIALS_OUT = BASE_DIR / "trial_mechanism_with_super_groups.csv"

# ---------------------------------------------------
# LOAD MAPPING: tree_number → super_group
# ---------------------------------------------------

map_df = pd.read_csv(MAP_PATH)

# Normalize tree_number a bit (strip whitespace)
map_df["tree_number"] = map_df["tree_number"].astype(str).str.strip()

# If there are duplicates, we just keep the first (should all agree anyway)
mapping = {}
for _, row in map_df.iterrows():
    tn = str(row["tree_number"]).strip()
    sg = row["mechanism_super_group"]
    if tn and tn not in mapping:
        mapping[tn] = sg

print(f"Loaded {len(mapping):,} tree_number → super_group mappings")

# ---------------------------------------------------
# HELPERS
# ---------------------------------------------------

def parse_listish(x):
    """Parse list-like strings (e.g. "['a','b']") into Python lists."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        if isinstance(v, list):
            return v
    except Exception:
        return []
    # If it's a single scalar, wrap in list
    return [s]


def build_super_group_list(mech_terms, mech_trees):
    """
    Given:
      mech_terms : list of MeSH headings (unused except for length)
      mech_trees : list of primary tree numbers (strings)
    Return:
      list of mechanism_super_group strings (same length).
    """
    # Ensure lists
    mech_trees = mech_trees or []
    mech_terms = mech_terms or []

    n = max(len(mech_terms), len(mech_trees))
    out = []

    for i in range(n):
        tn = (mech_trees[i] if i < len(mech_trees) else "") or ""
        tn = tn.strip()
        if not tn:
            out.append("")
            continue
        out.append(mapping.get(tn, ""))  # "" if not found
    return out


def build_super_group_list_with_fallback(
    mech_terms,
    mech_trees,
    fallback_terms,
    fallback_trees,
):
    """
    For investigational products:
    - First try mechanism-based MeSH mapping (mech_terms/mech_trees).
    - If the mechanism term is missing / '[none]' / empty, fall back to
      the investigational product MeSH mapping (fallback_terms/fallback_trees).

    Mapping itself is done ONLY on tree_number.
    """
    # Ensure all are lists
    mech_terms     = mech_terms or []
    mech_trees     = mech_trees or []
    fallback_terms = fallback_terms or []
    fallback_trees = fallback_trees or []

    n = max(len(mech_terms), len(mech_trees), len(fallback_terms), len(fallback_trees))
    out = []

    for i in range(n):
        # Primary (mechanism-based)
        term_mech = (mech_terms[i] if i < len(mech_terms) else "") or ""
        tn_mech   = (mech_trees[i] if i < len(mech_trees) else "") or ""

        term_mech = term_mech.strip()
        tn_mech   = tn_mech.strip()

        # Fallback (drug-based)
        term_fb = (fallback_terms[i] if i < len(fallback_terms) else "") or ""
        tn_fb   = (fallback_trees[i] if i < len(fallback_trees) else "") or ""

        term_fb = term_fb.strip()
        tn_fb   = tn_fb.strip()

        # Decide which tree number to use
        chosen_tn = ""
        # 1) Use mechanism term if present and not [none]
        if term_mech and term_mech != "[none]" and tn_mech:
            chosen_tn = tn_mech
        # 2) Else fall back to investigational product MeSH term
        elif term_fb and term_fb != "[none]" and tn_fb:
            chosen_tn = tn_fb

        chosen_tn = chosen_tn.strip()
        if not chosen_tn:
            out.append("")
            continue

        out.append(mapping.get(chosen_tn, ""))  # "" if no mapping

    return out


def insert_after(df, col, newcol, values):
    """Insert a new column immediately after `col`."""
    cols = list(df.columns)
    idx = cols.index(col)
    df.insert(idx + 1, newcol, values)

# ---------------------------------------------------
# RUN
# ---------------------------------------------------

# Load trial-level mechanism mapping
df = pd.read_csv(TRIALS_IN)

inv_sg_list = []
ac_sg_list  = []
soc_sg_list = []

for _, row in df.iterrows():
    # -----------------------------
    # INVESTIGATIONAL PRODUCTS
    # -----------------------------
    # Mechanism-based mapping
    inv_mech_terms = parse_listish(row.get("investigational_products_mechanism_mesh_terms"))
    inv_mech_tn    = parse_listish(row.get("investigational_products_mechanism_primary_tree_numbers"))

    # Fallback: investigational product MeSH mapping
    inv_prod_terms = parse_listish(row.get("investigational_products_mapped"))
    inv_prod_tn    = parse_listish(row.get("investigational_products_primary_tree_numbers"))

    inv_sg_list.append(
        build_super_group_list_with_fallback(
            inv_mech_terms,
            inv_mech_tn,
            inv_prod_terms,
            inv_prod_tn,
        )
    )

    # -----------------------------
    # ACTIVE COMPARATORS
    # (no fallback requested)
    # -----------------------------
    ac_terms = parse_listish(row.get("active_comparators_mechanism_mesh_terms"))
    ac_tn    = parse_listish(row.get("active_comparators_mechanism_primary_tree_numbers"))
    ac_sg_list.append(build_super_group_list(ac_terms, ac_tn))

    # -----------------------------
    # STANDARD OF CARE
    # (no fallback requested)
    # -----------------------------
    soc_terms = parse_listish(row.get("standard_of_care_mechanism_mesh_terms"))
    soc_tn    = parse_listish(row.get("standard_of_care_mechanism_primary_tree_numbers"))
    soc_sg_list.append(build_super_group_list(soc_terms, soc_tn))


# Insert new columns next to the primary_tree_numbers
insert_after(
    df,
    "investigational_products_mechanism_primary_tree_numbers",
    "investigational_products_mechanism_super_group",
    inv_sg_list,
)

insert_after(
    df,
    "active_comparators_mechanism_primary_tree_numbers",
    "active_comparators_mechanism_super_group",
    ac_sg_list,
)

insert_after(
    df,
    "standard_of_care_mechanism_primary_tree_numbers",
    "standard_of_care_mechanism_super_group",
    soc_sg_list,
)

# Save final CSV
df.to_csv(TRIALS_OUT, index=False)
print(f"Wrote final trial-level file with super-groups → {TRIALS_OUT}")
print(df.head(5).to_markdown(index=False))

Loaded 59 tree_number → super_group mappings
Wrote final trial-level file with super-groups → cache/task_3/trial_mechanism_with_super_groups.csv
| trial_hash                           | investigational_products                           | investigational_products_mechanism_mesh_terms                        | investigational_products_mechanism_tree_numbers                                 | investigational_products_mechanism_primary_tree_numbers   | investigational_products_mechanism_super_group   | investigational_products_alternative_names                                                                                                                                                                                                          | investigational_products_molecular_target   | investigational_products_mechanism                                         | investigational_products_tt_drug_id   | investigational_products_bmt_drug_id   | active_comparators   | active_comparators_mecha

In [7]:
"""
Derive a single 9-bucket MOA super-group label per trial and summarize distribution.

Inputs:
- CSV: cache/task_3/trial_mechanism_with_super_groups.csv
    Trial-level table with list-valued columns:
        • investigational_products_mechanism_super_group
        • active_comparators_mechanism_super_group
        • standard_of_care_mechanism_super_group

Process:
- For each trial:
    1. Parse the list-valued super-group columns from strings into Python lists.
    2. Choose a single trial-level super-group with the following priority:
         • Priority 1: first non-empty investigational_products_mechanism_super_group
         • Priority 2: if none, first non-empty active_comparators_mechanism_super_group
         • Priority 3: if none, first non-empty standard_of_care_mechanism_super_group
    3. Attach the chosen label in a new column:
         • trial_mechanism_super_group
- Aggregate a distribution (count per mechanism_super_group), excluding empty labels.

Outputs:
- CSV: cache/task_3/trial_super_group_distribution.csv
    Columns:
        • mechanism_super_group
        • count
- Console preview of the top categories (markdown).
"""

# -------------------------------------------------
# CONFIG
# -------------------------------------------------

import ast
from collections import Counter
from pathlib import Path

import pandas as pd

BASE_DIR = Path("cache/task_3")
INPUT_PATH  = BASE_DIR / "trial_mechanism_with_super_groups.csv"
OUTPUT_PATH = BASE_DIR / "trial_super_group_distribution.csv"

# -------------------------------------------------
# HELPERS
# -------------------------------------------------

def parse_listish(x):
    """Parse list-like strings safely into Python lists."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        return v if isinstance(v, list) else []
    except Exception:
        return []

def first_non_empty_str(lst):
    """Return the first non-empty string from a list, or '' if none."""
    if not isinstance(lst, list):
        return ""
    for v in lst:
        if isinstance(v, str) and v.strip():
            return v.strip()
    return ""

# -------------------------------------------------
# RUN
# -------------------------------------------------

df = pd.read_csv(INPUT_PATH)

chosen_super_groups = []

for _, row in df.iterrows():
    inv_sg_list = parse_listish(row.get("investigational_products_mechanism_super_group", "[]"))
    ac_sg_list  = parse_listish(row.get("active_comparators_mechanism_super_group", "[]"))
    soc_sg_list = parse_listish(row.get("standard_of_care_mechanism_super_group", "[]"))

    # Priority 1 — investigational product super-group
    chosen = first_non_empty_str(inv_sg_list)

    # Priority 2 — active comparator super-group
    if not chosen:
        chosen = first_non_empty_str(ac_sg_list)

    # Priority 3 — fallback to SOC
    if not chosen:
        chosen = first_non_empty_str(soc_sg_list)

    chosen_super_groups.append(chosen)

# Add per-trial chosen super-group (optional but useful)
df["trial_mechanism_super_group"] = chosen_super_groups

# Count distribution (exclude empty)
dist = Counter(sg for sg in chosen_super_groups if sg)

dist_df = (
    pd.DataFrame(
        [
            {"mechanism_super_group": sg, "count": count}
            for sg, count in dist.items()
        ]
    )
    .sort_values("count", ascending=False)
    .reset_index(drop=True)
)

# Save distribution
dist_df.to_csv(OUTPUT_PATH, index=False)

print(f"Saved distribution → {OUTPUT_PATH}")
print("Top categories:\n")
print(dist_df.head(20).to_markdown(index=False))

Saved distribution → cache/task_3/trial_super_group_distribution.csv
Top categories:

| mechanism_super_group                     |   count |
|:------------------------------------------|--------:|
| supportive_adjunctive_agents              |      53 |
| cytokine_hormone_receptor_modulators      |      39 |
| immune_checkpoint_immune_modulation       |      32 |
| targeted_pathway_inhibitors               |      20 |
| metabolic_pathway_modulators              |      15 |
| biologic_antibodies_biologics             |      10 |
| small_molecule_immunomod_antiinflammatory |       7 |
| classical_cytotoxic_chemotherapy          |       5 |
| vaccines_immune_biologics                 |       2 |


#### Output results

In [8]:
"""
Build a final, human-readable trial results table summarizing:
- Trial title
- Combined investigational (or SOC) drug name(s)
- Cleaned mechanism-of-action (MeSH-based, lightly normalized)
- Innovation status (Innovative / Generic / Biosimilar)
- High-level MOA super-group category

Inputs:
- CSV: cache/task_2/trial_investigational_drugs_classifications.csv
    • trial_hash
    • investigational_products                  (list-like)
    • investigational_products_classifications  (list-like; Innovative/Generic/Biosimilar/etc.)

- CSV: cache/task_3/trial_mechanism_with_super_groups.csv
    • trial_hash
    • investigational_products, standard_of_care (list-like drug names)
    • *_mechanism_mesh_terms, *_mechanism_super_group (list-like MeSH MOA & 9-bucket super-groups)

- CSV: cache/data_preprocess/raw_trials_with_hash.csv
    • trial_hash
    • title  (main trial title)

Process:
1. From the classification table:
   - Parse list-like investigational products & classifications.
   - Join product names and innovation flags with '+' per trial.
2. From the mechanism table:
   - Build per-trial '+'-joined strings for:
       • inv_drug_name_joined
       • inv_moa_joined
       • inv_category_joined (super-groups, unique)
       • soc_drug_name_joined
       • soc_moa_joined
       • soc_category_joined
3. Merge mechanism info with classification info on trial_hash.
4. For each trial, choose a single row:
   - If investigational product(s) exist:
       • drug_name  = inv_drug_name_joined
       • moa        = inv_moa_joined
       • innovation = innovation_joined
       • category   = inv_category_joined (fallback to SOC category if missing)
   - Else if only SOC exists:
       • drug_name  = soc_drug_name_joined
       • moa        = soc_moa_joined
       • category   = soc_category_joined
       • innovation = "Generic"
   - Else: leave fields empty.
5. Attach trial_title from raw_trials_with_hash; if missing, fall back to trial_hash.
6. Clean up:
   - Remove parenthetical text from drug_name (e.g. "Drug X (Company Y)" → "Drug X").
   - Normalize MOA:
       • split on '+'
       • drop text after '/' for each component
       • remove '*' and trim
       • rejoin with '+'
   - Normalize innovation flags to canonical title-case
     (Innovative, Generic, Biosimilar), preserving '+' separators.

Outputs:
- CSV: output/trial_results_table.csv
    Columns:
        trial_title
        drug_name
        moa
        innovation_generic_biosimilar
        category
- Console preview of the first 20 rows (markdown).
"""

# ---------------------------------
# CONFIG
# ---------------------------------

import ast
import re
from pathlib import Path

import pandas as pd

TASK2_DIR = Path("cache/task_2")
TASK3_DIR = Path("cache/task_3")
DATA_DIR  = Path("cache/data_preprocess")
OUT_DIR   = Path("output")

CLASS_PATH  = TASK2_DIR / "trial_investigational_drugs_classifications.csv"
MECH_PATH   = TASK3_DIR / "trial_mechanism_with_super_groups.csv"
TRIALS_PATH = DATA_DIR / "raw_trials_with_hash.csv"
OUTPUT_PATH = OUT_DIR / "trial_results_table.csv"

# ---------------------------------
# HELPERS
# ---------------------------------

def parse_listish(x):
    """Parse a list-like string into a Python list."""
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    try:
        v = ast.literal_eval(s)
        return v if isinstance(v, list) else []
    except Exception:
        return []

def join_plus(lst):
    """Join non-empty strings with '+'."""
    cleaned = [str(x).strip() for x in lst if str(x).strip()]
    return "+".join(cleaned)

def join_unique_plus(lst):
    """Join unique, non-empty strings with '+' (order-preserving)."""
    seen = set()
    out = []
    for x in lst:
        s = str(x).strip()
        if s and s not in seen:
            seen.add(s)
            out.append(s)
    return "+".join(out)

def normalize_innovation(val: str) -> str:
    """
    Normalize to canonical title-case:
      - Innovative
      - Generic
      - Biosimilar
    Anything else is title-cased as a fallback.
    """
    if pd.isna(val):
        return ""
    v = str(val).strip()
    if not v:
        return ""
    low = v.lower()
    if "innov" in low:
        return "Innovative"
    if "bio" in low:
        return "Biosimilar"
    if "gener" in low:
        return "Generic"
    # Fallback: just title-case whatever it is.
    return v.title()

def strip_parentheses(s: str) -> str:
    """Remove all parenthetical segments from a drug name."""
    if not isinstance(s, str):
        return s
    return re.sub(r"\s*\([^)]*\)", "", s).strip()

def clean_moa(moa: str) -> str:
    """
    Clean MOA string:
      - split by '+'
      - for each piece, drop everything after '/'
      - remove '*'
      - strip whitespace
      - rejoin by '+'
    """
    if not isinstance(moa, str):
        return ""
    parts = moa.split("+")
    cleaned_parts = []
    for part in parts:
        s = part.strip()
        if not s:
            continue
        # keep only text before first '/'
        if "/" in s:
            s = s.split("/", 1)[0]
        # remove '*' and strip again
        s = s.replace("*", "").strip()
        if s:
            cleaned_parts.append(s)
    return "+".join(cleaned_parts)

# ---------------------------------
# RUN
# ---------------------------------

# Load inputs
df_class  = pd.read_csv(CLASS_PATH)
df_mech   = pd.read_csv(MECH_PATH)
df_titles = pd.read_csv(TRIALS_PATH, dtype=str)[["trial_hash", "title"]]

# Prepare classification info
df_class["drug_name_list"] = df_class["investigational_products"].apply(parse_listish)
df_class["innovation_list"] = df_class["investigational_products_classifications"].apply(parse_listish)

df_class["drug_name_joined"] = df_class["drug_name_list"].apply(join_plus)
df_class["innovation_joined"] = df_class["innovation_list"].apply(
    lambda lst: join_plus([normalize_innovation(v) for v in lst])
)

df_class_slim = df_class[["trial_hash", "drug_name_joined", "innovation_joined"]].copy()

# Prepare mechanism / category info
df_mech["inv_drug_name_joined"] = df_mech["investigational_products"].apply(
    lambda x: join_plus(parse_listish(x))
)
df_mech["inv_moa_joined"] = df_mech["investigational_products_mechanism_mesh_terms"].apply(
    lambda x: join_plus(parse_listish(x))
)
df_mech["inv_category_joined"] = df_mech["investigational_products_mechanism_super_group"].apply(
    lambda x: join_unique_plus(parse_listish(x))
)

df_mech["soc_drug_name_joined"] = df_mech["standard_of_care"].apply(
    lambda x: join_plus(parse_listish(x))
)
df_mech["soc_moa_joined"] = df_mech["standard_of_care_mechanism_mesh_terms"].apply(
    lambda x: join_plus(parse_listish(x))
)
df_mech["soc_category_joined"] = df_mech["standard_of_care_mechanism_super_group"].apply(
    lambda x: join_unique_plus(parse_listish(x))
)

df_mech_slim = df_mech[
    [
        "trial_hash",
        "inv_drug_name_joined",
        "inv_moa_joined",
        "inv_category_joined",
        "soc_drug_name_joined",
        "soc_moa_joined",
        "soc_category_joined",
    ]
].copy()

# Merge on trial_hash
merged = pd.merge(
    df_mech_slim,
    df_class_slim,
    on="trial_hash",
    how="left",
)

# Build final table
def build_final_row(row):
    inv_drug = (row.get("inv_drug_name_joined") or "").strip()
    soc_drug = (row.get("soc_drug_name_joined") or "").strip()

    if inv_drug:
        drug_name  = inv_drug
        moa        = (row.get("inv_moa_joined") or "").strip()
        innovation = (row.get("innovation_joined") or "").strip()
        category   = (row.get("inv_category_joined") or "").strip()
        if not category:
            category = (row.get("soc_category_joined") or "").strip()

    elif soc_drug:
        drug_name  = soc_drug
        moa        = (row.get("soc_moa_joined") or "").strip()
        category   = (row.get("soc_category_joined") or "").strip()
        # Standard-of-care only: treat as Generic (canonical title-case)
        innovation = "Generic"

    else:
        drug_name = ""
        moa = ""
        innovation = ""
        category = ""

    return pd.Series(
        {
            "drug_name": drug_name,
            "moa": moa,
            "innovation_generic_biosimilar": innovation,
            "category": category,
        }
    )

final_cols = merged.apply(build_final_row, axis=1)
final = pd.concat([merged[["trial_hash"]], final_cols], axis=1)

# Attach titles
final = final.merge(df_titles, on="trial_hash", how="left")
final["trial_title"] = final["title"].fillna(final["trial_hash"])
final.drop(columns=["title"], inplace=True)

# REMOVE PARENTHETICAL TEXT FROM drug_name
final["drug_name"] = final["drug_name"].apply(strip_parentheses)

# CLEAN MOA FIELD
final["moa"] = final["moa"].apply(clean_moa)

# Ensure innovation is always canonical (in case anything slipped through)
final["innovation_generic_biosimilar"] = final["innovation_generic_biosimilar"].apply(
    lambda v: "+".join(
        normalize_innovation(part)
        for part in str(v).split("+")
        if str(part).strip()
    ) if pd.notna(v) and str(v).strip() else ""
)

results = final[
    ["trial_title", "drug_name", "moa", "innovation_generic_biosimilar", "category"]
].copy()

# Save
OUT_DIR.mkdir(parents=True, exist_ok=True)
results.to_csv(OUTPUT_PATH, index=False)
print(f"Saved results table → {OUTPUT_PATH}")
print(results.head(20).to_markdown(index=False))

Saved results table → output/trial_results_table.csv
| trial_title                                                                                                                                                                                                                                 | drug_name                                         | moa                                  | innovation_generic_biosimilar   | category                             |
|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------|:-------------------------------------|:--------------------------------|:-------------------------------------|
| Random, single oral administration, double cycle, double cross, bioequivalence test of Cinacalcet Hydrochloride Tablets in healthy subjec

In [10]:
"""
Inspect the final trial results table for rows with missing key fields.

Inputs:
- CSV: output/trial_results_table.csv
    Columns (at minimum):
        • trial_title
        • drug_name
        • moa
        • innovation_generic_biosimilar
        • category

Process:
- Load the final trial results table.
- Define a helper to treat NaN and empty strings ("") as missing.
- Identify rows where ANY of the key columns are missing.
- Print:
    • The total number of rows with at least one missing value.
    • A markdown preview of all such rows.

Outputs:
- Console-only diagnostics; no files are written.
"""

# ---------------------------------
# CONFIG
# ---------------------------------

from pathlib import Path

import pandas as pd

OUT_DIR = Path("output")
INPUT_PATH = OUT_DIR / "trial_results_table.csv"

# Columns to inspect for missing values
COLS_TO_CHECK = [
    "trial_title",
    "drug_name",
    "moa",
    "innovation_generic_biosimilar",
    "category",
]

# ---------------------------------
# HELPERS
# ---------------------------------

def is_missing(x):
    """Return True if value is NaN or an empty string (after stripping)."""
    return pd.isna(x) or (str(x).strip() == "")

# ---------------------------------
# RUN
# ---------------------------------

# Load results table
results = pd.read_csv(INPUT_PATH)

# Ensure all required columns exist
missing_cols = [c for c in COLS_TO_CHECK if c not in results.columns]
if missing_cols:
    raise ValueError(f"Missing expected columns in input: {missing_cols}")

# Treat "" as missing for easier filtering
mask_missing = results[COLS_TO_CHECK].map(is_missing).any(axis=1)

missing_rows = results[mask_missing].copy()

print(f"Found {len(missing_rows)} rows with at least one missing value.\n")

if len(missing_rows) > 0:
    print("Rows with missing values:\n")
    print(missing_rows.to_markdown(index=False))
else:
    print("No rows with missing values in the selected columns.")

Found 1 rows with at least one missing value.

Rows with missing values:

| trial_title                                                 | drug_name   |   moa | innovation_generic_biosimilar   |   category |
|:------------------------------------------------------------|:------------|------:|:--------------------------------|-----------:|
| A Phase I Study of SSS24 in Patients With Colorectal Cancer | SSS-24      |   nan | Innovative                      |        nan |


discovered for two drugs "601" and "Inetetamab" citline mapped them to the wrong drug / drug_id

----------------------------------------------------------------------

That is inotuzumab ozogamicin (Besponsa):
- Target: CD22
- Indication: B-cell ALL, etc.
- Mechanism: antibody–drug conjugate / DNA damaging.

This has nothing to do with:
- HER2
- breast cancer neoadjuvant
- Inetetamab / Inituzumab / Ceputin

----------------------------------------------------------------------

AER-601 (Aerami / Dance Biopharm GLP-1 analogue)
- Target: GLP-1 receptor
- Indications: Type 2 diabetes, obesity, appetite/weight control
- Mechanism: GLP-1 receptor agonist, incretin mimetic, insulin secretagogue

This has nothing to do with:
- VEGF-A or VEGF receptors
- Intravitreal ophthalmic anti-VEGF biologics
- Pathological myopic choroidal neovascularization (pmCNV)