# **PlaNet**

**Predicting Drug Outcomes via Clinical Knowledge Graph and Pre-trained Model**

**Paper Source**: [https://www.medrxiv.org/content/10.1101/2024.03.06.24303800v2](https://www.medrxiv.org/content/10.1101/2024.03.06.24303800v2)

**Overview**
**PlaNet** is a geometric deep learning framework that predicts treatment outcomes by reasoning over population variability, disease biology, and drug chemistry using a massive clinical knowledge graph. The system can predict drug efficacy (E), safety (S), and adverse events (AE) for any drug(s), disease(s), or population combination, including novel experimental drugs.

**Key Concepts**

**Core Innovation**

  * **Clinical Knowledge Graph**: Represents treatment information as (drug, condition, population) triplets.
  * **Population-Aware Predictions**: Accounts for how the same treatment affects different patient populations differently.
  * **Geometric Deep Learning**: Uses graph neural networks to learn from complex biomedical relationships.

**Knowledge Graph Structure**
- **Foreground Clinical KG**: Clinical trials data structured as drug-disease-population triplets.
- **Background Biological KG**: Integrates 9 biological/chemical databases including:
  * Disease biology and genomic variants.
  * Drug targets and chemical similarities.
  * Protein-protein interactions.
  * Molecular and cellular phenotypes.

**Overall Statistics:**

The final knowledge graph contains:
- 330,915 unique nodes
- 13,928,443 edges.

Clinical Knowledge Graph (Foreground):
- Primary Source: The graph is built from a snapshot of the ClinicalTrials.gov database from February 14th, 2021.
- Scope: It includes 69,595 interventional clinical trials , resulting in 205,809 individual trial arms.

Core Structure: 
- Information is structured as (drug, condition, population) triplets extracted from trial text.

Information Extraction:
- Drugs: Extracted using the Medex tool and standardized against DrugBank, PubChem, and RxNorm.
- Diseases: Mapped to the MeSH database.
- Eligibility Criteria: Processed using the Criteria2Query tool and mapped to UMLS.
- Adverse Events: Mapped to the MedDRA hierarchy.
- Primary Outcomes: A custom vocabulary of 3,048 terms was created by clustering phrases from all trials.

Biological Knowledge Graph (Background): This component integrates knowledge from 9 external databases (see table belo) to provide biological and chemical context.

Node Type Counts:
- Diseases: 5,751 (from MeSH) 
- Drugs: 14,300 (from DrugBank) 
- Drug Classes: 4,825 (from ClassyFire) 
- Proteins: 17,660 (from a multiscale-interactome) 
- Biological Functions: 28,734 (from Gene Ontology) 
- Population Concepts: 30,913 (from UMLS) 

Key Networks Integrated:
- Disease-Disease Network: Built from the MeSH hierarchy.
- Drug Chemical Hierarchy: Built from ClassyFire and DrugBank.
- Protein-Protein Interactions: Sourced from a multiscale-interactome that compiles seven major databases.
- Drug-Protein Network: Created using drug-target information from DrugBank.
- Disease-Protein Network: Built from curated DisGeNET data.

**Number of Entities in the Knowledge Graph:**

| Node Type | Count |
| :--- | :--- |
| **Trial Arm** | 205,809 |
| **Population** | 30,913 |
| **Protein** | 17,660 |
| **Drug** | 14,300 |
| **Disease** | 5,751 |
| **Drug Class** | 4,825  |
| **Outcome** | 3,048  |
| **Function** | 28,734  |
| **Interventional Clinical Trials** | 69,595 |
| **Total Unique Nodes** | **330,915**|
| **Total Edges (relations)** | **13,928,443**|


**PlaNet's Knowledge Graph Components**

This table details the external databases and pre-trained models integrated to build the PlaNet KG, forming the foundation for its predictive power.

| Component Name | Type | Role in the Workflow & Key Information | Component Details | Source File(s) / Identifier | Link |
| :--- | :--- | :--- | :--- | :--- | :--- |
| **UMLS (Unified Medical Language System)** | Biomedical Database | **Parsing**: A meta-thesaurus that links different vocabularies. It's used to normalize medical concepts from eligibility criteria and connect different biological networks together. | A foundational vocabulary that integrates and maps between different terminologies (e.g., MeSH, RxNorm) to create a unified concept space. The KG uses it to build a population-population network of related medical conditions and patient characteristics. | `data/population_data/umls-install/` | [NLM UMLS Homepage](https://www.nlm.nih.gov/research/umls/index.html) |
| **MeSH (Medical Subject Headings)** | Biomedical Database | **Parsing**: A comprehensive medical vocabulary used to standardize disease names from the trial's `condition` field. | A hierarchical thesaurus from the National Library of Medicine. The KG uses the 'Diseases' and 'Psychology' branches to build a network of 5,751 disease nodes with "is-a" relationships. | `data/disease_data/2021/` | [NLM MeSH Homepage](https://www.nlm.nih.gov/mesh/meshhome.html) |
| **DrugBank / RxNorm / PubChem** | Drug Databases | **Parsing**: Used to standardize drug and chemical names, providing information on drug targets, enzymes, and chemical classifications. | DrugBank provides detailed data on drugs, including synonyms, products, and protein targets. RxNorm provides normalized names for clinical drugs. Together, they help map drug names from trials to 14,300 unique drug nodes in the KG. | `data/drug_data/` | [DrugBank](https://go.drugbank.com/), [RxNorm](https://www.nlm.nih.gov/research/umls/rxnorm/index.html) |
| **DisGeNET** | Association Database | **Parsing**: Enriches the KG by providing curated associations between genes and human diseases, including data on genomic alterations and expression changes. | This database compiles gene-disease associations from expert-curated repositories. This is used to build the disease-protein interaction network in the KG. | `data/kg_data/external_data/disgenet/` | [DisGeNET](https://www.disgenet.org/) |
| **Gene Ontology (GO)** | Association Database | **Parsing**: Provides a structured vocabulary for the functions of genes and proteins, adding a layer of biological process information to the KG. | A collaborative bioinformatics project to create a standardized vocabulary of gene and protein functions. The KG uses the "Biological Processes" subnetwork to build a hierarchy of 28,734 biological functions. | `data/kg_data/external_data/go/` | [Gene Ontology](http://geneontology.org/) |
| **ClassyFire** | Chemical Taxonomy | **Parsing**: A database used to classify drugs based on their chemical structure, creating a hierarchical drug-drug network. | An automated system that provides a comprehensive chemical taxonomy called ChemOnt. It's used to connect drugs to 4,825 chemical classes, forming the basis of the drug chemical hierarchy in the KG. | Referenced in Supplementary Note 2 | [ClassyFire](http://classyfire.wishartlab.com/) |
| **Medex & Criteria2Query** | NLP Tools | **Parsing**: External Java programs that extract structured information (e.g., drug dosage, medical concepts) from the free-text portions of clinical trials. | Medex extracts medication information from clinical text. Criteria2Query parses eligibility criteria and links entities to UMLS concepts using tf-idf matching. | `resources/` | [Columbia DBMI](https://www.dbmi.columbia.edu/) |
| **BiomedBERT** | **Foundational Model** | **Prediction**: Converts the text descriptions of trial arms into numerical vectors (embeddings) that capture their semantic meaning. | A large language model pre-trained on biomedical text. It provides a deep understanding of medical language, which is used to generate feature vectors for diseases, drugs, outcomes, and trial arms for the prediction models. | `microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext` | [Hugging Face](https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext) |
| **PlaNet Graph Encoder** | Task-Specific Model | **Prediction**: The core graph neural network that learns from the knowledge graph structure. This is the base for the prediction models. | A Relational Graph Convolutional Network (R-GCN) that is pre-trained on the entire KG using a link prediction task. This creates general-purpose embeddings for all 330,915 nodes before fine-tuning. | `data/models/3u7di6ag/ckpt.pt` | N/A (Project-specific model) |
| **Adverse Event (AE) Predictor** | Task-Specific Model | **Prediction**: A model fine-tuned to predict the probability of specific adverse events for a given trial arm. | This model predicts 554 different adverse event categories defined by the MedDRA hierarchy. It achieves an average AUROC of 0.85 across these categories. | `data/models/ae_model_shxo9bgw/ckpt.pt` | N/A (Project-specific model) |
| **Safety Predictor** | Task-Specific Model | **Prediction**: A model fine-tuned to predict the overall probability of a safety concern for a trial arm. | This model predicts the likelihood of a "serious adverse event" occurring in the treatment arm compared to the placebo arm. It achieves an AUROC of 0.79. | `data/models/safety_model_1xekl810/ckpt.pt` | N/A (Project-specific model) |
| **Efficacy Predictor** | Task-Specific Model | **Prediction**: A model fine-tuned to compare two trial arms and predict which one is likely to have better efficacy. | This model focuses on predicting survival-related outcomes. It was trained on 1,307 labeled trial arms and achieves an AUROC of 0.70. | `data/models/efficacy_model_34l5ms9m/ckpt.pt` | N/A (Project-specific model) |

**Comparison of Data and Model Types:**

| Category | Primary Function | Nature | Output | Analogy | Example from Your Project |
| :--- | :--- | :--- | :--- | :--- | :--- |
| **Database** | Stores and provides access to structured facts. | Passive & Specialized | Raw data upon query | 📚 A library full of textbooks | MeSH, DrugBank, UMLS |
| **NLP Tool** | Performs a single, well-defined language task. | Active & Highly Specialized | Extracted or annotated text | 🔬 A specific lab instrument | `Medex`, `Criteria2Query` |
| **Foundational Model** | Provides a broad, general understanding of language. | Active & General-Purpose | Numerical vectors (embeddings) | 🎓 A medical school graduate | `BiomedBERT` |
| **Task-Specific Model**| Makes a prediction for one specific goal. | Active & Highly Specialized | A specific prediction or score | 🩺 A specialized heart surgeon | Adverse Event Predictor |

**Model Architecture**

**Encoder: Relational Graph Convolutional Network (R-GCN)**

```
h^(l+1)_i = σ(∑_{r∈R} ∑_{j∈N^r_i} (1/|N^r_i|) W^(l)_r h^(l)_j + W^(l)_0 h^(l)_i)
```
  * **Multi-layer message passing** with relation-specific transformations.
  * **Basis decomposition** for parameter efficiency.
  * **Entity attributes**: Text embeddings (PubMedBERT), structured features.

**Self-Supervised Pretraining**
  * **Link prediction task**: Predict the existence of edges in the knowledge graph.
  * **Negative sampling** with self-adversarial training.
  * Learns general-purpose embeddings for all entities.

**Downstream Tasks**

1. **Efficacy Prediction**: Which treatment arm has better survival outcomes.
2. **Safety Prediction**: Probability of serious adverse events.
3. **Adverse Event Categories**: Specific types of side effects (554 categories).

**Evaluation & Results**

**Efficacy Prediction**

  * **Dataset**: 1,307 labeled trial arms across 625 trials.
  * **Metric**: AUROC = 0.70 (15% improvement over PubMedBERT).
  * **Enhanced version (PlaNetLM)**: Achieved a further 5% improvement.

**Safety Prediction**

  * **Serious Adverse Events**: AUROC = 0.79.
  * **Adverse Event Categories**: Average AUROC = 0.85 across 554 categories.
  * **Training Data**: 18,583 labeled trial arms.

**Key Capabilities**

**Generalization to Novel Drugs**
  * **Performance**: Comparable results on 224 never-before-seen drugs.
  * **Mechanism**: Leverages chemical and biological similarities through KG connections.
  * **Example**: Correctly predicted tasisulam-sodium toxicity despite never seeing the drug in labeled data.

**Population Effect Analysis**
  * **Matched Trials**: 91% correct probability adjustments when population characteristics change.
  * **Population Ranking**: Identifies which patient eligibility criteria most influence the risk of adverse events.
  * **Clinical Insight**: The same drug can have different safety profiles in different populations.

**Validation Examples**
  * **COVID-19 Trials**: Correctly predicted hemorrhage/breathing difficulty for remdesivir (a drug and disease never seen during training).
  * **Cancer Trials**: Accurate predictions for novel drug combinations.
  * **Temporal Validation**: Similar performance on future trials (after June 2017) vs. historical data.

**Technical Implementation**

**Data Processing Pipeline**
1. **Entity Extraction**: Named entity recognition from clinical trial text.
2. **Standardization**: Mapping to controlled vocabularies (DrugBank, MeSH, UMLS).
3. **Knowledge Integration**: Connecting clinical and biological networks.
4. **Quality Control**: Filtering and validation of extracted relationships.

**Model Training**
  * **Pretraining**: 20,000 steps with a batch size of 8,192.
  * **Fine-tuning**: Task-specific classifiers with a shared encoder.
  * **Optimization**: Learning rate scheduling, gradient clipping, dropout regularization.

**Significance & Impact**

**Clinical Applications**
  * **Drug Development**: Identify promising treatments and predict failures early.
  * **Trial Design**: Optimize patient selection to reduce adverse events.
  * **Precision Medicine**: Personalized treatment recommendations based on population characteristics.

**Technical Contributions**
  * **Scalable Framework**: Handles massive heterogeneous biomedical data.
  * **Interpretable Predictions**: Explains which population factors influence outcomes.
  * **Generalizable**: Works across diseases, drugs, and population types without retraining.

The framework represents a significant advancement in AI-guided clinical decision making, offering valuable insights for realizing precision medicine through comprehensive biomedical knowledge integration.

# **LC Clinical Trials**

## **First Steps**

**1️⃣ Initial Query Construction**

We started by designing a broad **Boolean search query** to capture *all Long COVID-related trials* across ClinicalTrials.gov and similar repositories.

```text
("Long COVID" OR "Post-Acute Sequelae of SARS-CoV-2" OR "Post-Acute Sequelae of COVID-19" OR 
"PASC" OR "Post COVID Syndrome" OR "Post COVID Condition" OR "Long Haul COVID" OR 
"Chronic COVID" OR "COVID-19 Sequelae")
```

This query was used to identify *interventional* trials testing drugs, biologics, or supplements intended to treat Long COVID symptoms.

---

**2️⃣ Data Filtering on ClinicalTrials.gov**

Once trials were retrieved, the following filters were applied:

| Filter Category       | Selected Option(s)                                                   |
| --------------------- | -------------------------------------------------------------------- |
| **Study Type**        | Interventional (Clinical Trial)                                      |
| **Phase**             | Phase 1, 2, 3, or Early Phase 1                                      |
| **Status**            | Completed, Active, Recruiting, or Terminated (excluding “Withdrawn”) |
| **Condition/Disease** | Long COVID, Post-Acute COVID-19, PASC, Post-COVID Syndrome, etc.     |
| **Intervention Type** | Drug, Biologic, Dietary Supplement                                   |
| **Availability**      | Trials with posted results or uploaded results documents             |

---

**3️⃣ Manual Selection and Export**

* We manually exported the JSON data for **18 trials** meeting the inclusion criteria.
* Each file (e.g., `NCT04880161.json`, `NCT04678830.json`) contains complete metadata and posted results.

---

**4️⃣ Data Normalization**

All JSON files were parsed and normalized into a **uniform summary format** containing:

| Key Field                   | Description                                 |
| --------------------------- | ------------------------------------------- |
| **Trial_ID**                | NCT number of the study                     |
| **Drug**                    | Investigational compound name               |
| **AE**                      | Treatment-emergent adverse events summary   |
| **E**                       | Efficacy measure or primary endpoint change |
| **S**                       | Serious adverse events summary              |
| **Design**                  | Allocation, model, masking, and sample size |
| **Sponsor / Phase / Dates** | Metadata for cross-comparison               |

---

**5️⃣ Derived Tables**

Two tables were created for clarity:

* **Full Comparison Table:** Detailed information including trial design, conditions, and interventions.
* **Compact AE/E/S Table:** Focused on safety and efficacy outcomes.

---

✅ **Outcome:**
We successfully harmonized all 18 uploaded trials (phases 1–3) testing drug, biologic, or supplement interventions for Long COVID / PASC into consistent markdown summaries and comparative tables ready for inclusion in a systematic review or supplementary material.

## **🟩NCT04678830 — Leronlimab**

* Double Blind, Placebo Controlled Study of Safety and Efficacy of Leronlimab in Patients With "Long" COVID-19
* **Sponsor:** CytoDyn, Inc.
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2021-03-01; Primary Completion 2021-06-05; Completion 2021-07-08
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 56; Condition: Coronavirus Disease 2019
* **Arms:** Placebo: PLACEBO_COMPARATOR | 700mg Leronlimab: EXPERIMENTAL
* **Interventions:** DRUG: Placebos; DRUG: Leronlimab (700mg)
* **Primary endpoint(s):** Changes From Baseline in Daily COVID-19-related Symptom Severity Score Through Day 56
* **Results posted:** Yes
* **Efficacy (as posted):** Symptom score change: Leronlimab=-16.3, Placebo=-8.1
* **Adverse events (any):** 700mg Leronlimab: 22/28; Placebo: 20/28
* **Serious AEs:** 700mg Leronlimab: 0/28; Placebo: 1/28

---

## **🟩NCT04880161 — Ampion**

* A Study to Evaluate Ampion in Patients With Prolonged Respiratory Symptoms Due to COVID-19
* **Sponsor:** Ampio Pharmaceuticals. Inc.
* **Phase / Status:** Phase 1 / Completed  
* **Dates:** Start 2021-07-26; Primary Completion 2021-12-22; Completion 2022-02-21
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: **32**; Condition: Covid19
* **Arms:** Active: EXPERIMENTAL | Control: PLACEBO_COMPARATOR
* **Interventions:** BIOLOGICAL: Ampion; OTHER: Placebo
* **Primary endpoint(s):** Treatment-Emergent Adverse Events (TEAEs)
* **Results posted:** Yes
* **Efficacy (as posted):** Safety study - primary was TEAEs
* **Adverse events (any):** Active: **8/15**; Control: **9/16**
* **Serious AEs:** Active: 0/15; Control: 0/16

---

## **🟩NCT05633407 — Efgartigimod**

* Efficacy and Safety Study of **Efgartigimod** in Adults With Post-acute Sequelae of COVID-19 *[NOT Montelukast]*
* **Sponsor:** **argenx** 
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2022-09-23; Primary Completion 2024-04-18; Completion 2024-04-18
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: **53**; Condition: Postural Orthostatic Tachycardia Syndrome
* **Arms:** Efgartigimod: EXPERIMENTAL | Placebo: PLACEBO_COMPARATOR
* **Interventions:** DRUG: Efgartigimod; DRUG: Placebo
* **Primary endpoint(s):** Safety and tolerability measures
* **Results posted:** Yes
* **Efficacy (as posted):** Data available in results
* **Adverse events (any):** Efgartigimod: 31/36; Placebo: 14/17
* **Serious AEs:** Efgartigimod: 0/36; Placebo: 0/17

---

## **🟩NCT05126563 — Allogeneic**

* Randomized Double-Blind Phase 2 Study of **Allogeneic** Adipose Derived MSCs for Long COVID
* **Sponsor:** Hope Biosciences Research Foundation
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2021-07-30; Primary Completion 2022-09-14; Completion 2022-10-07
* **Design:** RANDOMIZED PARALLEL; Masking: TRIPLE; Enrollment: 40; Condition: Post COVID-19
* **Arms:** HB-adMSCs (allogeneic): EXPERIMENTAL | Placebo: PLACEBO_COMPARATOR
* **Interventions:** DRUG: HB-adMSCs (allogeneic); DRUG: Placebo
* **Primary endpoint(s):** Incidence of TEAEs and SAEs
* **Results posted:** Yes
* **Efficacy (as posted):** Safety primary endpoint
* **Adverse events (any):** HB-adMSCs: 3/21; Placebo: 3/19
* **Serious AEs:** HB-adMSCs: 0/21; Placebo: 0/19

---

## **🟩NCT05576662 — Paxlovid (Nirmatrelvir/Ritonavir)**

* Paxlovid (Nirmatrelvir/Ritonavir) for Treatment of Long COVID
* **Sponsor:** Stanford University
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2022-11-08; Primary Completion 2023-08-14; Completion 2023-09-12
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: **168**
* **Arms:** Nirmatrelvir Plus Ritonavir; Placebo Plus Ritonavir
* **Interventions:** DRUG: Nirmatrelvir; DRUG: Placebo; DRUG: Ritonavir
* **Primary endpoint(s):** Symptom severity composite
* **Results posted:** Yes
* **Efficacy (as posted):** Results available
* **Adverse events (any):** Nirmatrelvir+Ritonavir: 100/102; Placebo+Ritonavir: 48/53
* **Serious AEs:** Nirmatrelvir+Ritonavir: 3/102; Placebo+Ritonavir: 1/53

---

## **🟩NCT05047952 — Vortioxetine**

* Vortioxetine for Post-COVID-19 Condition
* **Sponsor:** **Brain and Cognition Discovery Foundation**
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2021-09-16; Primary Completion 2023-02-22; Completion 2023-02-22
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 149
* **Arms:** Vortioxetine; Placebo
* **Interventions:** DRUG: Vortioxetine; DRUG: Placebo
* **Primary endpoint(s):** Cognitive/symptom scale change
* **Results posted:** Yes
* **Efficacy (as posted):** Results available
* **Adverse events (any):** Vortioxetine: 47/75; Placebo: 34/74
* **Serious AEs:** Vortioxetine: 0/75; Placebo: 0/74

## **🟨NCT04871815 — Sodium Pyruvate Nasal Spray** 

* Effects of **Sodium Pyruvate Nasal Spray** in COVID-19 *[NOT Naltrexone]*
* **Sponsor:** **Cellular Sciences, inc.**
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2021-05-17; Primary Completion 2022-01-25; Completion 2022-01-25
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 122; Condition: COVID-19 Recovery
* **Arms:** Sodium Pyruvate: EXPERIMENTAL | Placebo: PLACEBO_COMPARATOR
* **Interventions:** DRUG: sodium pyruvate nasal spray
* **Primary endpoint(s):** Change in energy/fatigue scale
* **Results posted:** Yes
* **Efficacy (as posted):** Results available
* **Adverse events (any):** Data in results
* **Serious AEs:** Data in results

---

## **🟨NCT05592418 — Ampligen (Rintatolimod)**

* Study to Evaluate the Efficacy and Safety of Ampligen (Rintatolimod) in Long COVID
* **Sponsor:** AIM ImmunoTech Inc.
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2023-01-17; Primary Completion 2024-02-21; Completion 2024-02-21
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 80
* **Arms:** Placebo; Ampligen
* **Interventions:** DRUG: Rintatolimod; DRUG: Placebo/Normal Saline
* **Primary endpoint(s):** Change From Baseline to Week 13 in PROMIS Fatigue (T-score)
* **Results posted:** Yes
* **Efficacy (as posted):** Results available
* **Adverse events (any):** Data in results
* **Serious AEs:** Data in results

---

## **🟨NCT03554265 — Somatropin**

* Brain and Gut Plasticity in Mild TBI with **Somatropin** *[NOT Metoprolol]*
* **Sponsor:** **The University of Texas Medical Branch, Galveston**
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2018-06-01; Primary Completion 2023-02-28; Completion 2023-02-28
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 60
* **Arms:** Active Drug; Placebo
* **Interventions:** DRUG: Somatropin
* **Primary endpoint(s):** Lean Body Mass by DXA; Brain Network Connectivity
* **Results posted:** Yes
* **Efficacy (as posted):** Results available
* **Adverse events (any):** Data in results
* **Serious AEs:** Data in results

---

## **🟨NCT05121766 — Omega-3**

* Feasibility Pilot Clinical Trial of Omega-3 Supplementation for Long COVID
* **Sponsor:** **Hackensack Meridian Health**
* **Phase / Status:** Phase 1 / Terminated
* **Dates:** Start 2021-11-04; Primary Completion 2023-03-30; Completion 2023-03-30
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 10
* **Arms:** Omega-3; Placebo
* **Interventions:** DRUG: Omega-3 (EPA+DHA); DRUG: Placebo
* **Primary endpoint(s):** Feasibility endpoints; Inflammatory biomarkers
* **Results posted:** Yes
* **Efficacy (as posted):** Feasibility study
* **Adverse events (any):** Data in results
* **Serious AEs:** Data in results

---

## **🟨NCT05472090 — TNX-102 SL**

* Phase 2 Study to Evaluate TNX-102 SL in Long COVID
* **Sponsor:** Tonix Pharmaceuticals, Inc.
* **Phase / Status:** Phase 2 / Completed
* **Dates:** Start 2022-08-08; Primary Completion 2023-12-19; Completion 2023-12-19
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 63
* **Arms:** TNX-102 SL; Placebo
* **Interventions:** DRUG: TNX-102 SL; DRUG: Placebo SL Tablet
* **Primary endpoint(s):** Daily Diary Pain NRS
* **Results posted:** Yes
* **Efficacy (as posted):** Results available
* **Adverse events (any):** Data in results
* **Serious AEs:** Data in results

---

## **NCT05618587 — Lithium** 

* Effect of **Lithium** Therapy on Long COVID Symptoms *[NOT Nicotinamide Riboside/L-Carnitine]*
* **Sponsor:** **State University of New York at Buffalo**
* **Phase / Status:** Phase 1 / Unknown
* **Dates:** Start NA; Primary Completion NA; Completion NA
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: NA
* **Arms:** Lithium: EXPERIMENTAL | Placebo: PLACEBO_COMPARATOR
* **Interventions:** DRUG: Lithium; DRUG: Placebo
* **Primary endpoint(s):** Fatigue measures
* **Results posted:** No
* **Efficacy (as posted):** NA
* **Adverse events (any):** NA
* **Serious AEs:** NA

---

## **NCT05597722 — Amphetamine-Dextroamphetamine**

* **Digital Cognitive Behavioral Intervention + Amphetamine-Dextroamphetamine** for Long COVID *[NOT Guanfacine/NAC]*
* **Sponsor:** **Eva Szigethy**
* **Phase / Status:** NA / Active
* **Dates:** Multiple start dates by site
* **Design:** RANDOMIZED PARALLEL; Masking: None; Enrollment: 100
* **Arms:** Digital CBT: EXPERIMENTAL | Amphetamine-Dextroamphetamine: EXPERIMENTAL
* **Interventions:** BEHAVIORAL: Digital cognitive behavioral intervention-RxWell; DRUG: Amphetamine-Dextroamphetamine
* **Primary endpoint(s):** Cognitive assessments
* **Results posted:** No
* **Efficacy (as posted):** NA
* **Adverse events (any):** NA
* **Serious AEs:** NA

---

## **NCT04809974 — Niagen**

* Clinical Trial of **Niagen** (Nicotinamide Riboside) in Long COVID *[NOT Colchicine]*
* **Sponsor:** **Massachusetts General Hospital**
* **Phase / Status:** Phase 2 / Unknown
* **Dates:** Start 2022-01-14; Primary Completion NA; Completion NA
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: NA
* **Arms:** Niagen: EXPERIMENTAL | Placebo: PLACEBO_COMPARATOR
* **Interventions:** DIETARY SUPPLEMENT: Niagen
* **Primary endpoint(s):** Fatigue reduction
* **Results posted:** No
* **Efficacy (as posted):** NA
* **Adverse events (any):** NA
* **Serious AEs:** NA

---

## **NCT05074888 — Prospekta**

* Clinical Trial of **Prospekta** for Long COVID *[NOT Ivabradine]*
* **Sponsor:** **Materia Medica Holding**
* **Phase / Status:** Phase 3 / Unknown
* **Dates:** Start 2021-12-03; Primary Completion 2022-05-07; Completion 2022-05-07
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 476
* **Arms:** Prospekta: EXPERIMENTAL | Placebo: PLACEBO_COMPARATOR
* **Interventions:** DRUG: Prospekta; DRUG: Placebo
* **Primary endpoint(s):** Fatigue reduction
* **Results posted:** No
* **Efficacy (as posted):** NA
* **Adverse events (any):** NA
* **Serious AEs:** NA

---

## **NCT05946551 — Cetirizine + Famotidine**

* Treatment of Long COVID (TLC Study) with **Cetirizine + Famotidine** *[NOT Fluvoxamine]*
* **Sponsor:** **Emory University**
* **Phase / Status:** Phase 2 / Unknown
* **Dates:** Various
* **Design:** FACTORIAL; Masking: QUADRUPLE; Enrollment: NA
* **Arms:** 2x2 factorial design with Cetirizine and Famotidine
* **Interventions:** DRUG: Cetirizine; DRUG: Famotidine; DRUG: Cetirizine Placebo; DRUG: Famotidine Placebo
* **Primary endpoint(s):** Symptom burden
* **Results posted:** No
* **Efficacy (as posted):** NA
* **Adverse events (any):** NA
* **Serious AEs:** NA

---

## **NCT05096884 — Metoprolol**

* Post-Acute Sequelae of COVID-19 with Metoprolol
* **Sponsor:** **Hackensack Meridian Health**
* **Phase / Status:** Phase 2 / Unknown
* **Dates:** Start 2022-01-12; Primary Completion 2022-12-31; Completion 2022-12-31
* **Design:** RANDOMIZED PARALLEL; Masking: TRIPLE; Enrollment: NA
* **Arms:** Metoprolol; Placebo
* **Interventions:** DRUG: Metoprolol Succinate
* **Primary endpoint(s):** Change in 6-Minute Walk Test
* **Results posted:** No
* **Efficacy (as posted):** NA
* **Adverse events (any):** NA
* **Serious AEs:** NA

---

## **NCT05877508 — Monoclonal Antibodies (AER002)**

* Anti-SARS-CoV-2 Monoclonal Antibodies (AER002) for Long COVID
* **Sponsor:** **Michael Peluso, MD**
* **Phase / Status:** Phase 2 / Active
* **Dates:** Start 2023-08-18; Primary Completion NA; Completion NA
* **Design:** RANDOMIZED PARALLEL; Masking: QUADRUPLE; Enrollment: 120
* **Arms:** AER002; Placebo
* **Interventions:** DRUG: AER002; OTHER: Placebo
* **Primary endpoint(s):** PROMIS Fatigue change
* **Results posted:** No
* **Efficacy (as posted):** NA
* **Adverse events (any):** NA
* **Serious AEs:** NA

---

## **Full Comparison**

| Trial_ID | Title | Phase | Status | Enrollment | Masking | Start_Date | Completion_Date | Primary_Outcome | Drug | AE (any) | E (efficacy note) | S (serious AEs) |
|----------|-------|-------|--------|------------|---------|------------|-----------------|-----------------|------|----------|-------------------|-----------------|
| 🟩 NCT04678830 | Double Blind, Placebo Controlled Study of Safety and Efficacy of Leronlimab | 2 | Completed | 56 | QUADRUPLE | 2021-03-01 | 2021-07-08 | Day-56 symptom score change | Leronlimab (700mg) | 22/28 (active) vs 20/28 (placebo) | Symptom score Δ: −16.3 vs −8.1 | 0/28 (active) vs 1/28 (placebo) |
| 🟩 NCT04880161 | A Study to Evaluate Ampion in Prolonged Respiratory Symptoms | 1 | Completed | 32 | QUADRUPLE | 2021-07-26 | 2022-02-21 | Treatment-Emergent Adverse Events | Ampion (inhalation) | 8/15 (active) vs 9/16 (placebo) | Safety study - TEAEs | 0/15 (active) vs 0/16 (placebo) |
| 🟩 NCT05633407 | Efficacy and Safety of Efgartigimod in Adults With POTS | 2 | Completed | 53 | QUADRUPLE | 2022-09-23 | 2024-04-18 | Safety and tolerability | Efgartigimod | 31/36 (active) vs 14/17 (placebo) | Safety/tolerability primary | 0/36 (active) vs 0/17 (placebo) |
| 🟩 NCT05126563 | Allogeneic Adipose Derived MSCs for Long COVID | 2 | Completed | 40 | TRIPLE | 2021-07-30 | 2022-10-07 | Incidence of TEAEs and SAEs | HB-adMSCs (allogeneic) | 3/21 (active) vs 3/19 (placebo) | Safety primary endpoint | 0/21 (active) vs 0/19 (placebo) |
| 🟩 NCT05576662 | Paxlovid for Treatment of Long COVID | 2 | Completed | 168 | QUADRUPLE | 2022-11-08 | 2023-09-12 | Symptom severity composite | Nirmatrelvir/ritonavir | 100/102 (active) vs 48/53 (placebo) | Symptom severity composite | 3/102 (active) vs 1/53 (placebo) |
| 🟩 NCT05047952 | Vortioxetine for Post-COVID-19 Condition | 2 | Completed | 149 | QUADRUPLE | 2021-09-16 | 2023-02-22 | Cognitive/symptom scale | Vortioxetine | 47/75 (active) vs 34/74 (placebo) | Cognitive/symptom scale change | 0/75 (active) vs 0/74 (placebo) |
| 🟨 NCT04871815 | Effects of Sodium Pyruvate Nasal Spray in COVID-19 | 2 | Completed | 122 | QUADRUPLE | 2021-05-17 | 2022-01-25 | Change in energy/fatigue scale | Sodium Pyruvate | Data in results | Energy/fatigue scale change | Data in results |
| 🟨 NCT05592418 | Efficacy and Safety of Ampligen in Long COVID | 2 | Completed | 80 | QUADRUPLE | 2023-01-17 | 2024-02-21 | PROMIS Fatigue T-score | Ampligen (rintatolimod) | Data in results | PROMIS Fatigue change | Data in results |
| 🟨 NCT03554265 | Brain and Gut Plasticity in Mild TBI/PASC | 2 | Completed | 60 | QUADRUPLE | 2018-06-01 | 2023-02-28 | Lean Mass/Brain Connectivity | Somatropin | Data in results | Lean mass/brain connectivity | Data in results |
| 🟨 NCT05121766 | Omega-3 Supplementation for Long COVID | 1 | Terminated | 10 | QUADRUPLE | 2021-11-04 | 2023-03-30 | Feasibility/Inflammatory biomarkers | Omega-3 (EPA+DHA) | Data in results | Feasibility study | Data in results |
| 🟨 NCT05472090 | TNX-102 SL in Long COVID | 2 | Completed | 63 | QUADRUPLE | 2022-08-08 | 2023-12-19 | Daily Diary Pain NRS | TNX-102 SL | Data in results | Daily pain NRS | Data in results |
| NCT05618587 | Effect of Lithium Therapy on Long COVID | 1 | Unknown | Unknown | QUADRUPLE | Unknown | Unknown | Fatigue measures | Lithium | NA | NA | NA |
| NCT05597722 | Addressing Cognitive Fog in Long-COVID | No Phase | Active | 100 | NONE | Various | Unknown | Cognitive assessments | Digital CBT + Amphetamine | NA | NA | NA |
| NCT04809974 | Clinical Trial of Niagen for Long-COVID | 2 | Unknown | Unknown | QUADRUPLE | 2022-01-14 | Unknown | Fatigue reduction | Niagen (Nicotinamide Riboside) | NA | NA | NA |
| NCT05074888 | Clinical Trial of Prospekta | 3 | Unknown | 476 | QUADRUPLE | 2021-12-03 | 2022-05-07 | Fatigue reduction | Prospekta | NA | NA | NA |
| NCT05946551 | Treatment of Long COVID (TLC Study) | 2 | Recruiting | Unknown | QUADRUPLE | Various | Unknown | Symptom burden | Cetirizine + Famotidine | NA | NA | NA |
| NCT05096884 | Post-Acute Sequelae of COVID-19 - Metoprolol | 2 | Unknown | Unknown | TRIPLE | 2022-01-12 | 2022-12-31 | 6-Minute Walk Test | Metoprolol Succinate | NA | NA | NA |
| NCT05877508 | Anti-SARS-CoV-2 mAb (AER002) for Long COVID | 2 | Recruiting | 120 | QUADRUPLE | 2023-08-18 | Unknown | PROMIS Fatigue change | AER002 | NA | NA | NA |

## **AE/E/S Comparison**

| Trial_ID    | Drug                                        | AE (any)                           | E (efficacy note)                                  | S (serious AEs)                  |
| ----------- | ------------------------------------------- | ---------------------------------- | -------------------------------------------------- | -------------------------------- |
| 🟩 NCT04678830 | Leronlimab (PRO140)                        | 22/28 (active) vs 20/28 (placebo) | Symptom score Δ: −16.3 (active) vs −8.1 (placebo) | 0/28 (active) vs 1/28 (placebo) |
| 🟩 NCT04880161 | Ampion (inhalation)                        | 8/15 (active) vs 9/16 (placebo)   | Safety study - primary was TEAEs                  | 0/15 (active) vs 0/16 (placebo) |
| 🟩 NCT05633407 | Efgartigimod                               | 31/36 (active) vs 14/17 (placebo) | Safety/tolerability primary                       | 0/36 (active) vs 0/17 (placebo) |
| 🟩 NCT05126563 | HB-adMSCs (allogeneic)                     | 3/21 (active) vs 3/19 (placebo)   | Safety primary endpoint                           | 0/21 (active) vs 0/19 (placebo) |
| 🟩 NCT05576662 | Nirmatrelvir/ritonavir (Paxlovid)         | 100/102 (active) vs 48/53 (placebo)| Symptom severity composite                       | 3/102 (active) vs 1/53 (placebo)|
| 🟩 NCT05047952 | Vortioxetine                               | 47/75 (active) vs 34/74 (placebo) | Cognitive/symptom scale change                    | 0/75 (active) vs 0/74 (placebo) |
| 🟨 NCT04871815 | Sodium Pyruvate Nasal Spray                | Data in results                    | Energy/fatigue scale change                       | Data in results                  |
| 🟨 NCT05592418 | Ampligen (rintatolimod)                   | Data in results                    | PROMIS Fatigue change                             | Data in results                  |
| 🟨 NCT03554265 | Somatropin                                 | Data in results                    | Lean mass/brain connectivity                      | Data in results                  |
| 🟨 NCT05121766 | Omega-3 (EPA+DHA)                         | Data in results                    | Feasibility study                                 | Data in results                  |
| 🟨 NCT05472090 | TNX-102 SL (cyclobenzaprine SL)           | Data in results                    | Daily pain NRS                                    | Data in results                  |
| NCT05597722 | Digital CBT + Amphetamine-Dextroamphetamine| NA                                 | NA                                                 | NA                               |
| NCT04809974 | Niagen (Nicotinamide Riboside)            | NA                                 | NA                                                 | NA                               |
| NCT05074888 | Prospekta                                  | NA                                 | NA                                                 | NA                               |
| NCT05946551 | Cetirizine + Famotidine                   | NA                                 | NA                                                 | NA                               |
| NCT05096884 | Metoprolol Succinate                       | NA                                 | NA                                                 | NA                               |
| NCT05877508 | AER002 (anti-SARS-CoV-2 mAb)              | NA                                 | NA                                                 | NA                               |
| NCT05618587 | Lithium                                    | NA                                 | NA                                                 | NA                               |

**Summary:**
* **TOTAL:** 18 Studies
* 🟩: 6 studies with AE/S/E in the .json files
* 🟨: 5 studies with AE/S/E maybe in other documents
* Without AE/S/E: 7 studies

# **Environment**

## **Conda**

**Kinds:**
- planet: For parsing.
- planet_clean: For predicting.
- gene_mapper: To map the causal genes with drugs.

1.  Create the Conda (or venv) environment in Bash:

```
conda create -n planet python=3.8
conda activate planet
```

In [7]:
# ============================================================================
# Core Scientific Computing & Data Analysis
# ============================================================================
numpy==1.24.4
scipy==1.10.1
pandas==2.0.3
scikit-learn==1.3.2
statsmodels==0.14.1
patsy==1.0.1

# ============================================================================
# Deep Learning & PyTorch Ecosystem
# ============================================================================
torch==1.13.1+cpu
torchvision==0.14.1+cpu
torch-geometric==2.6.1
torch_cluster==1.6.3
torch_scatter==2.1.2
torch_sparse==0.6.18
torch_spline_conv==1.2.2

# ============================================================================
# Machine Learning & NLP
# ============================================================================
transformers==4.19.4
tokenizers==0.12.1
datasets==3.1.0
huggingface-hub==0.33.5
gensim==4.3.3
nltk==3.6.6
sacremoses @ file:///home/conda/feedstock_root/build_artifacts/sacremoses_1651557636210/work

# ============================================================================
# Visualization
# ============================================================================
matplotlib @ file:///croot/matplotlib-suite_1693812469450/work
matplotlib-inline==0.1.7
contourpy @ file:///opt/conda/conda-bld/contourpy_1663827406301/work
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
fonttools @ file:///croot/fonttools_1713551344105/work
kiwisolver @ file:///croot/kiwisolver_1672387140495/work
pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1719903565503/work

# ============================================================================
# Experiment Tracking & Monitoring
# ============================================================================
tensorboard==2.14.0
tensorboard-data-server==0.7.2
wandb==0.23.0
sentry-sdk==2.44.0

# ============================================================================
# Jupyter & IPython
# ============================================================================
ipython==8.12.3
ipykernel==6.29.5
jupyter_client==8.6.3
jupyter_core==5.8.1
nest-asyncio==1.6.0

# ============================================================================
# Bioinformatics & Chemistry
# ============================================================================
goatools==1.0.15
obonet==0.2.5
PubChemPy==1.0.4
ogb==1.3.6

# ============================================================================
# Graph & Network Analysis
# ============================================================================
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1680692919326/work
pydot==4.0.1

# ============================================================================
# Data Processing & Storage
# ============================================================================
pyarrow==17.0.0
dill==0.3.8
fsspec==2024.9.0
xxhash==3.6.0
multiprocess==0.70.16

# ============================================================================
# File Format Support
# ============================================================================
xlrd==1.2.0
xlsxwriter==3.2.5
lxml==6.0.0
PyYAML==6.0.2

# ============================================================================
# HTTP & Network
# ============================================================================
requests==2.32.4
urllib3==2.2.3
aiohttp==3.10.11
aiohappyeyeballs==2.4.4
aiosignal==1.3.1
async-timeout==5.0.1
frozenlist==1.5.0
multidict==6.1.0
yarl==1.15.2
propcache==0.2.0

# ============================================================================
# Authentication & Security
# ============================================================================
certifi==2025.7.14
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1721521265753/work
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1722587090966/work
google-auth==2.43.0
google-auth-oauthlib==1.0.0
oauthlib==3.3.1
requests-oauthlib==2.0.0
rsa==4.9.1
pyasn1==0.6.1
pyasn1_modules==0.4.2

# ============================================================================
# Compression & Encoding
# ============================================================================
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1666788425425/work
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1695621656497/work
zstandard @ file:///home/conda/feedstock_root/build_artifacts/zstandard_1667296101734/work

# ============================================================================
# Protocol & Communication
# ============================================================================
grpcio==1.70.0
protobuf==5.29.5
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1634280454336/work
hpack==4.0.0
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1619110129307/work

# ============================================================================
# Version Control & Git
# ============================================================================
GitPython==3.1.45
gitdb==4.0.12
hf-xet==1.1.5

# ============================================================================
# GUI Framework
# ============================================================================
PyQt5==5.15.10
PyQt5-sip @ file:///croot/pyqt-split_1698769088074/work/pyqt_sip
sip @ file:///croot/sip_1698675935381/work

# ============================================================================
# Utilities & Helpers
# ============================================================================
click @ file:///home/conda/feedstock_root/build_artifacts/click_1692311806742/work
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
tqdm==4.67.1
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1714665484399/work
cachetools==5.5.2
filelock==3.16.1
wget==3.2
docopt==0.6.2
littleutils==0.2.4
outdated==0.2.2
wrapt==1.17.2

# ============================================================================
# Text Processing & Parsing
# ============================================================================
charset-normalizer==2.1.1
idna==3.10
regex==2024.11.6
stemming==1.0.1
Pygments==2.19.2
pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
ply==3.11

# ============================================================================
# Template & Markup
# ============================================================================
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1715127149914/work
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1706899923320/work
Markdown==3.7

# ============================================================================
# Validation & Type Checking
# ============================================================================
pydantic==2.10.6
pydantic_core==2.27.2
annotated-types==0.7.0
eval_type_backport==0.2.2

# ============================================================================
# Python Compatibility & Extensions
# ============================================================================
typing_extensions==4.13.2
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1726082825846/work
importlib_resources @ file:///croot/importlib_resources-suite_1720641103994/work
dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work
zipp @ file:///croot/zipp_1729012354496/work

# ============================================================================
# Time & Date
# ============================================================================
python-dateutil @ file:///croot/python-dateutil_1716495738603/work
pytz==2025.2
tzdata==2025.2

# ============================================================================
# System & Process Management
# ============================================================================
psutil==7.0.0
platformdirs==4.3.6

# ============================================================================
# Conda & Environment Management
# ============================================================================
conda-pack @ file:///home/conda/feedstock_root/build_artifacts/conda-pack_1719323969738/work

# ============================================================================
# Debugging & Development
# ============================================================================
debugpy==1.8.14
ipython==8.12.3
jedi==0.19.2
parso==0.8.4
executing==2.2.0
stack-data==0.6.3
pure_eval==0.2.3
asttokens==3.0.0
decorator==5.2.1
backcall==0.2.0
pickleshare==0.7.5
prompt_toolkit==3.0.51
wcwidth==0.2.13
traitlets==5.14.3

# ============================================================================
# Terminal & PTY
# ============================================================================
pexpect==4.9.0
ptyprocess==0.7.0

# ============================================================================
# Communication & IPC
# ============================================================================
pyzmq==27.0.0
tornado @ file:///croot/tornado_1718740109488/work
comm==0.2.2

# ============================================================================
# Web Framework
# ============================================================================
Werkzeug==3.0.6

# ============================================================================
# Low-level Dependencies
# ============================================================================
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1723018376978/work
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1711811537435/work
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
unicodedata2 @ file:///croot/unicodedata2_1713212950228/work

# ============================================================================
# Math & Symbolic Computing
# ============================================================================
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1715527302982/work
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1728484478345/work

# ============================================================================
# Misc Dependencies
# ============================================================================
absl-py==2.3.1
attrs==25.3.0
packaging==25.0
six @ file:///tmp/build/80754af9/six_1644875935023/work
smart_open==7.3.0.post1
smmap==5.0.2
threadpoolctl==3.5.0

SyntaxError: invalid syntax (1083988590.py, line 4)

## **Packages**

2.  Install the packages used by the XML/JSON parsing utilities:
```
pip install -r parsing_package/requirements.txt
```

Packages:
- `lxml, beautifulsoup4`: XML parsing
- `tqdm, pandas, numpy`: generic data wrangling
- `transformers==4.x`: BioMed‑BERT encoder
- `scikit‑learn, umls‑client`: small helper libs

3.  Add the graph‑learning stack for the GCN models

`gcn_models` depends mainly on PyTorch + PyTorch‑Geometric.

Install a CUDA or CPU build depending on the hardware:

3‑A.  Core PyTorch:

CPU build (works everywhere):
```
pip install torch==1.13.1+cpu torchvision --extra-index-url https://download.pytorch.org/whl/cpu
```

GPU build (example for CUDA 11.8 – see pytorch.org for other versions):
```
pip install torch==1.13.1+cu118 torchvision --extra-index-url https://download.pytorch.org/whl/cu118
```

3‑B.  PyTorch‑Geometric and its companions:

```
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv \
            -f https://data.pyg.org/whl/torch-1.13.1+cpu.html      # swap +cpu for +cu118 if using CUDA
pip install torch-geometric==2.4.0
```

**Obs.:** The code in `gcn_models` is built using the functions and tools provided by PyTorch and PyTorch‑Geometric.

**PyTorch**: This is the primary machine learning framework used. It provides the fundamental building blocks for neural networks, such as tensors (multi-dimensional arrays) and the automatic differentiation required for training models. The code frequently imports it with `import torch`.

**PyTorch-Geometric (PyG)**: This is a specialized library that extends PyTorch for working specifically with graphs. The `gcn_models` code relies on it for essential graph-related components, including:
- **Graph Convolutional Layers**: It provides pre-built layers necessary for Graph Convolutional Networks (GCNs), such as `RGCNConv` and `RGATConv`.
- **Graph Data Handling**: It defines how graph data, including nodes and edges, is structured and processed.
- **Graph Sampling**: It includes tools like `NeighborSampler` to efficiently sample subgraphs for training on large graphs.

In short, PyTorch provides the general deep learning capabilities, while PyTorch-Geometric provides the specialized tools needed to apply those capabilities to graph data, which is central to how the `gcn_models` work.

4. Install optional scientific helpers

```
pip install pyyaml networkx
```

## **GitHub Repo Data**

5. Download the full bundle from the following website (the one in the GitHub repo is smaller):

```
wget https://snap.stanford.edu/planet/data.zip
wget https://snap.stanford.edu/planet/parsing_package.zip
```

6. Unzip the bundles:

```
unzip -o data.zip             # populates ./data/… with drug_data/, disease_data/, …
unzip -o parsing_package.zip  # adds resources inside ./parsing_package/
```

## **GitHub Repo Tree**

```
.
├── LICENSE
├── README.md
├── gcn_models
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-38.pyc
│   │   ├── clf_train.cpython-38.pyc
│   │   ├── clf_trainer.cpython-38.pyc
│   │   ├── data_loader.cpython-38.pyc
│   │   ├── decoders.cpython-38.pyc
│   │   ├── encoders.cpython-38.pyc
│   │   ├── evaluator.cpython-38.pyc
│   │   ├── layers.cpython-38.pyc
│   │   ├── link_pred_models.cpython-38.pyc
│   │   ├── sampler.cpython-38.pyc
│   │   ├── train.cpython-38.pyc
│   │   ├── trainers.cpython-38.pyc
│   │   └── utils.cpython-38.pyc
│   ├── clf_train.py
│   ├── clf_trainer.py
│   ├── conv_layers
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-38.pyc
│   │   │   ├── hgt_conv.cpython-38.pyc
│   │   │   ├── residual_rgcn.cpython-38.pyc
│   │   │   ├── rgat_conv.cpython-38.pyc
│   │   │   ├── rgat_conv_simple.cpython-38.pyc
│   │   │   └── rgcn_concat.cpython-38.pyc
│   │   ├── hgt_conv.py
│   │   ├── residual_rgcn.py
│   │   ├── rgat_conv.py
│   │   ├── rgat_conv_simple.py
│   │   ├── rgcn_concat.py
│   │   └── rgcn_conv.py
│   ├── data_loader.py
│   ├── decoders.py
│   ├── encoders.py
│   ├── evaluator.py
│   ├── layers.py
│   ├── link_pred_models.py
│   ├── node_classification_models
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-38.pyc
│   │   │   └── classification_model.cpython-38.pyc
│   │   └── classification_model.py
│   ├── sampler.py
│   ├── train.py
│   ├── trainers.py
│   └── utils.py
├── knowledge_graph
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-38.pyc
│   │   └── kg.cpython-38.pyc
│   ├── kg.py
│   └── node_features
│       ├── __init__.py
│       ├── map_merge_feats.py
│       ├── text_bert_features.py
│       └── trial_attribute_features.py
├── notebooks
│   ├── demo.ipynb
│   ├── parse_clinical_trial.ipynb
│   ├── predict_all_for_new_clinial_trial.py
│   ├── predict_for_new_clinial_trial.ipynb
│   ├── predict_for_new_clinial_trial__CPU.ipynb
│   ├── small_data
│   │   ├── ae1017_idx2aename.pkl
│   │   ├── ae_kgid2name.pkl
│   │   └── trial_data_NCT02370680.pkl
│   └── utils
│       ├── __pycache__
│       │   └── demo_utils.cpython-38.pyc
│       ├── demo_utils.py
│       └── text_bert_features.py
├── parsing_package
│   ├── README
│   ├── data_parsers
│   │   ├── __init__.py
│   │   ├── disease_extract.py
│   │   ├── external_tools
│   │   │   ├── __init__.py
│   │   │   ├── criteria2query
│   │   │   │   ├── README.md
│   │   │   │   ├── __init__.py
│   │   │   │   ├── criteria_input.py
│   │   │   │   └── run_crit2query.py
│   │   │   ├── medex
│   │   │   │   ├── README.md
│   │   │   │   ├── __init__.py
│   │   │   │   ├── index.html
│   │   │   │   ├── medex_input.py
│   │   │   │   ├── medex_output.py
│   │   │   │   └── run_medex.py
│   │   │   └── umls_search
│   │   │       ├── __init__.py
│   │   │       ├── api.py
│   │   │       └── auth.py
│   │   ├── medex_drug_extract.py
│   │   ├── outcome_measure.py
│   │   ├── population_extract.py
│   │   ├── result_arm_extract.py
│   │   └── umls_utils.py
│   ├── knowledge_graph
│   │   ├── __init__.py
│   │   ├── build_graph.py
│   │   ├── external_networks.py
│   │   ├── kg.py
│   │   └── node_features
│   │       ├── __init__.py
│   │       ├── __pycache__
│   │       │   └── index.html
│   │       ├── map_merge_feats.py
│   │       ├── text_bert_features.py
│   │       └── trial_attribute_features.py
│   ├── parse_trial.py
│   ├── preprocessing
│   │   ├── __init__.py
│   │   ├── ctrials_data.py
│   │   ├── disease_data.py
│   │   ├── drug_data.py
│   │   ├── network_construction.py
│   │   ├── network_old.py
│   │   └── trials_disease_drug_match.py
│   └── requirements.txt
└── scripts
    ├── train_ae.sh
    ├── train_efficacy.sh
    └── train_safety.sh
```

# **Main Scripts**

## **PARSE**

### **Step 1: Setup & Data Preparation**
1. Access https://clinicaltrials.gov/
2. Download the 18 Long COVID trial JSON files already collected
3. Create the KG with the parse trial Python script for all trials

### **Step 2: Validation with Completed Trials (🟩 Green - Complete Data)**
4. Run prediction for `AE`, `E`, and `S` for trials with complete data:
   - **NCT05576662** (Paxlovid) - Phase 2, completed, QUADRUPLE blind, N=168
   - **NCT05047952** (Vortioxetine) - Phase 2, completed, QUADRUPLE blind, N=149
   - **NCT04678830** (Leronlimab) - Phase 2, completed, QUADRUPLE blind, N=56
   - **NCT05633407** (Efgartigimod) - Phase 2, completed, QUADRUPLE blind, N=53
   - **NCT04880161** (Ampion) - Phase 1, completed, QUADRUPLE blind, N=32
   - **NCT05126563** (HB-adMSCs) - Phase 2, completed, TRIPLE blind, N=40

### **Step 3: Extract & Validate Partial Data Trials (🟨 Yellow - Data in Results)**
5. Extract and process results for trials with data available:
   - **NCT04871815** (Sodium Pyruvate) - Phase 2, completed, QUADRUPLE blind, N=122
   - **NCT05592418** (Ampligen) - Phase 2, completed, QUADRUPLE blind, N=80
   - **NCT05472090** (TNX-102 SL) - Phase 2, completed, QUADRUPLE blind, N=63
   - **NCT03554265** (Somatropin) - Phase 2, completed, QUADRUPLE blind, N=60
   - **NCT05121766** (Omega-3) - Phase 1, terminated, QUADRUPLE blind, N=10

### **Step 4: Sub-population Analysis**
6. Run predictions for sub-populations on validated trials:
   - Gender differences (male vs. female)
   - Age groups stratification
   - Comorbidity presence/absence
   - Disease severity at baseline

### **Step 5: Prospective Predictions (No Results Yet)**
7. Run predictions for ongoing/recruiting trials:
   - **NCT05946551** (Cetirizine + Famotidine) - Recruiting, factorial design
   - **NCT05877508** (AER002 mAb) - Recruiting, QUADRUPLE blind, N=120
   
8. Run predictions for trials with unknown status:
   - **NCT05618587** (Lithium) - Phase 1, QUADRUPLE blind
   - **NCT04809974** (Niagen) - Phase 2, QUADRUPLE blind
   - **NCT05074888** (Prospekta) - Phase 3, QUADRUPLE blind, N=476
   - **NCT05096884** (Metoprolol) - Phase 2, TRIPLE blind

### **Step 6: Weaker Evidence Trials**
9. Run predictions for non-standard design:
   - **NCT05597722** (Digital CBT + Amphetamine) - No blinding, no phase listed, N=100

**Analysis Priority:**
- **Strongest validation**: NCT05576662, NCT05047952 (Phase 2, completed, large N)
- **Strong validation**: Other completed Phase 2 trials with QUADRUPLE blinding
- **Moderate evidence**: Completed Phase 1 or terminated trials
- **Prospective prediction**: Recruiting/ongoing trials
- **Weakest evidence**: Unblinded or non-randomized trials

### **Step 7: Prepare Tree Folder**

```
.
├── NCT05576662.json
├── NCT05576662_results.pkl
├── NCT05576662_summary.json
├── PlaNet_Parse.e145679978
├── PlaNet_Parse.o145679978
├── PlaNet_Parse_log_145679978.gadi-pbs.txt
├── PlaNet_Parse_output_145679978.gadi-pbs.txt
├── README
├── __pycache__
│   └── parse_trial.cpython-38.pyc
├── data
│   ├── README.txt
│   ├── analysis
│   │   └── trial_embeddings
│   │       ├── disease_metadata.pkl
│   │       ├── disgrp.pkl
│   │       └── trial_embeddings_3u7di6ag.pkl
│   ├── clf_data
│   │   └── PT
│   │       ├── AE_embdict_OR2.pkl
│   │       ├── aes.dict
│   │       ├── data.pkl
│   │       ├── efficacy_survival_pairs.pickle
│   │       └── unique_arms.pkl
│   ├── data_31_7_21_umls.pkl
│   ├── disease_data
│   │   └── 2021
│   │       ├── c2021.bin
│   │       ├── d2021.bin
│   │       ├── disease_mappings.tsv
│   │       └── q2021.bin
│   ├── drug_data
│   │   ├── RXNCONSO.RRF
│   │   ├── drugs_all_03_04_21.pkl
│   │   ├── pubchem-drugbankid-synonyms.json
│   │   └── rxnorm2drugbank-umls.pkl
│   ├── drug_split.pkl
│   ├── graph
│   │   ├── KG_node2name.pkl
│   │   ├── entities.dict
│   │   ├── nctinfo.pkl
│   │   ├── relations.dict
│   │   ├── test.tsv
│   │   ├── train.tsv
│   │   ├── unique_arms.pkl
│   │   ├── valid.tsv
│   │   └── withdrawn_drugs.pkl
│   ├── kg_data
│   │   ├── external_data
│   │   │   ├── classyfire
│   │   │   │   └── ChemOnt_2_1.obo
│   │   │   ├── disgenet
│   │   │   │   ├── curated_gene_disease_associations.tsv
│   │   │   │   └── disease_mappings.tsv
│   │   │   ├── drug-phenotype.tsv
│   │   │   ├── drug-protein.tsv
│   │   │   ├── go
│   │   │   │   └── go-basic.obo
│   │   │   ├── multiscale-interactome
│   │   │   │   └── protein_to_protein.tsv
│   │   │   └── protein-function.tsv
│   │   ├── kg-entity2cid-31_7_21.pkl
│   │   └── node_features_armtext.pkl
│   ├── knowledge-graph-31_7_21.pkl
│   ├── models
│   │   ├── 3u7di6ag
│   │   │   ├── ckpt.pt
│   │   │   └── config.json
│   │   ├── ae_model_shxo9bgw
│   │   │   ├── ckpt.pt
│   │   │   └── config.json
│   │   ├── dragon
│   │   │   ├── config.json
│   │   │   ├── pytorch_model.bin
│   │   │   ├── special_tokens_map.json
│   │   │   ├── tokenizer.json
│   │   │   ├── tokenizer_config.json
│   │   │   └── vocab.txt
│   │   ├── efficacy_model_34l5ms9m
│   │   │   ├── ckpt.pt
│   │   │   └── config.json
│   │   └── safety_model_1xekl810
│   │       ├── ckpt.pt
│   │       └── config.json
│   ├── new_data_trial_arm_text.pkl
│   ├── node_features_armtext.pkl
│   ├── outcome_data
│   │   ├── clusters-outcome-measures.txt
│   │   ├── outcome_measures_phrase_bigram_model.pkl
│   │   └── outcome_measures_phrase_trigram_model.pkl
│   └── population_data
│       ├── tfidf_matcher_state.pkl
│       ├── umls-install
│       │   └── 2020AB
│       │       └── META
│       │           ├── MRCONSO.RRF
│       │           └── MRREL.RRF
│       ├── umls_cache.tar.gz
│       ├── umls_graph_clipper_output.pkl
│       └── umls_search_cache
│           └── 2020AB
│               ├── ckpt_0_10000.pkl
│               ├── ckpt_1000000_1010000.pkl
│               ├── ckpt_100000_110000.pkl
│               ├── ckpt_10000_20000.pkl
│               ├── ckpt_1010000_1020000.pkl
│               ├── ckpt_1020000_1030000.pkl
│               ├── ckpt_1030000_1040000.pkl
│               ├── ckpt_1040000_1050000.pkl
│               ├── ckpt_1050000_1060000.pkl
│               ├── ckpt_1060000_1070000.pkl
│               ├── ckpt_1070000_1080000.pkl
│               ├── ckpt_1080000_1090000.pkl
│               ├── ckpt_1090000_1100000.pkl
│               ├── ckpt_1100000_1110000.pkl
│               ├── ckpt_110000_120000.pkl
│               ├── ckpt_1110000_1120000.pkl
│               ├── ckpt_1120000_1130000.pkl
│               ├── ckpt_1130000_1140000.pkl
│               ├── ckpt_1140000_1150000.pkl
│               ├── ckpt_1150000_1160000.pkl
│               ├── ckpt_1160000_1170000.pkl
│               ├── ckpt_1170000_1180000.pkl
│               ├── ckpt_1180000_1190000.pkl
│               ├── ckpt_1190000_1200000.pkl
│               ├── ckpt_1200000_1210000.pkl
│               ├── ckpt_120000_130000.pkl
│               ├── ckpt_1210000_1220000.pkl
│               ├── ckpt_1220000_1230000.pkl
│               ├── ckpt_1230000_1240000.pkl
│               ├── ckpt_1240000_1250000.pkl
│               ├── ckpt_1250000_1260000.pkl
│               ├── ckpt_1260000_1270000.pkl
│               ├── ckpt_1270000_1280000.pkl
│               ├── ckpt_1280000_1290000.pkl
│               ├── ckpt_1290000_1300000.pkl
│               ├── ckpt_1300000_1310000.pkl
│               ├── ckpt_130000_140000.pkl
│               ├── ckpt_1310000_1320000.pkl
│               ├── ckpt_1320000_1330000.pkl
│               ├── ckpt_1330000_1340000.pkl
│               ├── ckpt_1340000_1350000.pkl
│               ├── ckpt_1350000_1360000.pkl
│               ├── ckpt_1360000_1370000.pkl
│               ├── ckpt_1370000_1380000.pkl
│               ├── ckpt_140000_150000.pkl
│               ├── ckpt_150000_160000.pkl
│               ├── ckpt_160000_170000.pkl
│               ├── ckpt_170000_180000.pkl
│               ├── ckpt_180000_190000.pkl
│               ├── ckpt_190000_200000.pkl
│               ├── ckpt_200000_210000.pkl
│               ├── ckpt_20000_30000.pkl
│               ├── ckpt_210000_220000.pkl
│               ├── ckpt_220000_230000.pkl
│               ├── ckpt_230000_240000.pkl
│               ├── ckpt_240000_250000.pkl
│               ├── ckpt_250000_260000.pkl
│               ├── ckpt_260000_270000.pkl
│               ├── ckpt_270000_280000.pkl
│               ├── ckpt_280000_290000.pkl
│               ├── ckpt_290000_300000.pkl
│               ├── ckpt_300000_310000.pkl
│               ├── ckpt_30000_40000.pkl
│               ├── ckpt_310000_320000.pkl
│               ├── ckpt_320000_330000.pkl
│               ├── ckpt_330000_340000.pkl
│               ├── ckpt_340000_350000.pkl
│               ├── ckpt_350000_360000.pkl
│               ├── ckpt_360000_370000.pkl
│               ├── ckpt_370000_380000.pkl
│               ├── ckpt_380000_390000.pkl
│               ├── ckpt_390000_400000.pkl
│               ├── ckpt_400000_410000.pkl
│               ├── ckpt_40000_50000.pkl
│               ├── ckpt_410000_420000.pkl
│               ├── ckpt_420000_430000.pkl
│               ├── ckpt_430000_440000.pkl
│               ├── ckpt_440000_450000.pkl
│               ├── ckpt_450000_460000.pkl
│               ├── ckpt_460000_470000.pkl
│               ├── ckpt_470000_480000.pkl
│               ├── ckpt_480000_490000.pkl
│               ├── ckpt_490000_500000.pkl
│               ├── ckpt_500000_510000.pkl
│               ├── ckpt_50000_60000.pkl
│               ├── ckpt_510000_520000.pkl
│               ├── ckpt_520000_530000.pkl
│               ├── ckpt_530000_540000.pkl
│               ├── ckpt_540000_550000.pkl
│               ├── ckpt_550000_560000.pkl
│               ├── ckpt_560000_570000.pkl
│               ├── ckpt_570000_580000.pkl
│               ├── ckpt_580000_590000.pkl
│               ├── ckpt_590000_600000.pkl
│               ├── ckpt_600000_610000.pkl
│               ├── ckpt_60000_70000.pkl
│               ├── ckpt_610000_620000.pkl
│               ├── ckpt_620000_630000.pkl
│               ├── ckpt_630000_640000.pkl
│               ├── ckpt_640000_650000.pkl
│               ├── ckpt_650000_660000.pkl
│               ├── ckpt_660000_670000.pkl
│               ├── ckpt_670000_680000.pkl
│               ├── ckpt_680000_690000.pkl
│               ├── ckpt_690000_700000.pkl
│               ├── ckpt_700000_710000.pkl
│               ├── ckpt_70000_80000.pkl
│               ├── ckpt_710000_720000.pkl
│               ├── ckpt_720000_730000.pkl
│               ├── ckpt_730000_740000.pkl
│               ├── ckpt_740000_750000.pkl
│               ├── ckpt_750000_760000.pkl
│               ├── ckpt_760000_770000.pkl
│               ├── ckpt_770000_780000.pkl
│               ├── ckpt_780000_790000.pkl
│               ├── ckpt_790000_800000.pkl
│               ├── ckpt_800000_810000.pkl
│               ├── ckpt_80000_90000.pkl
│               ├── ckpt_810000_820000.pkl
│               ├── ckpt_820000_830000.pkl
│               ├── ckpt_830000_840000.pkl
│               ├── ckpt_840000_850000.pkl
│               ├── ckpt_850000_860000.pkl
│               ├── ckpt_860000_870000.pkl
│               ├── ckpt_870000_880000.pkl
│               ├── ckpt_880000_890000.pkl
│               ├── ckpt_890000_900000.pkl
│               ├── ckpt_900000_910000.pkl
│               ├── ckpt_90000_100000.pkl
│               ├── ckpt_910000_920000.pkl
│               ├── ckpt_920000_930000.pkl
│               ├── ckpt_930000_940000.pkl
│               ├── ckpt_940000_950000.pkl
│               ├── ckpt_950000_960000.pkl
│               ├── ckpt_960000_970000.pkl
│               ├── ckpt_970000_980000.pkl
│               ├── ckpt_980000_990000.pkl
│               ├── ckpt_990000_1000000.pkl
│               └── state.json
├── data_parsers
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-312.pyc
│   │   ├── __init__.cpython-38.pyc
│   │   ├── disease_extract.cpython-312.pyc
│   │   ├── disease_extract.cpython-38.pyc
│   │   ├── medex_drug_extract.cpython-312.pyc
│   │   ├── medex_drug_extract.cpython-38.pyc
│   │   ├── outcome_measure.cpython-38.pyc
│   │   ├── population_extract.cpython-38.pyc
│   │   ├── result_arm_extract.cpython-38.pyc
│   │   └── umls_utils.cpython-38.pyc
│   ├── disease_extract.py
│   ├── external_tools
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   └── __init__.cpython-38.pyc
│   │   ├── criteria2query
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-38.pyc
│   │   │   │   └── criteria_input.cpython-38.pyc
│   │   │   ├── criteria_input.py
│   │   │   └── run_crit2query.py
│   │   ├── medex
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-38.pyc
│   │   │   │   ├── medex_input.cpython-38.pyc
│   │   │   │   └── medex_output.cpython-38.pyc
│   │   │   ├── index.html
│   │   │   ├── medex_input.py
│   │   │   ├── medex_output.py
│   │   │   └── run_medex.py
│   │   └── umls_search
│   │       ├── __init__.py
│   │       ├── __pycache__
│   │       │   ├── __init__.cpython-38.pyc
│   │       │   ├── api.cpython-38.pyc
│   │       │   └── auth.cpython-38.pyc
│   │       ├── api.py
│   │       └── auth.py
│   ├── medex_drug_extract.py
│   ├── outcome_measure.py
│   ├── population_extract.py
│   ├── result_arm_extract.py
│   └── umls_utils.py
├── debug_conditions.py
├── knowledge_graph
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-38.pyc
│   │   ├── build_graph.cpython-38.pyc
│   │   └── kg.cpython-38.pyc
│   ├── build_graph.py
│   ├── external_networks.py
│   ├── kg.py
│   └── node_features
│       ├── __init__.py
│       ├── __pycache__
│       │   ├── __init__.cpython-37.pyc
│       │   ├── __init__.cpython-38.pyc
│       │   ├── index.html
│       │   ├── map_merge_feats.cpython-37.pyc
│       │   ├── map_merge_feats.cpython-38.pyc
│       │   ├── text_bert_features.cpython-37.pyc
│       │   ├── trial_attribute_features.cpython-37.pyc
│       │   └── trial_attribute_features.cpython-38.pyc
│       ├── map_merge_feats.py
│       ├── text_bert_features.py
│       └── trial_attribute_features.py
├── nltk_data
│   ├── corpora
│   │   ├── stopwords
│   │   │   ├── README
│   │   │   ├── albanian
│   │   │   ├── arabic
│   │   │   ├── azerbaijani
│   │   │   ├── basque
│   │   │   ├── belarusian
│   │   │   ├── bengali
│   │   │   ├── catalan
│   │   │   ├── chinese
│   │   │   ├── danish
│   │   │   ├── dutch
│   │   │   ├── english
│   │   │   ├── finnish
│   │   │   ├── french
│   │   │   ├── german
│   │   │   ├── greek
│   │   │   ├── hebrew
│   │   │   ├── hinglish
│   │   │   ├── hungarian
│   │   │   ├── indonesian
│   │   │   ├── italian
│   │   │   ├── kazakh
│   │   │   ├── nepali
│   │   │   ├── norwegian
│   │   │   ├── portuguese
│   │   │   ├── romanian
│   │   │   ├── russian
│   │   │   ├── slovene
│   │   │   ├── spanish
│   │   │   ├── swedish
│   │   │   ├── tajik
│   │   │   ├── tamil
│   │   │   └── turkish
│   │   └── stopwords.zip
│   ├── taggers
│   │   ├── averaged_perceptron_tagger
│   │   │   └── averaged_perceptron_tagger.pickle
│   │   └── averaged_perceptron_tagger.zip
│   └── tokenizers
│       └── punkt
│           ├── PY3
│           │   ├── README
│           │   ├── czech.pickle
│           │   ├── danish.pickle
│           │   ├── dutch.pickle
│           │   ├── english.pickle
│           │   ├── estonian.pickle
│           │   ├── finnish.pickle
│           │   ├── french.pickle
│           │   ├── german.pickle
│           │   ├── greek.pickle
│           │   ├── italian.pickle
│           │   ├── malayalam.pickle
│           │   ├── norwegian.pickle
│           │   ├── polish.pickle
│           │   ├── portuguese.pickle
│           │   ├── russian.pickle
│           │   ├── slovene.pickle
│           │   ├── spanish.pickle
│           │   ├── swedish.pickle
│           │   └── turkish.pickle
│           ├── README
│           ├── czech.pickle
│           ├── danish.pickle
│           ├── dutch.pickle
│           ├── english.pickle
│           ├── estonian.pickle
│           ├── finnish.pickle
│           ├── french.pickle
│           ├── german.pickle
│           ├── greek.pickle
│           ├── italian.pickle
│           ├── malayalam.pickle
│           ├── norwegian.pickle
│           ├── polish.pickle
│           ├── portuguese.pickle
│           ├── russian.pickle
│           ├── slovene.pickle
│           ├── spanish.pickle
│           ├── swedish.pickle
│           └── turkish.pickle
├── parse_trial.py
├── planet_parse.pbs
├── preprocessing
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-312.pyc
│   │   ├── __init__.cpython-38.pyc
│   │   ├── disease_data.cpython-312.pyc
│   │   └── disease_data.cpython-38.pyc
│   ├── ctrials_data.py
│   ├── disease_data.py
│   ├── drug_data.py
│   ├── network_construction.py
│   ├── network_old.py
│   └── trials_disease_drug_match.py
├── quick_check_keys.py
├── requirements.txt
├── resources
│   ├── criteria2query.jar
│   ├── medex
│   │   └── Medex_UIMA_1.3.8
│   │       ├── bin
│   │       │   ├── org
│   │       │   │   └── apache
│   │       │   │       ├── NLPTools
│   │       │   │       │   ├── BZToken.class
│   │       │   │       │   ├── CFGparser
│   │       │   │       │   │   ├── EarleyChart.class
│   │       │   │       │   │   ├── EarleyEntry.class
│   │       │   │       │   │   ├── EarleyParser.class
│   │       │   │       │   │   ├── EntryItem.class
│   │       │   │       │   │   ├── Grammar.class
│   │       │   │       │   │   ├── MidGrammar.class
│   │       │   │       │   │   ├── Rule.class
│   │       │   │       │   │   ├── TreeNode.class
│   │       │   │       │   │   └── med_parser_grammar.class
│   │       │   │       │   ├── Document.class
│   │       │   │       │   ├── Global$ArrayIndexComparator.class
│   │       │   │       │   ├── Global$CharacterComparator.class
│   │       │   │       │   ├── Global$Section.class
│   │       │   │       │   ├── Global$SuffixArrayCaseMode.class
│   │       │   │       │   ├── Global$SuffixArrayMode.class
│   │       │   │       │   ├── Global$SuffixArrayNode.class
│   │       │   │       │   ├── Global$SuffixArrayResult.class
│   │       │   │       │   ├── Global$TextSectionType.class
│   │       │   │       │   ├── Global$Util.class
│   │       │   │       │   ├── Global.class
│   │       │   │       │   ├── Main.class
│   │       │   │       │   ├── Sentence.class
│   │       │   │       │   ├── SentenceBoundary.class
│   │       │   │       │   ├── SentenceBoundaryMain.class
│   │       │   │       │   ├── Stemmer.class
│   │       │   │       │   ├── Tag.class
│   │       │   │       │   ├── TextSection.class
│   │       │   │       │   ├── Token.class
│   │       │   │       │   └── Util.class
│   │       │   │       ├── RuleEngine
│   │       │   │       │   ├── Main.class
│   │       │   │       │   ├── MatchResult.class
│   │       │   │       │   ├── NormPattern.class
│   │       │   │       │   ├── Patterns.class
│   │       │   │       │   ├── ProcessingEngine.class
│   │       │   │       │   ├── Rule.class
│   │       │   │       │   └── SingleRule.class
│   │       │   │       ├── TIMEX
│   │       │   │       │   ├── Main.class
│   │       │   │       │   ├── MatchResult.class
│   │       │   │       │   ├── NormPattern.class
│   │       │   │       │   ├── Patterns.class
│   │       │   │       │   ├── ProcessingEngine.class
│   │       │   │       │   ├── Rule.class
│   │       │   │       │   └── SingleRule.class
│   │       │   │       ├── UIMA
│   │       │   │       │   ├── CPE
│   │       │   │       │   │   └── medex
│   │       │   │       │   │       ├── MedexUIMACPE.class
│   │       │   │       │   │       ├── SourceDocumentInformation.class
│   │       │   │       │   │       ├── SourceDocumentInformation_Type$1.class
│   │       │   │       │   │       ├── SourceDocumentInformation_Type.class
│   │       │   │       │   │       ├── UIMAMedexCollectionReader.class
│   │       │   │       │   │       └── UIMAMedexConsumer.class
│   │       │   │       │   └── medex
│   │       │   │       │       ├── BrandName.class
│   │       │   │       │       ├── BrandName_Type$1.class
│   │       │   │       │       ├── BrandName_Type.class
│   │       │   │       │       ├── DoseAmount.class
│   │       │   │       │       ├── DoseAmount_Type$1.class
│   │       │   │       │       ├── DoseAmount_Type.class
│   │       │   │       │       ├── Drug.class
│   │       │   │       │       ├── DrugAnnotator.class
│   │       │   │       │       ├── Drug_Type$1.class
│   │       │   │       │       ├── Drug_Type.class
│   │       │   │       │       ├── Duration.class
│   │       │   │       │       ├── Duration_Type$1.class
│   │       │   │       │       ├── Duration_Type.class
│   │       │   │       │       ├── Form.class
│   │       │   │       │       ├── Form_Type$1.class
│   │       │   │       │       ├── Form_Type.class
│   │       │   │       │       ├── Frequency.class
│   │       │   │       │       ├── Frequency_Type$1.class
│   │       │   │       │       ├── Frequency_Type.class
│   │       │   │       │       ├── Neccessity.class
│   │       │   │       │       ├── Neccessity_Type$1.class
│   │       │   │       │       ├── Neccessity_Type.class
│   │       │   │       │       ├── Route.class
│   │       │   │       │       ├── Route_Type$1.class
│   │       │   │       │       ├── Route_Type.class
│   │       │   │       │       ├── Strength.class
│   │       │   │       │       ├── Strength_Type$1.class
│   │       │   │       │       └── Strength_Type.class
│   │       │   │       ├── algorithms
│   │       │   │       │   ├── Classifier.class
│   │       │   │       │   ├── SuffixArray.class
│   │       │   │       │   ├── VectorSpaceModel$1.class
│   │       │   │       │   ├── VectorSpaceModel$2.class
│   │       │   │       │   ├── VectorSpaceModel$3.class
│   │       │   │       │   ├── VectorSpaceModel$SimDistanceType.class
│   │       │   │       │   └── VectorSpaceModel.class
│   │       │   │       └── medex
│   │       │   │           ├── DrugTag.class
│   │       │   │           ├── Encoder.class
│   │       │   │           ├── Lexicon.class
│   │       │   │           ├── Main.class
│   │       │   │           ├── MedTagger$1.class
│   │       │   │           ├── MedTagger.class
│   │       │   │           ├── RegexParser.class
│   │       │   │           ├── SemanticRuleEngine.class
│   │       │   │           ├── Util.class
│   │       │   │           └── semantic_rules
│   │       │   │               ├── Tag_rule_engine.class
│   │       │   │               ├── disambiguation.drl
│   │       │   │               ├── result_bean.class
│   │       │   │               ├── tags.class
│   │       │   │               ├── transformation1.drl
│   │       │   │               └── transformation2.drl
│   │       │   └── patch
│   │       │       ├── bugfixes.py
│   │       │       ├── drugnames_ct.txt
│   │       │       ├── drugnames_drugbank.txt
│   │       │       ├── drugnames_synonyms_pubchem.txt
│   │       │       ├── forms_drugbank.txt
│   │       │       ├── forms_rxnorm.txt
│   │       │       ├── medex-patch.py
│   │       │       ├── routes_drugbank.txt
│   │       │       ├── routes_fda.txt
│   │       │       ├── routes_wikipedia.txt
│   │       │       └── units_drugbank.txt
│   │       ├── lib
│   │       │   ├── alt-rt.jar
│   │       │   ├── antlr-2.7.7.jar
│   │       │   ├── antlr-3.3.jar
│   │       │   ├── antlr-runtime-3.3.jar
│   │       │   ├── bcmail-jdk14-138.jar
│   │       │   ├── bcprov-jdk14-138.jar
│   │       │   ├── charsets.jar
│   │       │   ├── commons-lang-2.6.jar
│   │       │   ├── deploy.jar
│   │       │   ├── dom4j-1.6.1.jar
│   │       │   ├── drools-clips-5.4.0.Final.jar
│   │       │   ├── drools-compiler-5.4.0.Final.jar
│   │       │   ├── drools-core-5.4.0.Final.jar
│   │       │   ├── drools-decisiontables-5.4.0.Final.jar
│   │       │   ├── drools-jsr94-5.4.0.Final.jar
│   │       │   ├── drools-persistence-jpa-5.4.0.Final.jar
│   │       │   ├── drools-templates-5.4.0.Final.jar
│   │       │   ├── drools-verifier-5.4.0.Final.jar
│   │       │   ├── ecj-3.5.1.jar
│   │       │   ├── guava-r06.jar
│   │       │   ├── hibernate-jpa-2.0-api-1.0.1.Final.jar
│   │       │   ├── itext-2.1.2.jar
│   │       │   ├── jVinci.jar
│   │       │   ├── jackson-annotations-2.12.2.jar
│   │       │   ├── jackson-core-2.12.2.jar
│   │       │   ├── jackson-databind-2.12.2.jar
│   │       │   ├── javassist-3.14.0-GA.jar
│   │       │   ├── javatuples-1.1.jar
│   │       │   ├── javaws.jar
│   │       │   ├── jce.jar
│   │       │   ├── jsr94-1.1.jar
│   │       │   ├── jsse.jar
│   │       │   ├── jta-1.1.jar
│   │       │   ├── jxl-2.6.10.jar
│   │       │   ├── knowledge-api-5.4.0.Final.jar
│   │       │   ├── knowledge-internal-api-5.4.0.Final.jar
│   │       │   ├── log4j-1.2.14.jar
│   │       │   ├── management-agent.jar
│   │       │   ├── mvel2-2.1.0.drools16.jar
│   │       │   ├── plugin.jar
│   │       │   ├── protobuf-java-2.4.1.jar
│   │       │   ├── resources.jar
│   │       │   ├── rt.jar
│   │       │   ├── slf4j-api-1.6.4.jar
│   │       │   ├── stringtemplate-3.2.1.jar
│   │       │   ├── uima-adapter-soap.jar
│   │       │   ├── uima-adapter-vinci.jar
│   │       │   ├── uima-core.jar
│   │       │   ├── uima-cpe.jar
│   │       │   ├── uima-document-annotation.jar
│   │       │   ├── uima-examples.jar
│   │       │   ├── uima-tools.jar
│   │       │   ├── uimaj-bootstrap.jar
│   │       │   ├── xml-apis-1.3.04.jar
│   │       │   ├── xmlpull-1.1.3.1.jar
│   │       │   ├── xpp3_min-1.1.4c.jar
│   │       │   └── xstream-1.4.1.jar
│   │       └── resources
│   │           ├── TIMEX
│   │           │   ├── norm_patterns
│   │           │   │   ├── Action
│   │           │   │   ├── ApproximatePatterns
│   │           │   │   ├── DaySection
│   │           │   │   ├── DaySectionLy
│   │           │   │   ├── DayUnit
│   │           │   │   ├── DayUnitLy
│   │           │   │   ├── DayWord
│   │           │   │   ├── Eat
│   │           │   │   ├── ExtactPatterns
│   │           │   │   ├── FREQDayUnit
│   │           │   │   ├── FrequentPatterns
│   │           │   │   ├── Holiday
│   │           │   │   ├── MonthWord
│   │           │   │   ├── Next
│   │           │   │   ├── NormFREQAlaphaNum
│   │           │   │   ├── NormFREQword
│   │           │   │   ├── NormThisThatTheOBJ
│   │           │   │   ├── NumWord
│   │           │   │   ├── Season
│   │           │   │   ├── ThisThatThe
│   │           │   │   ├── TimeDateBeginModifier
│   │           │   │   ├── TimeDateEndModifier
│   │           │   │   ├── TimeUnit
│   │           │   │   └── WeekDay
│   │           │   ├── patterns
│   │           │   │   ├── About
│   │           │   │   ├── ActionPatterns
│   │           │   │   ├── AfterTomorrow
│   │           │   │   ├── ApproximatePatterns
│   │           │   │   ├── BeforeYesterday
│   │           │   │   ├── DaySection
│   │           │   │   ├── DaySectionLy
│   │           │   │   ├── DayUnit
│   │           │   │   ├── DayUnitLy
│   │           │   │   ├── DayWord
│   │           │   │   ├── DurationPrep
│   │           │   │   ├── Eat
│   │           │   │   ├── End
│   │           │   │   ├── ExtactPatterns
│   │           │   │   ├── FREQAlaphaNum
│   │           │   │   ├── FREQAlaphaNum~
│   │           │   │   ├── FREQDayUnit
│   │           │   │   ├── FREQword
│   │           │   │   ├── FrequentPatterns
│   │           │   │   ├── Holiday
│   │           │   │   ├── HourWord
│   │           │   │   ├── LaborDay
│   │           │   │   ├── MonthWord
│   │           │   │   ├── More
│   │           │   │   ├── Next
│   │           │   │   ├── NumWord1D
│   │           │   │   ├── NumWord2D
│   │           │   │   ├── Number
│   │           │   │   ├── Per
│   │           │   │   ├── PostOP
│   │           │   │   ├── Previous
│   │           │   │   ├── Season
│   │           │   │   ├── Start
│   │           │   │   ├── ThisThatThe
│   │           │   │   ├── ThisThatTheOBJ
│   │           │   │   ├── TimeDateBeginModifier
│   │           │   │   ├── TimeDateEndModifier
│   │           │   │   ├── TimeUnit
│   │           │   │   ├── Today
│   │           │   │   ├── Tomorrow
│   │           │   │   ├── WeekDay
│   │           │   │   ├── Yesterday
│   │           │   │   └── duration_patterns
│   │           │   └── rules
│   │           │       └── frequency_rules
│   │           ├── abbr.txt
│   │           ├── brand_generic.cfg
│   │           ├── code.cfg
│   │           ├── findFPlexicion.py
│   │           ├── forms.txt
│   │           ├── grammar.txt
│   │           ├── lexicon.cfg
│   │           ├── norm.cfg
│   │           ├── rxcui_generic.cfg
│   │           ├── semantic_rules
│   │           │   ├── disambiguation.drl
│   │           │   ├── temp
│   │           │   ├── transformation1.drl
│   │           │   └── transformation2.drl
│   │           └── word.txt
│   └── medex.zip
├── setup_nltk_container.py
└── tmp
    └── trial_data_NCT02370680.pkl
```

### **Step 8: Check databases KG**

1. Clinical Trial JSON
2. Basic Parsing: Extract fields (interventions, conditions, etc.)
3. MedEx: Identify drug entities using pre-trained model
4. Criteria2Query: Parse eligibility using Stanford NLP models
5. Disease Matching: Map to MeSH using disease database
6. Drug Matching: Map to DrugBank/RxNorm IDs
7. Outcome Extraction: Classify outcomes using pre-trained clusters
8. Population Extraction: Map to UMLS concepts using TF-IDF
9. Knowledge Graph Building: Connect everything using the biomedical KG
10. Structured Output: Trial represented as knowledge graph edges

### **Step 9: Define Prerequisites**
- Access to GADI HPC system
- Create singularity image from definition file

```bash
singularity build --fakeroot planet.sif planet.def
```

### **Step 10: Set Up Directory Structure**
```bash
# Create working directory
mkdir -p /scratch/sq95/sp6154/planet/parsing_package
cd /scratch/sq95/sp6154/planet/parsing_package
```

### **Step 11: Copy/Transfer Required Files**
Ensure:
- `/scratch/sq95/sp6154/planet/planet.sif` (Singularity image)
- All data directories under `parsing_package/data/`
- Python scripts (`parse_trial.py`, etc.)
- Resource files (`resources/medex/`, `resources/criteria2query.jar`)

### **Step 12: Download NLTK Data (One-time setup)**
```bash
# On GADI login node
cd /scratch/sq95/sp6154/planet/parsing_package
mkdir -p nltk_data/corpora nltk_data/taggers

# Download stopwords
cd nltk_data/corpora
wget --no-check-certificate https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/stopwords.zip
unzip stopwords.zip && rm stopwords.zip

# Download POS taggers
cd ../taggers
wget --no-check-certificate https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/taggers/averaged_perceptron_tagger.zip
unzip averaged_perceptron_tagger.zip && rm averaged_perceptron_tagger.zip
cd ../..
```

### **Step 13: Update parse_trial.py**
Add NLTK path at the top of the file:
```python
#!/usr/bin/env python3
import nltk
nltk.data.path.append("/app/nltk_data")
```

Increase Java memory in two functions:
- In `run_medex_and_parse_output`: Change `-Xmx1024m` to `-Xmx4096m`
- In `parse_eligiility_criteria`: Change `-Xmx1024m` to `-Xmx8192m`

Add the conditions fix in `main()` function before disease matching.

In [None]:
import argparse
import requests
import os
import pickle
import math
import numpy as np
import tempfile
import json
import subprocess
import shlex
import pathlib
from typing import Any

from data_parsers.external_tools.medex import medex_input
from data_parsers.external_tools import medex

from data_parsers import DiseaseExtract
from data_parsers import CriteriaOutputParser
from data_parsers import DrugMatcher, get_intervention_drug_ids
from data_parsers import OutcomeMeasureExtract
from data_parsers import UMLSConceptSearcher

from data_parsers import UMLSTFIDFMatcher
from data_parsers.umls_utils import UMLSUtils

from knowledge_graph import KnowledgeGraphBuilder
from knowledge_graph.kg import UnionFind
from knowledge_graph.build_graph import TrialGraphBuilder
from knowledge_graph.node_features import TrialAttributeFeatures

DATA_DIR = "data"
RESULTS_ROOT = os.environ.get("RESULTS_DIR", "LC_Results")

# -------------------------
# Helpers
# -------------------------

def ensure_dir(p: str) -> str:
    pathlib.Path(p).mkdir(parents=True, exist_ok=True)
    return p

def save_pkl(obj: Any, path: str) -> None:
    with open(path, "wb") as f:
        pickle.dump(obj, f)

def _to_jsonable(obj: Any) -> Any:
    """
    Recursively convert objects into JSON-serializable forms:
      - set -> sorted list
      - numpy scalars -> Python scalars
      - numpy arrays -> lists
      - bytes -> utf-8 (best-effort)
      - pathlib.Path -> str
      - dict/list/tuple -> recurse
      - fallback: str(obj)
    """
    # Fast-path for common types
    if obj is None or isinstance(obj, (bool, int, float, str)):
        return obj

    # numpy scalar
    if isinstance(obj, np.generic):
        return obj.item()

    # numpy array
    if isinstance(obj, np.ndarray):
        return obj.tolist()

    # sets -> sorted list for determinism
    if isinstance(obj, set):
        try:
            return sorted(_to_jsonable(v) for v in obj)
        except Exception:
            return [_to_jsonable(v) for v in obj]

    # bytes
    if isinstance(obj, (bytes, bytearray)):
        try:
            return obj.decode("utf-8", errors="ignore")
        except Exception:
            return str(obj)

    # pathlib
    if isinstance(obj, (pathlib.Path, )):
        return str(obj)

    # dict
    if isinstance(obj, dict):
        return {str(k): _to_jsonable(v) for k, v in obj.items()}

    # list/tuple
    if isinstance(obj, (list, tuple)):
        return [_to_jsonable(v) for v in obj]

    # fallback
    return str(obj)

def save_json(obj: Any, path: str) -> None:
    with open(path, "w") as f:
        json.dump(_to_jsonable(obj), f, indent=2)

# -------------------------
# Data loading
# -------------------------

def get_clinical_trial_data(nctid: str):
    """
    Try local JSON first, else ClinicalTrials.gov API v2.
    """
    candidates = []
    if nctid.lower().endswith(".json"):
        candidates.append(nctid if os.path.isabs(nctid) else os.path.join(os.getcwd(), nctid))
    else:
        candidates.append(os.path.join("/app/studies", f"{nctid}.json"))
        candidates.append(os.path.join(os.getcwd(), "LC_Clinical_Trials", f"{nctid}.json"))
        candidates.append(os.path.join(os.getcwd(), f"{nctid}.json"))

    for path in candidates:
        if os.path.isfile(path):
            try:
                with open(path, "r") as f:
                    print(f"[INFO] Loading local JSON for {nctid} from {path}")
                    return json.load(f)
            except Exception as e:
                return {"error": f"Failed to read local JSON {path}: {e}"}

    base_url = "https://clinicaltrials.gov/api/v2/studies"
    request_url = f"{base_url}/{nctid}"
    try:
        print(f"[INFO] Fetching {nctid} from ClinicalTrials.gov API: {request_url}")
        response = requests.get(request_url, timeout=30)
        if response.status_code == 200:
            return response.json()
        else:
            return {"error": f"Failed to fetch data. Status code: {response.status_code}, Message: {response.text}"}
    except Exception as e:
        return {"error": str(e)}

# -------------------------
# Parsing / enrichment
# -------------------------

def parse(nctid: str):
    trial_data = get_clinical_trial_data(nctid)
    if "error" in trial_data:
        raise RuntimeError(f"Error fetching data: {trial_data['error']}")

    attributes = {
        "nct_id": "protocolSection.identificationModule.nctId",
        "arm_group": "protocolSection.armsInterventionsModule.armGroups",
        "intervention": "protocolSection.armsInterventionsModule.interventions",
        "condition": "protocolSection.conditionsModule.conditions",
        "intervention_mesh_terms": "derivedSection.conditionBrowseModule.meshes",
        "event_groups": "resultsSection.adverseEventsModule.eventGroups",
        "primary_outcome": "protocolSection.outcomesModule.primaryOutcomes",
        "secondary_outcome": "protocolSection.outcomesModule.secondaryOutcomes",
        "eligibility_criteria": "protocolSection.eligibilityModule.eligibilityCriteria",
        "brief_summary": "protocolSection.descriptionModule.briefSummary",
        "phase": "protocolSection.designModule.phases",
        "enrollment": "protocolSection.designModule.enrollmentInfo",
        "gender_sex": "protocolSection.eligibilityModule.sex",
        "minimum_age": "protocolSection.eligibilityModule.minimumAge",
        "maximum_age": "protocolSection.eligibilityModule.maximumAge",
    }

    parsed_trial = {}
    for attribute, path in attributes.items():
        val = trial_data
        for component in path.split("."):
            if not isinstance(val, dict) or component not in val:
                val = None
                break
            val = val[component]
        parsed_trial[attribute] = val

    parsed_trial.setdefault("arm_group", [])
    parsed_trial.setdefault("intervention", [])

    for arm_group in parsed_trial["arm_group"]:
        if "label" in arm_group:
            arm_group["arm_group_label"] = arm_group.pop("label")

    for intervention in parsed_trial["intervention"]:
        if "type" in intervention:
            intervention["intervention_type"] = intervention.pop("type").title()
        if "name" in intervention:
            intervention["intervention_name"] = intervention.pop("name")
        if "otherNames" in intervention:
            intervention["other_name"] = intervention.pop("otherNames")
        if "armGroupLabels" in intervention:
            intervention["arm_group_label"] = intervention.pop("armGroupLabels")

    parsed_trial["clinical_results"] = (
        {"reported_events": {"group_list": {"group": parsed_trial.pop("event_groups")}}}
        if parsed_trial.get("event_groups") is not None
        else {}
    )
    return parsed_trial

def run_medex_and_parse_output(parsed_trial):
    result = {}
    classpath = (
        "resources/medex/Medex_UIMA_1.3.8/bin:"
        "resources/medex/Medex_UIMA_1.3.8/lib/*"
    )
    args_template = (
        "java -Xmx1024m -cp {0} org.apache.medex.Main "
        "-i {1} -o {2} -b n -f y -d y -t n"
    )

    with tempfile.TemporaryDirectory() as basedir:
        medex_input._generate_medex_inputs(parsed_trial, result)
        input_dir = os.path.join(basedir, "inputs")
        os.makedirs(input_dir)
        with open(os.path.join(input_dir, "medex_input.json"), "w") as f:
            json.dump(result, f)

        output_path = os.path.join(basedir, "outputs")
        os.makedirs(os.path.join(output_path, "data"))
        args = args_template.format(
            classpath,
            os.path.join(input_dir, "medex_input.json"),
            os.path.join(output_path, "data"),
        )
        print(args)
        try:
            subprocess.run(shlex.split(args), check=True)
        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"Medex execution failed with error: {e}")

        medex_output_parser = medex.MedexOutputParser(base_paths=[output_path])
        medex_output_parser.fill_medex_info(parsed_trial)

    return parsed_trial

def parse_eligiility_criteria(parsed_trial):
    args_template = (
        "java -Xmx8192m -jar resources/criteria2query.jar  --input {0} --outputDir {1}"
    )

    with tempfile.TemporaryDirectory() as basedir:
        input_dir = os.path.join(basedir, "inputs")
        os.makedirs(input_dir)
        with open(os.path.join(input_dir, "crit_input.txt"), "w") as f:
            f.write(parsed_trial.get("eligibility_criteria", "") or "")

        output_path = os.path.join(basedir, "outputs")
        os.makedirs(os.path.join(output_path, "data"))
        args = args_template.format(
            os.path.join(input_dir, "crit_input.txt"),
            os.path.join(output_path, "data"),
        )
        print(args)
        try:
            subprocess.run(shlex.split(args), check=True)
        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"Crit2Query execution failed with error: {e}")

        parsed_trial["ec_umls"] = CriteriaOutputParser.parse_crit_output_from_file(
            os.path.join(output_path, "data", "output.json")
        )
    return parsed_trial

def extract_outcomes(parsed_trial):
    outcome_extractor = OutcomeMeasureExtract(
        f"{DATA_DIR}/outcome_data/clusters-outcome-measures.txt"
    )
    outcome_extractor.load_phrase_models(f"{DATA_DIR}/outcome_data")
    outcome_extractor.populate_cids(parsed_trial)
    return parsed_trial

def population_extraction(umls_utils, parsed_trial):
    umls_concept_searcher = UMLSConceptSearcher(
        api_key="",
        version="2020AB",
        cache_dir=f"{DATA_DIR}/population_data/umls_search_cache",
    )
    umls_concept_searcher.set_umls_search(False)

    criteria_all = parsed_trial["ec_umls"]
    for category in criteria_all:
        for inclusion in criteria_all[category]:
            for criterion in criteria_all[category][inclusion]:
                criterion.map_concept(umls_concept_searcher)

    umls_utils.cuid2parents = {}
    for category in criteria_all:
        for inclusion in criteria_all[category]:
            for criterion in criteria_all[category][inclusion]:
                if criterion.concept is not None:
                    criterion.parents = umls_utils.parents(criterion.concept["ui"])

    tfidf_matcher = UMLSTFIDFMatcher(
        umls_utils.cuid2concept, f"{DATA_DIR}/population_data", None
    )
    tfidf_matcher.populate_result_single(parsed_trial["ec_umls"])
    return parsed_trial

def _phase_feature_vec(phases):
    v = [0] * 5
    for phase in phases:
        if phase in ["EARLY_PHASE1", "PHASE1"]:
            v[1] = 1
        elif phase == "N/A":
            v[0] = 1
        elif phase == "PHASE2":
            v[2] = 1
        elif phase == "PHASE3":
            v[3] = 1
        elif phase == "PHASE4":
            v[4] = 1
        else:
            raise RuntimeError(f"Unknown phase: {phase}")
    return v

def _enrollment_feat(enrollment):
    is_anticipated = False
    if isinstance(enrollment, dict):
        if enrollment.get("type") == "ANTICIPATED":
            is_anticipated = True
        return [math.log(1 + enrollment.get("count", 0)), int(is_anticipated)]
    if isinstance(enrollment, float) and np.isnan(enrollment):
        return [0, 0]
    return [math.log(1 + enrollment), 0]

def _sex_vec(sex):
    if sex is None or isinstance(sex, float):
        return [0, 0, 0]
    sex_to_feats = {
        "ALL": [1, 0, 0],
        "MALE": [0, 1, 0],
        "FEMALE": [0, 0, 1],
    }
    return sex_to_feats.get(sex, [0, 0, 0])

def extract_trial_features(extractor, trial_row):
    data = {}
    data["phase_vec"] = _phase_feature_vec(trial_row["phase"])
    data["enrollment_vec"] = _enrollment_feat(trial_row["enrollment"])
    data["gender_sex_vec"] = _sex_vec(trial_row["gender_sex"])
    data["minimum_age_vec"] = extractor._age_vec(trial_row["minimum_age"] or 0.0)
    data["maximum_age_vec"] = extractor._age_vec(trial_row["maximum_age"] or 0.0)

    def merge_vecs(row):
        feats = []
        for attribute in extractor.attributes:
            if attribute == "phase":
                feats.extend(row["phase_vec"])
            elif attribute == "enrollment":
                feats.extend(row["enrollment_vec"])
            elif attribute == "gender":
                feats.extend(row["gender_sex_vec"])
            elif attribute == "age":
                feats.extend(row["minimum_age_vec"])
                feats.extend(row["maximum_age_vec"])
            elif attribute == "age_class":
                feats.extend(row["age_vec_2"])
            else:
                raise RuntimeError(f"Unknown attributes ({attribute}) for features")
        return np.array(feats)

    return merge_vecs(data)

def get_arm_text(row):
    arm2text = {}
    nct2text = {}
    summary = row.get("brief_summary", "") or ""
    disease_text = ""
    for disease in row.get("condition", []) or []:
        disease_text += f"{disease} "
    outcome_text = ""
    if not isinstance(row.get("primary_outcome"), float):
        for pom in row.get("primary_outcome", []) or []:
            outcome_text += pom.get("measure", "") + " "
    criteria = row.get("eligibility_criteria")
    if isinstance(criteria, float) or criteria is None:
        criteria = ""

    arm2intervention = {}
    for intervention in row.get("intervention", []) or []:
        intervention_text = intervention.get("intervention_name", "") + " "
        intervention_desc = intervention.get("description", "") + " "
        arm_group_label = intervention.get("arm_group_label", ["default"])
        if not isinstance(arm_group_label, list):
            arm_group_label = ["default"]
        for arm_label in arm_group_label:
            arm_label = arm_label.lower()
            arm2intervention[arm_label] = (intervention_text, intervention_desc)

    arms = row.get("arm_group", [])
    if not isinstance(arms, list) or not arms:
        arms = [{"arm_group_label": "default", "arm_group_type": ""}]

    for idx, arm in enumerate(arms):
        label = arm.get("arm_group_label", "default")
        arm_text_val = label + " " + arm.get("description", "")
        if label.lower() in arm2intervention:
            intervention_text, intervention_desc = arm2intervention[label.lower()]
        else:
            intervention_text, intervention_desc = "", ""
        all_text = " ".join(
            [intervention_text, disease_text, outcome_text, arm_text_val, summary, intervention_desc, criteria]
        )
        arm2text[row["nct_id"], idx] = all_text
        nct2text[row["nct_id"]] = [disease_text, outcome_text, summary, criteria]
    return arm2text, nct2text

def build_trial_arms(disease_matcher, drug_matcher, umls_utils, cuid2term, parsed_trial):
    entity2cid_path = f"{DATA_DIR}/kg_data/kg-entity2cid-31_7_21.pkl"
    with open(entity2cid_path, "rb") as f:
        entity2cid = pickle.load(f)

    ext_basepath = f"{DATA_DIR}/kg_data/external_data"
    builder = KnowledgeGraphBuilder(
        disease_matcher.mesh_dis_data,
        drug_matcher.drug_data,
        ext_basepath,
        cuid2term,
        umls_utils,
        umls_graph_clip_threshold=10,
        build_ae=False,
    )

    builder.build_external_networks()
    builder._mesh_children()

    parsed_trial["has_results"] = False

    trial_builder = TrialGraphBuilder(builder, parsed_trial)
    trial_builder.build(use_population=True)

    uf = UnionFind()
    for u, v, data in builder.biokg.graph.edges(data=True):
        if data["relation"] == "KG-MERGE-SAME":
            uf.union(u, v)

    trial_attribute_featurizer = TrialAttributeFeatures(
        attributes=("age", "gender", "enrollment", "phase")
    )
    trial_attribute_feats = extract_trial_features(trial_attribute_featurizer, parsed_trial)

    trial_data = []
    arm2text, _ = get_arm_text(parsed_trial)
    for arm_label, arm_idx in trial_builder.arm_labels.items():
        trial_arm_data = []
        for u, v, k, data in builder.biokg.graph.edges(
            nbunch=[trial_builder.arm_key(arm_idx)],
            data=True,
            keys=True,
        ):
            trial_arm_data.append(
                {
                    "kg_id": entity2cid[uf.find_parent(v)],
                    "relation": data["relation"],
                    "key": k,
                    "data": data,
                }
            )
        trial_data.append(
            {
                "nct_id": parsed_trial["nct_id"],
                "arm_label": arm_label,
                "arm_idx": arm_idx,
                "trial_arm_edges": trial_arm_data,
                "arm_text": arm2text[parsed_trial["nct_id"], arm_idx],
                "trial_attribute_feats_vec": trial_attribute_feats,
            }
        )

    return trial_data

def load_cuid2term():
    basedir = f"{DATA_DIR}/population_data"
    filepath = os.path.join(basedir, "umls_graph_clipper_output.pkl")
    with open(filepath, "rb") as f:
        g_clipper_state = pickle.load(f)
        cuid2term = g_clipper_state["cuid2term"]
    return cuid2term

def main(nct_id: str):
    drug_matcher = DrugMatcher(
        data_paths={
            "drug_data": f"{DATA_DIR}/drug_data/drugs_all_03_04_21.pkl",
            "pubchem_synonyms": f"{DATA_DIR}/drug_data/pubchem-drugbankid-synonyms.json",
            "rxnorm2drugbank-umls": f"{DATA_DIR}/drug_data/rxnorm2drugbank-umls.pkl",
            "RXNCONSO": f"{DATA_DIR}/drug_data/RXNCONSO.RRF",
        }
    )
    disease_matcher = DiseaseExtract(data_dir=DATA_DIR, data_year=2021)
    umls_utils = UMLSUtils(f"{DATA_DIR}/population_data/umls-install/2020AB")
    umls_utils.load_relations()
    cuid2term = load_cuid2term()

    trial = parse(nct_id)
    trial = run_medex_and_parse_output(trial)
    trial = parse_eligiility_criteria(trial)
    trial["mesh_ids"] = disease_matcher.get_disease_ids(trial)

    for intervention in trial['intervention']:
        if "medex_out" in intervention:
            get_intervention_drug_ids(drug_matcher, intervention, trial)

    trial = extract_outcomes(trial)
    trial = population_extraction(umls_utils, trial)

    trial_data = build_trial_arms(disease_matcher, drug_matcher, umls_utils, cuid2term, trial)
    return trial, trial_data

# -------------------------
# CLI
# -------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process clinical trial data.")
    parser.add_argument("nctid", type=str, help="NCT ID of the clinical trial (or path to JSON)")
    args = parser.parse_args()

    nct_id = args.nctid.split("/")[-1].replace(".json", "")

    outroot = ensure_dir(RESULTS_ROOT)
    outdir = ensure_dir(os.path.join(outroot, nct_id))
    print(f"[INFO] output root: {outroot}")
    print(f"[INFO] output trial dir: {outdir}")

    enriched_trial, trial_data = main(args.nctid)

    # 1) Always write PKLs first (so you get your pickles even if JSON serialization fails)
    arms_pkl_new = os.path.join(outroot, f"trial_data_{nct_id}.pkl")
    save_pkl(trial_data, arms_pkl_new)
    print(f"[INFO] wrote {arms_pkl_new}")

    arms_pkl_compat = os.path.join(outroot, f"{nct_id}_results.pkl")
    save_pkl(trial_data, arms_pkl_compat)
    print(f"[INFO] wrote {arms_pkl_compat}")

    # 2) JSON outputs (now safe-serialized)
    parsed_trial_path = os.path.join(outroot, f"parsed_trial_{nct_id}.json")
    save_json(enriched_trial, parsed_trial_path)
    print(f"[INFO] wrote {parsed_trial_path}")

    summary = {
        "nct_id": nct_id,
        "num_arms": len(trial_data),
        "keys_first_arm": sorted(list(trial_data[0].keys())) if trial_data else [],
    }
    summary_path = os.path.join(outroot, f"{nct_id}_summary.json")
    save_json(summary, summary_path)
    print(f"[INFO] wrote {summary_path}")

    print("[INFO] done.")

**Explanation**

**Step 1: Load & Parse Clinical Trial Data**

```python
trial = parse(nct_id)
```

* Internally calls `get_clinical_trial_data(nctid)`

  * Tries local JSON files (e.g. `/app/studies/{nctid}.json`, `LC_Clinical_Trials/{nctid}.json`, `./{nctid}.json`)
  * Falls back to ClinicalTrials.gov API v2 (`https://clinicaltrials.gov/api/v2/studies/{nctid}`)
* Extracts key fields into `parsed_trial`:

  * `nct_id`, arm groups, interventions
  * conditions, outcomes (primary/secondary)
  * eligibility criteria, brief summary
  * phase, enrollment, sex, min/max age
  * adverse event groups (wrapped into `clinical_results`)

---

**Step 2: Run MedEx (Medical Entity Extraction)**

```python
trial = run_medex_and_parse_output(trial)
```

* Uses Java-based MedEx tool
* Writes a temporary `medex_input.json` and calls:

  * `java -Xmx1024m -cp resources/medex/Medex_UIMA_1.3.8/bin:resources/medex/Medex_UIMA_1.3.8/lib/* org.apache.medex.Main ...`
* Extracts drug names, dosages, frequencies from text
* Populates `medex_out` information back into the `intervention` entries

---

**Step 3: Parse Eligibility Criteria (Criteria2Query)**

```python
trial = parse_eligiility_criteria(trial)
```

* Writes raw eligibility text to a temporary `crit_input.txt`
* Runs Criteria2Query Java tool:

  * `java -Xmx8192m -jar resources/criteria2query.jar --input crit_input.txt --outputDir ...`
* Converts free-text eligibility criteria into structured UMLS concepts
* Stores parsed output in `trial["ec_umls"]`

---

**Step 4: Disease Matching (MeSH IDs)**

```python
disease_matcher = DiseaseExtract(data_dir=DATA_DIR, data_year=2021)
trial["mesh_ids"] = disease_matcher.get_disease_ids(trial)
```

* Loads disease dictionaries (e.g. `c2021.bin`, `d2021.bin`)
* Maps trial conditions to MeSH (Medical Subject Headings) terms
* Adds `mesh_ids` to the trial object

---

**Step 5: Drug Matching (DrugBank / RxNorm / PubChem)**

```python
drug_matcher = DrugMatcher(
    data_paths={
        "drug_data": f"{DATA_DIR}/drug_data/drugs_all_03_04_21.pkl",
        "pubchem_synonyms": f"{DATA_DIR}/drug_data/pubchem-drugbankid-synonyms.json",
        "rxnorm2drugbank-umls": f"{DATA_DIR}/drug_data/rxnorm2drugbank-umls.pkl",
        "RXNCONSO": f"{DATA_DIR}/drug_data/RXNCONSO.RRF",
    }
)

for intervention in trial["intervention"]:
    if "medex_out" in intervention:
        get_intervention_drug_ids(drug_matcher, intervention, trial)
```

* Loads multiple drug knowledge sources:

  * DrugBank (pickled `drugs_all_03_04_21.pkl`)
  * RxNorm (`RXNCONSO.RRF`)
  * PubChem–DrugBank synonym mappings
* Uses MedEx output to normalize intervention drug names
* Maps interventions to standardized drug IDs (DrugBank, RxNorm, UMLS where available)

---

**Step 6: Outcome Extraction & Clustering**

```python
trial = extract_outcomes(trial)
```

* Uses `OutcomeMeasureExtract` with:

  * `data/outcome_data/clusters-outcome-measures.txt`
  * pre-trained phrase models in `data/outcome_data`
* Extracts and clusters outcome measures into standardized categories
* Populates outcome cluster IDs (CIDs) into the trial object

---

**Step 7: Population / UMLS Processing (TF-IDF Matching)**

```python
trial = population_extraction(umls_utils, trial)
```

* Uses `UMLSUtils` with local UMLS installation (`data/population_data/umls-install/2020AB`)

  * Loads UMLS relations and parent–child structure
* `UMLSConceptSearcher`:

  * (Configured here without live API calls)
  * Handles mapping of parsed criteria to UMLS concepts
* For each criterion in `trial["ec_umls"]`:

  * Assigns UMLS concept (`criterion.concept`)
  * Attaches parent CUIs via `umls_utils.parents(...)`
* `UMLSTFIDFMatcher`:

  * Builds TF-IDF representation over UMLS concepts (`cuid2concept`)
  * Applies TF-IDF similarity to populate standardized population-related concepts
* Result: eligibility criteria enriched with UMLS concepts and TF-IDF–based population features

---

**Step 8: Build Trial-Specific Knowledge Graph Arms**

```python
cuid2term = load_cuid2term()
trial_data = build_trial_arms(disease_matcher, drug_matcher, umls_utils, cuid2term, trial)
```

* Loads preprocessed UMLS term dictionary:

  * `data/population_data/umls_graph_clipper_output.pkl` → `cuid2term`

* Initializes `KnowledgeGraphBuilder` with:

  * disease data (`disease_matcher.mesh_dis_data`)
  * drug data (`drug_matcher.drug_data`)
  * external KG data (`data/kg_data/external_data`)
  * UMLS utilities and clipped UMLS graph

* Steps inside `build_trial_arms`:

  * Builds/augments external biomedical knowledge graph
  * Runs `_mesh_children()` to propagate MeSH term structure
  * Constructs a `TrialGraphBuilder` for the current trial (with `use_population=True`)
  * Uses `UnionFind` on KG edges with relation `KG-MERGE-SAME` to merge equivalent entities
  * Builds demographic/attribute feature vectors via `TrialAttributeFeatures`:

    * Phase (N/A, Phase 1–4)
    * Enrollment (log-count, anticipated vs actual)
    * Sex (all/male/female)
    * Min/max age
  * Generates arm-level text with `get_arm_text(trial)`:

    * Intervention text + condition text + outcome measures
    * Arm descriptions + brief summary + eligibility criteria
  * For each arm:

    * Extracts KG edges connected to that arm node
    * Maps KG nodes to compact IDs using `kg-entity2cid-31_7_21.pkl`
    * Assembles:

      * `nct_id`
      * `arm_label` and `arm_idx`
      * `trial_arm_edges` (KG edges for this arm)
      * `arm_text` (combined clinical text)
      * `trial_attribute_feats_vec` (numerical feature vector)

* Output (`trial_data`) is a list of arm-level records, one per trial arm

---

**Final Output (Files & In-Memory Objects)**

The script returns from `main`:

* `enriched_trial` – the full parsed and enriched trial dict
* `trial_data` – list of arm-level KG/feature summaries

When run via CLI, it also writes to disk (under `RESULTS_DIR` or `LC_Results/` by default):

* **Pickles**

  * `LC_Results/trial_data_<NCTID>.pkl`
  * `LC_Results/<NCTID>_results.pkl`
    (arm-level KG + feature data, compatible with the original pipeline)

* **JSON**

  * `LC_Results/parsed_trial_<NCTID>.json`

    * Full enriched trial with MedEx, Criteria2Query, outcomes, UMLS/TF-IDF population info, MeSH and drug mappings
  * `LC_Results/<NCTID>_summary.json`

    * Lightweight summary:

      * `nct_id`
      * `num_arms`
      * keys present in the first arm record

These artifacts provide:

* Trial arms with knowledge graph connections
* Arm-level feature vectors for downstream ML models
* Standardized medical concept mappings (UMLS, MeSH, DrugBank/RxNorm, outcome clusters)

### **Step 14: Download Clinical Trial Datasets**

Example:
```bash
# Download fresh trial data from ClinicalTrials.gov
curl -o NCT05576662.json "https://clinicaltrials.gov/api/v2/studies/NCT05576662"
```

### **Step 15: Create PBS Script**

Create `planet_parse.pbs` (for several clinical trials):
```bash
#!/bin/bash
#PBS -P sq95
#PBS -q normal
#PBS -l ncpus=1
#PBS -l mem=48GB
#PBS -l jobfs=1GB
#PBS -l walltime=08:00:00
#PBS -l wd
#PBS -r y
#PBS -N PlaNet_Parse_Array
#PBS -J 1-6

# --- Main Paths ---
SIF=/scratch/sq95/sp6154/planet/planet.sif
WORKDIR=/scratch/sq95/sp6154/planet/parsing_package
STUDY_DIR=$WORKDIR/LC_Clinical_Trials
NLTK_DATA_DIR=$WORKDIR/nltk_data
RESULTS_DIR=$WORKDIR/LC_Results

# Ensure results directory exists
mkdir -p "$RESULTS_DIR"

# --- Load Modules ---
module load singularity
module load java/jdk-17.0.2

echo "==== Job Array ID: $PBS_ARRAY_INDEX ===="
echo "Start time: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"

# --- Use trials.list to map array index -> JSON file ---
LIST_FILE="$STUDY_DIR/trials.list"

# If trials.list doesn't exist, build it from *.json in STUDY_DIR
if [ ! -f "$LIST_FILE" ]; then
    echo "trials.list not found. Building from *.json in $STUDY_DIR"
    ls "$STUDY_DIR"/*.json | xargs -n1 basename | sort > "$LIST_FILE"
fi

# Get JSON name for this array index (1-based)
JSON_NAME=$(sed -n "${PBS_ARRAY_INDEX}p" "$LIST_FILE")

if [ -z "$JSON_NAME" ]; then
    echo "No JSON entry for PBS_ARRAY_INDEX=$PBS_ARRAY_INDEX in $LIST_FILE. Exiting."
    exit 1
fi

json_file="$JSON_NAME"
nct_id=$(basename "$json_file" .json)

echo "Processing index $PBS_ARRAY_INDEX -> $json_file"
echo "Input JSON: $STUDY_DIR/$json_file"
echo "Results will be collected in: $RESULTS_DIR"
echo "----------------------------------------------------"

# --- Run parser for this single trial ---
singularity exec \
    --bind $WORKDIR:/app \
    --bind $STUDY_DIR:/app/studies \
    --bind $NLTK_DATA_DIR:/app/nltk_data \
    --bind $RESULTS_DIR:/app/results \
    --bind /apps/java:/apps/java \
    $SIF \
    /bin/bash -c "\
        export NLTK_DATA=/app/nltk_data && \
        export JAVA_HOME=/apps/java/jdk-17.0.2 && \
        export PATH=\$JAVA_HOME/bin:\$PATH && \
        export RESULTS_DIR=/app/results && \
        cd /app && \
        python /app/parse_trial.py $nct_id && \
        mv trial_data_${nct_id}.pkl /app/results/ 2>/dev/null || true && \
        mv ${nct_id}_results.pkl /app/results/ 2>/dev/null || true && \
        mv parsed_trial_${nct_id}.json /app/results/ 2>/dev/null || true && \
        mv ${nct_id}_summary.json /app/results/ 2>/dev/null || true" \
    2>&1 | grep -v 'Network is unreachable' > "$STUDY_DIR/${nct_id}.log"

echo "Finished processing $nct_id. Log saved to ${nct_id}.log"
echo "End time: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
```

### **Step 16: Submit the Job**
```bash
qsub planet_parse.pbs
```

### **Step 17: Monitor Progress**
```bash
# Check job status
qstat -u $USER

# Input data:
cd LC_Clinical_Trials

# Once completed, check output
cat LC_Results/parsed_trial_NCT04678830.json
cat NCT04678830_summary.json
cat NCT04678830_results.pkl
```

### **Step 18: Load Results**

The script will generate:
- `NCT05576662_results.pkl` - Complete results with knowledge graph data
- `NCT05576662_summary.json` and `parsed_trial_NCT04678830.json` - Human-readable summary

In [None]:
import os
import json
import pprint

# Folder that contains your results_.pkl and _summary.json files
folder = '/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/LC_Results/'

# List all files in the folder
files = sorted(os.listdir(folder))

if not files:
    print(f"No files found in '{folder}'.")
else:
    for filename in files:
        file_path = os.path.join(folder, filename)

        # Skip subdirectories, just in case
        if not os.path.isfile(file_path):
            continue

        # Only process pkl and json files
        if not (filename.endswith('.pkl') or filename.endswith('.json')):
            continue

        print("\n" + "=" * 80)
        print(f"Found file: {file_path}")
        print("=" * 80)

        # Skip pickle files (do not attempt to read/unpickle)
        if filename.endswith('.pkl'):
            print("Skipping .pkl file (binary pickle, not printing content).")
            continue

        # Handle JSON files: read and pretty-print
        if filename.endswith('.json'):
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)

                print("--- JSON File Content ---")
                pprint.pprint(data)
            except Exception as e:
                print(f"[ERROR] An error occurred while reading '{file_path}': {e}")


Found file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/LC_Results/NCT03554265_results.pkl
Skipping .pkl file (binary pickle, not printing content).

Found file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/LC_Results/NCT03554265_summary.json
--- JSON File Content ---
{'keys_first_arm': ['arm_idx',
                    'arm_label',
                    'arm_text',
                    'nct_id',
                    'trial_arm_edges',
                    'trial_attribute_feats_vec'],
 'nct_id': 'NCT03554265',
 'num_arms': 3}

Found file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/LC_Results/NCT04678830_results.pkl
Skipping .pkl file (binary pickle, not printing content).

Found file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/LC_Results/NCT04678830_summary.json
--- JSON File Content -

In [None]:
import json
from pathlib import Path

# ========= CONFIG =========
DATA_DIR = Path(r"/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/LC_Results")
OUTPUT_HTML = DATA_DIR / "lc_results_json.html"

# ========= BUILD HTML =========
html_parts = [
    "<!doctype html>",
    "<html>",
    "<head>",
    "  <meta charset='utf-8'>",
    "  <title>LC Results JSON Files</title>",
    "  <style>",
    "    body { font-family: monospace; white-space: pre-wrap; }",
    "    h2 { margin-top: 2em; border-bottom: 1px solid #ccc; }",
    "    pre { background: #f8f8f8; padding: 10px; border-radius: 4px; }",
    "  </style>",
    "</head>",
    "<body>",
    "  <h1>LC Results JSON Files</h1>",
]

for path in sorted(DATA_DIR.glob("*.json")):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Pretty-print JSON
    pretty = json.dumps(data, indent=2)

    # Escape HTML special chars
    pretty_escaped = (
        pretty.replace("&", "&amp;")
              .replace("<", "&lt;")
              .replace(">", "&gt;")
    )

    html_parts.append(f"  <h2>{path.name}</h2>")
    html_parts.append("  <pre>")
    html_parts.append(pretty_escaped)
    html_parts.append("  </pre>")

html_parts.append("</body>")
html_parts.append("</html>")

# ========= WRITE FILE =========
with open(OUTPUT_HTML, "w", encoding="utf-8") as f:
    f.write("\n".join(html_parts))

print(f"HTML written to: {OUTPUT_HTML}")

HTML written to: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/LC_Results/lc_results_json.html


### **Step 19: Check Logs**

In [None]:
from pathlib import Path
from collections import deque

logs_dir = Path("/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/parse/Logs")

def tail(path, n=40):
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        return list(deque(f, maxlen=n))

# Choose all files (you can filter later if needed)
log_files = sorted(
    [p for p in logs_dir.iterdir() if p.is_file()],
    key=lambda p: p.stat().st_mtime,
    reverse=True
)

if not log_files:
    print("⚠️ No files found in Logs directory.")
else:
    for lf in log_files:
        print("\n" + "="*80)
        print(f"📄 {lf.name}")
        print("="*80)
        for line in tail(lf, n=30):
            print(line.rstrip())


📄 PlaNet_Parse_Array.o154973436.6
==== Job Array ID: 6 ====
Start time: 2025-11-21T06:32:18Z
Processing index 6 -> NCT05047952.json
Input JSON: /scratch/sq95/sp6154/planet/parsing_package/Input_ALL/NCT05047952.json
Results will be collected in: /scratch/sq95/sp6154/planet/parsing_package/LC_Results
----------------------------------------------------
Finished processing NCT05047952. Log saved to NCT05047952.log
End time: 2025-11-21T06:38:01Z

                  Resource Usage on 2025-11-21 17:38:04:
   Job Id:             154973436[6].gadi-pbs
   Project:            sq95
   Exit Status:        0
   Service Units:      2.29
   NCPUs Requested:    1                      NCPUs Used: 1
                                           CPU Time Used: 00:05:31
   Memory Requested:   48.0GB                Memory Used: 33.43GB
   Walltime requested: 08:00:00            Walltime Used: 00:05:44
   JobFS requested:    1.0GB                  JobFS used: 4.2KB

📄 PlaNet_Parse_Array.o154973436.5
==== Job A

### **Step 20: Interpret Results**

**1. Overview of What the Script Did**

For each trial, the script handled files as follows:

* `*_results.pkl` – **binary result files** (arm-level / graph-level outputs).

  * **Detected but not opened**: the script logs
    `Skipping .pkl file (binary pickle, not printing content).`
* `*_summary.json` – compact summaries with high-level trial metadata.

  * **Loaded and printed**.
* `parsed_trial_*.json` – detailed parsed JSON from ClinicalTrials.gov (arms, outcomes, eligibility, etc.).

  * **Loaded and printed**.

---

**2. File-Level Status by Trial**

| NCT ID      | `*_results.pkl`                       | `*_summary.json`          | `parsed_trial_*.json` |
| ----------- | ------------------------------------- | ------------------------- | --------------------- |
| NCT03554265 | ✅ Found (binary, skipped in this log) | ✅ Loaded (`num_arms = 3`) | ✅ Loaded              |
| NCT04678830 | ✅ Found (binary, skipped in this log) | ✅ Loaded (`num_arms = 2`) | ✅ Loaded              |
| NCT04809974 | ✅ Found (binary, skipped in this log) | ✅ Loaded (`num_arms = 2`) | ✅ Loaded              |
| NCT04871815 | ✅ Found (binary, skipped in this log) | ✅ Loaded (`num_arms = 1`) | ✅ Loaded              |
| NCT04880161 | ✅ Found (binary, skipped in this log) | ✅ Loaded (`num_arms = 2`) | ✅ Loaded              |
| NCT05047952 | ✅ Found (binary, skipped in this log) | ✅ Loaded (`num_arms = 2`) | ✅ Loaded              |
| NCT05633407 | ✅ Found (binary, skipped in this log) | ✅ Loaded (`num_arms = 2`) | ✅ Loaded              |

**Key point:**
We already have **full trial structure** from the JSON files (`*_summary.json` and `parsed_trial_*.json`). The `.pkl` files are present but **not inspected in this step**; they can be used later by other parts of the pipeline if needed.

---

**3. Common Structure Across the Parsed Trials**

All seven `parsed_trial_*.json` files share a similar schema, including:

* **Core trial metadata**

  * `nct_id`, `brief_summary`, `condition`, `phase`, `enrollment` (count + type).
* **Arms and interventions**

  * `arm_group`: labels (e.g. *Placebo*, *Experimental*), descriptions, and whether they are non-drug or active treatments.
  * `intervention`: drug/biologic name, type, and mapping to arm labels.
* **Outcomes**

  * `primary_outcome` / `secondary_outcome`: each with `measure`, `description`, `timeFrame`, and concept/cluster IDs.
* **Safety data**

  * `clinical_results.reported_events.group_list.group` with, per arm:

    * `deathsNumAffected` / `deathsNumAtRisk`
    * `seriousNumAffected` / `seriousNumAtRisk`
    * `otherNumAffected` / `otherNumAtRisk`
* **Eligibility (free text + UMLS)**

  * `eligibility_criteria`: raw inclusion/exclusion text.
  * `ec_umls`: structured mapping of inclusion/exclusion to UMLS concepts, grouped into:

    * `Condition`, `Demographic`, `Drug`, `Measurement`, `Observation`, `Procedure`.
* **Terminology**

  * `mesh_ids` and `intervention_mesh_terms`: MeSH terms for diseases/conditions and interventions.

This is enough to:

* Compare trials on **symptom domains**, **inclusion/exclusion logic**, and **outcome measures**.
* Extract **safety signals** from reported events (per arm).
* Build higher-level summaries and features **without touching the `.pkl` files**.

---

**4. Trial Summaries**

**4.1 NCT03554265 – Somatropin (rhGH) in mTBI and PASC**

**Basic info**

* **Intervention**: Recombinant human growth hormone (somatropin, Genotropin).
* **Conditions**: Traumatic brain injury, fatigue, cognitive impairment, **COVID-19/PASC**.
* **Phase**: Phase 3.
* **Enrollment**: 72 (actual).
* **Arms (3)**:

  * *mTBI subjects* – rhGH daily for 6 months.
  * *PASC subjects* – rhGH daily for 9 months.
  * *Household/community controls* – no intervention.

**Population & inclusion (high level)**

* Adults 18–70 years.
* Separate inclusion/exclusion sets for mTBI, PASC, and controls.
* PASC arm: prior confirmed COVID-19, ≥ 6 months since diagnosis, clinically significant fatigue (BFI score ≥ 3).

**Key exclusion themes**

* Significant heart, liver, kidney, blood, or respiratory disease.
* Uncontrolled diabetes, recent cancers, inflammatory bowel disease / Celiac / diverticular disease.
* Current alcohol/drug abuse, significant psychiatric history, pregnancy.
* Recent anabolic steroids, corticosteroids, or drugs that confound outcomes.

**Main outcomes**

* **Primary**: Lean body mass and fat mass (DEXA) at baseline and 6 months.
* **Secondary**: Resting energy expenditure (metabolic rate) at baseline and 6 months.

**Reported events**

* **mTBI subjects (EG000)**: 0 deaths / 0 serious AEs; `otherNumAffected` 26 / 28.
* **Community controls (EG001)**: 0 deaths / 0 serious AEs; `otherNumAffected` 0 / 27.
* **PASC subjects (EG002)**: 0 deaths / 0 serious AEs; `otherNumAffected` 17 / 17.

**Interpretation**

* Non-serious AEs are **very frequent** in the treated mTBI and PASC groups, absent in controls.
* This trial is important for **body composition / metabolic** effects of GH in PASC, and for your AE vs benefit mapping.

---

**4.2 NCT04678830 – Leronlimab for Long COVID Symptoms**

**Basic info**

* **Intervention**: Leronlimab (PRO 140), a humanized IgG4 monoclonal antibody targeting **CCR5**.
* **Condition**: Prolonged COVID-19 symptoms (> 12 weeks).
* **Phase**: Phase 2.
* **Enrollment**: 56 (actual).
* **Arms (2)**:

  * *700 mg Leronlimab* – weekly subcutaneous injections.
  * *Placebo* – saline syringes.

**Population & inclusion**

* Adults ≥ 18 years with prior confirmed COVID-19.
* Symptom score ≥ 6 and ≥ 2 symptoms of moderate+ severity.
* Symptoms persisting > 12 weeks across respiratory, neurological, CV/GI, musculoskeletal, and general “immune response” domains.

**Key exclusion themes**

* Moderate/severe pulmonary disease (COPD, pulmonary fibrosis), NYHA III–IV heart failure.
* Significant liver/kidney disease, serious systemic illnesses.
* Major psychiatric disorders (bipolar, schizophrenia, uncontrolled major depression).
* Pre-existing CFS or fibromyalgia.
* Recent immunosuppressive / immunomodulatory therapy above allowed thresholds.

**Primary outcome**

* **Change in daily COVID-19-related symptom severity score through Day 56**:

  * ~24 symptom items scored 0–3 (two up to 2), max ≈ 70, min 0.
  * **Negative change = improvement**; more negative = greater improvement.

**Selected secondary outcomes**

* Duration of symptoms, symptom-free days, symptom progression.
* **PROMIS Fatigue**, **PROMIS Cognitive Function**, **PROMIS Sleep Disturbance**.
* Hospitalisation count and duration.

**Reported events**

* **700 mg Leronlimab (EG000)**: 0 deaths, 0 serious AEs, `otherNumAffected` 22 / 28.
* **Placebo (EG001)**: 0 deaths, 1 serious AE, `otherNumAffected` 20 / 28.

**Interpretation**

* Non-serious AEs are common in both groups; one serious AE reported in the placebo arm in this dataset.
* Rich **multi-symptom diary + PROMIS** structure makes this trial central for your symptom-based benefit scoring.

---

**4.3 NCT04809974 – Niagen (Nicotinamide Riboside) for Cognitive Long COVID**

**Basic info**

* **Intervention**: Niagen (nicotinamide riboside, NR) vs placebo.
* **Conditions**: COVID-19 with persistent cognitive symptoms (“brain fog”) and other neurological/physical sequelae.
* **Phase**: Phase 4.
* **Enrollment**: 72 (actual).
* **Arms / sequence**:

  * Two-week **placebo lead-in** for all participants.
  * Randomised phase:

    * *Niagen* – 2000 mg/day in capsules.
    * *Placebo* – matching capsules, with later cross-over to Niagen.

**Population & inclusion**

* History of PCR-confirmed SARS-CoV-2 ≥ 2 months prior.
* SARS-CoV-2 PCR negative at entry.
* Persistent cognitive difficulties (esp. brain fog) plus ≥ 2 ongoing neurological/physical symptoms (fatigue, weakness, loss of smell, SOB, palpitations, musculoskeletal pain, etc.).
* Not pregnant or lactating, able to complete assessments.

**Key exclusion themes**

* Major central nervous system diseases (stroke, brain tumour, NPH, etc.).
* Clinically significant unstable medical conditions.
* Intubation during acute COVID-19.
* Major active/chronic unstable psychiatric illness within the last year.
* Alcohol/substance abuse within 2 years.
* Psychoactive medications likely to worsen cognition.
* Known hypersensitivity to NR or its metabolite NMN.
* Recent investigational agents or MRI contraindication (for MRI sub-study).

**Primary outcome**

* **Effect of Niagen on NAD⁺ and cognitive functioning**:

  * ECog, RBANS, and TMT-B:

    * ECog: higher = worse subjective cognition.
    * RBANS: higher = better cognitive performance.
    * TMT-B: longer time = worse performance.
  * Change scores from Baseline → Week 10 → Week 20.

**Selected secondary outcomes**

* **Depression**: Beck Depression Inventory (BDI).
* **Anxiety**: Beck Anxiety Inventory (BAI).
* **Other COVID-related symptoms**: Fatigue Severity Scale (FSS) and Pittsburgh Sleep Quality Index (PSQI).

**Reported events**

* **Lead-in phase (EG000)**: 0 deaths, 0 serious AEs, `otherNumAffected` 1 / 72.
* **Placebo (EG001)**: 0 deaths, 0 serious AEs, `otherNumAffected` 19 / 21.
* **Niagen (EG002)**: 0 deaths, 1 serious AE, `otherNumAffected` 25 / 37.

**Interpretation**

* Non-serious AEs are common in both active and placebo arms; one serious AE recorded in the Niagen arm.
* This trial is valuable for **cognitive, mood, fatigue, and sleep outcomes** under a NAD⁺-targeting intervention.

---

**4.4 NCT04871815 – Sodium Pyruvate Nasal Spray in Long COVID**

**Basic info**

* **Intervention**: Sodium pyruvate nasal spray (N115) – single-arm study.
* **Condition**: Long COVID (“Long Haulers”) with respiratory and systemic symptoms.
* **Phase**: Phase 2/3.
* **Enrollment**: 22 (actual).
* **Arm (1)**:

  * *Treatment of Long Covid with sodium pyruvate nasal spray* – open-label.

**Population & inclusion**

* Prior confirmed positive COVID-19 test plus **lingering symptoms** consistent with CDC long COVID list (fatigue, brain fog, headache, loss of smell/taste, palpitations, SOB, chest pain, musculoskeletal pain, depression, anxiety, post-exertional symptom worsening).

**Key exclusion themes**

* Viral infections other than COVID-19.
* Significant cardiac disease (including uncontrolled CHF, unstable angina).
* Pregnancy; inadequate contraception; lactation.
* Systemic corticosteroids within 1 month.
* Recent hospitalisation for airway disease; significant abnormal chest X-ray.
* Recent medication changes or investigational drug use.
* Current alcohol or recreational drug abuse.
* Recent dietary supplements containing pyruvate.

**Primary outcomes**

* **Change in symptom score over 14 days**:

  * Likert 0–10 per symptom; daily logs.
  * Total 7-day baseline score (days 1–7) vs 7-day treatment score (days 8–14).
* **Physiological parameters**:

  * Body temperature.
  * Pulse rate.
  * Blood oxygenation (SaO₂).

**Reported events**

* **Treatment phase (EG000)**: 0 deaths, 0 serious AEs, `otherNumAffected` 0 / 22.
* **Baseline / no treatment (EG001)**: 0 deaths, 0 serious AEs, `otherNumAffected` 0 / 22.

**Interpretation**

* No AEs reported in this dataset; purely descriptive for symptom and vital-sign changes pre-/post- N115.
* Useful as a **single-arm longitudinal** example for symptom trajectory features.

---

**4.5 NCT04880161 – Inhaled Ampion for Respiratory Long COVID**

**Basic info**

* **Intervention**: Inhaled **Ampion** vs inhaled saline placebo.
* **Condition**: Long COVID with prolonged respiratory symptoms.
* **Phase**: Phase 1.
* **Enrollment**: 32 (actual).
* **Arms (2)**:

  * *Active*: Inhaled nebulized Ampion (8 mL) four times daily for 5 days.
  * *Control*: Inhaled nebulized saline (8 mL) four times daily for 5 days.

**Population & inclusion**

* Adults ≥ 18 years.
* COVID-19 clinical diagnosis ≥ 4 weeks prior (symptoms + RT-PCR/equivalent).
* ≥ 2 respiratory symptoms (cough, sore throat, runny/stuffy nose, SOB, chest tightness, low exercise tolerance) with score ≥ 2 for ≥ 4 weeks.

**Key exclusion themes**

* Need for hospitalisation.
* Severe COPD/restrictive lung disease, chronic renal failure, significant liver disease.
* Pre-existing CFS.
* Chronic immunosuppressive medication.
* Major surgery expected within the study window.
* Allergy to human albumin or excipients of 5% human albumin.
* Pregnancy/breastfeeding, major ECG abnormalities, severe comorbidity limiting outcome assessments.

**Primary outcome**

* **Number of participants with treatment-emergent adverse events (TEAEs) and serious AEs (SAEs)** from baseline to Day 28.

**Reported events**

* **Ampion (EG000)**: 0 deaths, 0 serious AEs, `otherNumAffected` 8 / 15.
* **Placebo (EG001)**: 0 deaths, 0 serious AEs, `otherNumAffected` 9 / 16.

**Interpretation**

* No deaths or SAEs in either arm; non-serious AEs in roughly half of participants in both groups.
* A clean example of a primarily **safety-focused** Long COVID respiratory trial.

---

**4.6 NCT05047952 – Vortioxetine for Cognitive Impairment in Post-COVID**

**Basic info**

* **Intervention**: **Vortioxetine** vs placebo.
* **Conditions**: Post-COVID-19 condition/syndrome with **cognitive impairment**.
* **Phase**: Phase 2.
* **Enrollment**: 149 (actual).
* **Arms (2)**:

  * *Vortioxetine*:

    * 18–64 years: 10 → 20 mg/day.
    * 65+ years: 5 → 10 mg/day.
  * *Placebo*: once-daily matching pill for 8 weeks.

**Population & inclusion**

* Age ≥ 18.
* Meets **WHO post-COVID-19 condition definition**, with symptoms ≥ 3 months from onset, lasting ≥ 2 months, and impacting daily function.
* Documented SARS-CoV-2 infection (PCR/antigen or clinical diagnosis).
* Subjective cognitive complaints (PDQ-5).
* Eligibility confirmed > 12 weeks from acute onset or positive test.

**Key exclusion themes**

* Cognitive symptoms fully explained by major depressive or bipolar disorder.
* Major neurocognitive disorder, schizophrenia, CFS, encephalitis/meningitis.
* Current alcohol/substance use disorder.
* Primary comorbid psychiatric diagnosis (per MINI 7.0.2).
* Medications that substantially affect cognition (psychostimulants, certain antidepressants, benzodiazepines near testing).
* Significant neurological disease (moderate/severe head trauma, uncontrolled epilepsy).
* Pregnancy/breastfeeding, high suicide risk, recent ECT.
* Concomitant MAOIs, linezolid, IV methylene blue.
* Prior intolerance or inefficacy of vortioxetine.

**Primary outcome**

* **Least squares mean change in DSST z-score from baseline to Week 8**:

  * Higher LS mean change = better adjusted cognitive performance.
  * LS mean ≈ 0 = no average change; values can be positive or negative.

**Selected secondary outcomes**

* Change in **WHO-5 Wellbeing**.
* Change in **QIDS-SR-16** (depression severity):

  * Negative LS mean = improvement in depressive symptoms.

**Reported events**

* **Vortioxetine (EG000)**: 0 deaths, 0 serious AEs, `otherNumAffected` 47 / 75.
* **Placebo (EG001)**: 0 deaths, 0 serious AEs, `otherNumAffected` 34 / 74.

**Interpretation**

* Non-serious AEs are more frequent in the vortioxetine arm, consistent with systemic antidepressant exposure.
* Very rich **cognition + mood** outcome structure, ideal for your efficacy side of the benefit–risk scoring.

---

**4.7 NCT05633407 – Efgartigimod for Post-COVID POTS**

**Basic info**

* **Intervention**: **Efgartigimod** IV 10 mg/kg vs placebo.
* **Conditions**: Post-COVID-19 **postural orthostatic tachycardia syndrome (POTS)**; post-acute COVID-19 syndrome is explicitly tagged in MeSH.
* **Phase**: Phase 2.
* **Enrollment**: 53 (actual).
* **Arms (2)**:

  * *Efgartigimod* – IV 10 mg/kg once weekly for 24 weeks.
  * *Placebo* – IV placebo once weekly for 24 weeks.
  * Randomisation is **2:1** (efgartigimod : placebo).

**Population & inclusion**

Key inclusion themes (new-onset **post-COVID POTS**):

* Adults (≥ 18 years), all sexes (gender = ALL).

* **History of COVID-19**, based on a previous positive PCR or rapid antigen test (documented or participant-reported).

* **New-onset POTS post-COVID-19**, confirmed by tilt-table or orthostatic vital signs consistent with consensus criteria:

  * Sustained HR increase ≥ 30 bpm within 10 min of standing or head-up tilt (≥ 40 bpm for ages 18–19) and/or HR > 120 bpm within 10 min.
  * **No** sustained ≥ 20 mmHg drop in systolic BP (i.e. excludes classical orthostatic hypotension).

* **Persistent POTS symptoms > 12 weeks** after COVID-19 diagnosis or hospital discharge, with at least 3 symptoms across:

  * **Vasomotor / orthostatic symptoms** – fatigue, orthostatic intolerance, brain fog, exertional dyspnoea, difficulty concentrating, venous pooling, exercise intolerance.
  * **Sympathetic over-compensation symptoms** – palpitations, heat intolerance, nausea ± vomiting, insomnia, anxiety, poor appetite, chest pain, diaphoresis.

* **COMPASS-31 ≥ 35** at screening, indicating clinically significant autonomic symptom burden.

* **BMI < 35 kg/m²**.

* Standard contraception rules: females of child-bearing potential require negative pregnancy tests and effective contraception; males have no specific contraception requirement beyond local regulations.

**Key exclusion themes**

Pre-existing autonomic/neurological disorders and major comorbidities are carefully excluded, including:

* Pre-COVID diagnosis or treatment for: peripheral neuropathy, POTS, myalgic encephalomyelitis/chronic fatigue syndrome, genetically confirmed Ehlers-Danlos syndrome, autonomic neuropathy, multiple sclerosis, stroke, spinal cord injury, or other major CNS lesions on imaging or exam.

* Significant **cardio-pulmonary disease**, such as ongoing cardiac arrhythmias, heart failure, myocarditis, pulmonary embolism requiring anticoagulation, pulmonary fibrosis, or critical-illness related polyneuropathy/myopathy.

* **Autoimmune disease** that could confound symptom assessment or increase risk (investigator judgement).

* **Primary immunodeficiency** (e.g. common variable immunodeficiency) or **HIV**.

* Malignancy unless adequately treated with ≥ 3 years without recurrence, except allowed in situ / early stage cancers (e.g. basal/squamous cell skin cancer, cervical/breast carcinoma in situ, incidental low-stage prostate cancer).

* Clinically significant uncontrolled active or chronic **bacterial, viral, or fungal infections**, or positive SARS-CoV-2 PCR at screening.

* Positive screening tests indicating active **HBV**, **HCV**, or **HIV** infection.

* **Total IgG < 4 g/L** at screening (important for a drug targeting the IgG pathway).

* Recent or conflicting therapies:

  * Investigational products within 12 weeks or 5 half-lives (whichever longer).
  * IV/SC immunoglobulin or plasmapheresis/plasma exchange within 12 weeks.
  * Live or live-attenuated vaccines within 4 weeks.

* Known **hypersensitivity** to efgartigimod or excipients; prior exposure to efgartigimod in earlier trials.

* Participation in other interventional studies.

* Alcohol or drug abuse within 12 months.

* Pregnancy, lactation, or plans to become pregnant during the study.

* Inability/unwillingness to remain on a stable medication regimen or to avoid initiating new rehabilitation/exercise programmes during the 24-week treatment period.

**Primary outcomes**

The trial has a mixed efficacy + safety primary focus:

1. **Change from Baseline to Week 24 in COMPASS-31 (2-week recall version)**

   * Composite autonomic symptom score (31 questions, 6 weighted domains).
   * Total score ranges **0–100**, with higher scores = **more severe autonomic symptoms**.
   * Improvement = **decrease** in COMPASS-31 score.

2. **Change from Baseline to Week 24 in the Malmö POTS Symptom Score (MaPS)**

   * POTS-specific symptom scale (12 items, 0–10 VAS per item; total 0–120).
   * Captures both orthostatic and non-orthostatic symptoms (e.g. GI, insomnia, concentration difficulties).
   * Higher MaPS score = more severe symptom burden; improvement = **decrease** in MaPS.

3. **Number of participants with TEAEs and TESAEs**

   * TEAEs = adverse events emerging after first dose up to 60 days after last dose.
   * TESAEs = serious treatment-emergent adverse events (death, life-threatening, hospitalisation, disability, congenital anomaly, or other medically important events).
   * Time window: Day 1 through 60 days post last dose (up to 236 days total).

**Selected secondary outcomes**

* **Global patient-reported severity and change**:

  * **PGI-S** (Patient Global Impression–Severity): 4-point scale (none → severe).

    * “Improved PGI-S” defined as a change from baseline of –1, –2, or –3 (lower = less severe).
  * **PGI-C** (Patient Global Impression–Change): 7-point scale (much better → much worse).

    * “Improved PGI-C” defined as categories 1–3 (better to much better).

* **PROMIS Fatigue Short Form 8a**

  * 8-item scale (1–5 per item; 8–40 total, higher = more fatigue).
  * Converted to T-scores (mean 50, SD 10); **decrease in T-score** = improved fatigue.

* **PROMIS Cognitive Function Short Form 6a**

  * 6 items about perceived cognitive difficulties (concentration, memory, mental acuity).
  * Raw score 6–30; higher raw score = worse perceived cognition, but T-score interpretation is specified such that **increase in T-score** = better cognitive function.

* **Immunology / PK / immunogenicity**:

  * **Percent change in total IgG** from baseline to Week 24.
  * **Serum concentration of efgartigimod** at multiple timepoints (PK profile).
  * **Number of participants with anti-drug antibodies (ADAs)** (treatment-induced or treatment-boosted).

**Reported events**

From `clinical_results.reported_events.group_list.group`:

* **Efgartigimod (EG000)**:

  * Deaths: 0 / 36
  * Serious AEs: 0 / 36
  * Other (non-serious) AEs: 31 / 36

* **Placebo (EG001)**:

  * Deaths: 0 / 17
  * Serious AEs: 0 / 17
  * Other (non-serious) AEs: 14 / 17

**Interpretation**

* No deaths or serious AEs reported in either arm; **non-serious AEs are very common** in both groups (≈ 86% for efgartigimod, ≈ 82% for placebo), which is consistent with a sick, highly symptomatic population undergoing intensive IV treatment and close monitoring.

* This trial is particularly important because it:

  * Directly targets **post-COVID POTS**, a defined autonomic phenotype within the broader Long COVID spectrum.
  * Uses **autonomic symptom scales (COMPASS-31, MaPS)** as primary readouts, providing high-resolution data on orthostatic and autonomic symptom domains.
  * Includes **fatigue and cognitive PROs (PROMIS Fatigue and Cognitive Function)**, making it comparable to Niagen and vortioxetine trials on the symptom-cognition axis.
  * Adds **mechanistic immunology endpoints** (IgG reduction, PK, ADAs), which link the clinical readouts to the drug’s FcRn-targeting mechanism.

* For your PlaNet-based benefit–risk analysis, NCT05633407 contributes:

  * A **CCR-like autonomic/vascular Long COVID phenotype (post-COVID POTS)** with distinct eligibility and comorbidity filters in `ec_umls`.
  * Detailed **symptom severity changes** across autonomic, fatigue, and cognitive domains.
  * A robust **safety profile** dominated by non-serious AEs, with explicit TEAE/TESAE definitions and IgG-related constraints in the eligibility criteria.

---

**5. `ec_umls`: Eligibility Criteria as Structured Concepts (High Level)**

Each `parsed_trial_*.json` (including NCT05633407) includes an `ec_umls` block that:

* Splits eligibility into **inclusion** and **exclusion** lists.
* Groups concepts by domain:

  * `Condition`, `Demographic`, `Drug`, `Measurement`, `Observation`, `Procedure`.
* Stores original phrases plus mapped **UMLS concepts**, which enables:

  * Programmatic filtering of trials by comorbidities (e.g. CFS, COPD, psychiatric disorders), drugs (e.g. immunosuppressants, corticosteroids, IVIG), or measurements (e.g. SpO₂, symptom scores, COMPASS-31 thresholds).
  * Comparative analysis of **eligibility profiles** between trials (e.g. which trials exclude CFS, severe cardiac disease, or specific immunodeficiencies).
  * Construction of higher-level features such as “cardiometabolic stringency”, “psychiatric exclusion strictness”, or “immunosuppression exclusion”.

These structured criteria, combined with the outcome and AE summaries above, give you everything you need to link **drug–indication–eligibility–outcome–safety** in the PlaNet-based analysis, without relying on the `.pkl` contents at this interpretation step.

## **PREDICT AE/S/E**

### **Step 1: Prepare .def file**

```bash
Bootstrap: docker
From: continuumio/miniconda3:4.12.0

%files
    planet_env.tar.gz /planet_env.tar.gz

%post
    set -e
    echo "📦 Installing system dependencies..."
    apt-get update \
      && apt-get install -y --no-install-recommends libstdc++6 \
      && apt-get clean \
      && rm -rf /var/lib/apt/lists/*

    echo "🗜️ Unpacking conda env…"
    mkdir -p /opt/conda/envs/planet
    tar -xzf /planet_env.tar.gz -C /opt/conda/envs/planet

    echo "🔧 Running conda‑unpack to fix prefixes…"
    /opt/conda/envs/planet/bin/conda-unpack

    echo "🧹 Cleaning up…"
    rm /planet_env.tar.gz

    # Prepare non‑interactive activation
    echo 'source /opt/conda/etc/profile.d/conda.sh' >> /environment
    # Activate by path to avoid name‑mismatch
    echo 'conda activate /opt/conda/envs/planet' >> /environment
    chmod +x /environment

%environment
    # Every container run will source+activate your env
    source /opt/conda/etc/profile.d/conda.sh
    conda activate /opt/conda/envs/planet
    export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH"

%labels
    Author "Sindy Piñero"
    Version "planet-cpu-2025-07-24-final"

%help
    Final PlaNet Singularity container. Includes all dependencies for parsing and prediction.
```

### **Step 2: Create .sif file**

- Access to GADI HPC system
- Access to the PlaNet data and singularity image

```bash
singularity build --fakeroot planet_2.sif planet_2.def
```

### **Step 3: Create .pbs file**

```bash
#!/bin/bash
#PBS -P sq95
#PBS -q normal
#PBS -l ncpus=24
#PBS -l mem=64GB
#PBS -l jobfs=1GB
#PBS -l walltime=20:00:00
#PBS -l wd
#PBS -M pinsy007@mymail.unisa.edu.au
#PBS -m abe
#PBS -N PlaNet_Predict_Many

# ── Paths (host) ────────────────────────────────────────────────────────
PLANET_ROOT=/scratch/sq95/sp6154/planet
NOTEBOOKS_DIR=$PLANET_ROOT/notebooks
PARSE_DIR=$PLANET_ROOT/parsing_package
SIF=$PLANET_ROOT/planet_v2.sif
BERT_MODEL_DIR=$PLANET_ROOT/bert_model

# ── Trials to run (edit this list as needed) ────────────────────────────
TRIALS=(
  NCT04678830
  NCT04880161
  NCT05047952
  NCT05633407
)

# Log + results (host)
LOG_DIR="$PARSE_DIR/logs"
RESULTS_DIR="$PARSE_DIR/results"

# ── Load modules & env fixes ────────────────────────────────────────────
module load singularity
export SINGULARITYENV_LD_LIBRARY_PATH="/opt/conda/envs/planet/lib:${LD_LIBRARY_PATH:-}"
export SINGULARITYENV_TRANSFORMERS_CACHE=$BERT_MODEL_DIR
export SINGULARITYENV_HF_HOME=$BERT_MODEL_DIR
export SINGULARITYENV_TRANSFORMERS_OFFLINE=1

# ── Prep output dirs ────────────────────────────────────────────────────
mkdir -p "$RESULTS_DIR" "$LOG_DIR"

# ── Debug info (once) ──────────────────────────────────────────────────
MASTER_LOG="$LOG_DIR/PlaNet_Predict_master.log"
: > "$MASTER_LOG"
echo "[$(date)] Starting multi-trial prediction job" | tee -a "$MASTER_LOG"
echo "  notebooks/data → $(readlink "$NOTEBOOKS_DIR/data" 2>/dev/null)" | tee -a "$MASTER_LOG"
echo "  Host models under parsing_package/data/models:" | tee -a "$MASTER_LOG"
ls -1 "$PARSE_DIR/data/models" 2>&1 | tee -a "$MASTER_LOG"

# ── Loop over trials ───────────────────────────────────────────────────
for TRIAL in "${TRIALS[@]}"; do
  echo "[$(date)] ===== Trial: $TRIAL =====" | tee -a "$MASTER_LOG"

  # Relative + host pkl path
  INPUT_PKL_REL="LC_Results/${TRIAL}_results.pkl"
  INPUT_PKL="$PARSE_DIR/$INPUT_PKL_REL"    # /scratch/.../parsing_package/LC_Results/TRIAL_results.pkl

  # Sanity check: file exists
  if [[ ! -f "$INPUT_PKL" ]]; then
    echo "[$(date)] WARNING: Missing file $INPUT_PKL – skipping." | tee -a "$MASTER_LOG"
    continue
  fi

  # Trial-specific basename, log and expected JSON
  BASENAME=$(basename "$INPUT_PKL" .pkl)   # e.g. NCT04678830_results
  LOG="$LOG_DIR/PlaNet_Predict_${BASENAME}.log"
  JSON_OUT="$RESULTS_DIR/result_${BASENAME}.json"

  # Init log for this trial
  : > "$LOG"
  echo "[$(date)] Starting prediction for $INPUT_PKL" | tee -a "$LOG"

  # Run inside container
  echo "[$(date)] Running prediction in Singularity..." | tee -a "$LOG"
  singularity exec \
    --bind "$PLANET_ROOT":/planet \
    --bind "$PARSE_DIR/data":/planet/parsing_package/data \
    --bind "$PARSE_DIR/data":/planet/data \
    --bind "$BERT_MODEL_DIR":/bert_model \
    --pwd /planet/notebooks \
    "$SIF" \
      /opt/conda/envs/planet/bin/python \
        /planet/parsing_package/predict_all_for_new_clinial_trial.py \
          --pklpath "/planet/parsing_package/$INPUT_PKL_REL" \
          --output-dir "/planet/parsing_package/results" \
    >> "$LOG" 2>&1

  echo "[$(date)] Finished prediction for $TRIAL." | tee -a "$LOG"
  echo "JSON output (expected): $JSON_OUT" | tee -a "$LOG"
  echo "[$(date)] Finished $TRIAL (log: $LOG)" | tee -a "$MASTER_LOG"
done

echo "[$(date)] All trials processed (or skipped if missing)." | tee -a "$MASTER_LOG"
```

### **Step 4: Python Script**

In [None]:
#!/usr/bin/env python3
"""
predict_all_for_new_clinial_trial.py

Runs AE, safety, and efficacy predictions for clinical trial data using the PlaNet pipeline.
Usage:
    python predict_all_for_new_clinial_trial.py \
        --pklpath path/to/trial.pkl \
        [--output-dir results] \
        [--model-dir data/models] \
        [--device cpu]
"""

import os
import sys
import json
import argparse
import logging
import time
import threading
from contextlib import contextmanager
from copy import deepcopy
import re
import pickle as pkl

# Third-party imports
import torch
from torch.utils.data import TensorDataset
import networkx as nx
import numpy as np
if not hasattr(np, 'bool'):
    np.bool = bool
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# Configure logging
def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )
    return logging.getLogger(__name__)

log = setup_logging()


def configure_paths():
    """Insert project directories into sys.path for local imports."""
    cwd = os.getcwd()  # should be /app inside container
    for path in [cwd, '/notebooks', '/planet']:
        if os.path.isdir(path) and path not in sys.path:
            sys.path.insert(0, path)
            log.info(f"Added '{path}' to sys.path")


def find_file(filename, search_paths=None):
    """Search for a file in multiple base directories."""
    if search_paths is None:
        search_paths = ['.', '/app', '/planet', '/planet/parsing_package']
    for base in search_paths:
        candidate = os.path.join(base, filename)
        if os.path.exists(candidate):
            return candidate
    raise FileNotFoundError(f"Could not find '{filename}' in {search_paths}")


def load_kg_vocab():
    kgid2x, x2kgid = {}, {}
    relname2etype, etype2relname = {}, {}

    entities_file = find_file('data/graph/entities.dict')
    log.info(f"Using entities file: {entities_file}")
    with open(entities_file) as f:
        for line in f:
            x, kgid = line.split()
            kgid2x[kgid] = int(x)
            x2kgid[int(x)] = kgid

    relations_file = find_file('data/graph/relations.dict')
    log.info(f"Using relations file: {relations_file}")
    with open(relations_file) as f:
        for line in f:
            etype, name = line.split()
            name = name[name.find('rel-name-') + len('rel-name-'):]
            relname2etype[name] = int(etype)
            etype2relname[int(etype)] = name

    return kgid2x, x2kgid, relname2etype, etype2relname


@contextmanager
def suppress_output_with_progress(description):
    """
    Suppresses stdout/stderr while showing a live elapsed-time bar.
    """
    old_out, old_err = sys.stdout, sys.stderr
    null_fd = os.open(os.devnull, os.O_WRONLY)
    new_out, new_err = os.dup(1), os.dup(2)

    pbar = tqdm(desc=f"{description}... Time elapsed", bar_format="{desc}: {n:.1f}s", ncols=40)
    os.dup2(null_fd, 1)
    os.dup2(null_fd, 2)

    start = time.time()
    running = threading.Event(); running.set()
    def update():
        while running.is_set():
            pbar.n = time.time() - start
            pbar.refresh()
            time.sleep(0.1)
    t = threading.Thread(target=update, daemon=True)
    t.start()

    try:
        yield
    finally:
        running.clear()
        t.join()
        pbar.close()
        os.dup2(new_out, 1)
        os.dup2(new_err, 2)
        os.close(null_fd)
        os.close(new_out)
        os.close(new_err)
        sys.stdout, sys.stderr = old_out, old_err


def load_kg_utils():
    """Import and verify the knowledge_graph package."""
    try:
        # This is a valid way to import within a function
        import knowledge_graph
        from knowledge_graph import kg
        log.info("Imported knowledge_graph successfully")
    except ImportError as e:
        log.error(f"Failed to import knowledge_graph: {e}")
        sys.exit(1)


def get_trial_feature(bert_model, new_trial):
    emb = bert_model._embed(new_trial['arm_text'])
    return np.concatenate([emb, new_trial['trial_attribute_feats_vec']])


def get_new_edges(new_trial, the_x, kgid2x, relname2etype):
    new_edges, new_etypes, seen = [], [], set()
    for edge in new_trial['trial_arm_edges']:
        h, t = the_x, kgid2x[edge['kg_id']]
        r = relname2etype[edge['relation']]
        new_edges += [[h, t], [t, h]]
        new_etypes += [r, r + 26]
        seen.add(r)
    for r in [21,22,23,24,25]:
        if r not in seen:
            new_edges += [[the_x, 0], [0, the_x]]
            new_etypes += [r, r + 26]
    return torch.tensor(new_edges).t(), torch.tensor(new_etypes)


def add_new_trial_to_dataset(dataset, encoder, new_trial, kgid2x, relname2etype, bert_model):
    df = dataset.df[dataset.df['split']=='test'].head(1)
    dataset.df = df
    the_x = df.iloc[0]['x']
    the_kgid = df.iloc[0]['kgid']

    node_feats = deepcopy(dataset.node_feats)
    pos = node_feats[node_feats['node_id']==the_kgid].index[0]
    node_feats.at[pos, 'emb'] = get_trial_feature(bert_model, new_trial)
    dataset.node_feats = node_feats

    graph = deepcopy(dataset.graph)
    new_edges, new_etypes = get_new_edges(new_trial, the_x, kgid2x, relname2etype)
    mask = graph.data.edge_index.eq(the_x).any(dim=0)
    graph.data.edge_index = torch.cat([graph.data.edge_index[:, ~mask], new_edges], dim=1)
    graph.data.edge_type  = torch.cat([graph.data.edge_type[~mask], new_etypes], dim=0)
    dataset.graph = graph

    x = dataset._get_data_x(dataset.df)
    ds = TensorDataset(
        x,
        dataset.task_ys[0][0].repeat(len(x),1),
        dataset.sample_weight_masks[0][0].repeat(len(x),1)
    )
    dataset.datasets['test'] = ds
    return dataset, encoder


def add_new_trial_to_efficacy_dataset(dataset, encoder, trial1, trial2, kgid2x, relname2etype, bert_model):
    df = dataset.efficacy_df[dataset.efficacy_df['split']=='test'].head(1)
    dataset.efficacy_df = df
    x1, x2 = df.iloc[0]['x1'], df.iloc[0]['x2']
    kg1, kg2 = df.iloc[0]['kgid1'], df.iloc[0]['kgid2']

    node_feats = deepcopy(dataset.node_feats)
    pos1 = node_feats[node_feats['node_id']==kg1].index[0]
    node_feats.at[pos1, 'emb'] = get_trial_feature(bert_model, trial1)
    pos2 = node_feats[node_feats['node_id']==kg2].index[0]
    node_feats.at[pos2, 'emb'] = get_trial_feature(bert_model, trial2)
    dataset.node_feats = node_feats

    graph = deepcopy(dataset.graph)
    edges1, types1 = get_new_edges(trial1, x1, kgid2x, relname2etype)
    edges2, types2 = get_new_edges(trial2, x2, kgid2x, relname2etype)
    mask1 = graph.data.edge_index.eq(x1).any(dim=0)
    mask2 = graph.data.edge_index.eq(x2).any(dim=0)
    graph.data.edge_index = torch.cat([
        graph.data.edge_index[:, ~mask1 & ~mask2],
        edges1, edges2
    ], dim=1)
    graph.data.edge_type = torch.cat([
        graph.data.edge_type[~mask1 & ~mask2],
        types1, types2
    ], dim=0)
    dataset.graph = graph

    x = dataset._get_data_x(df)
    ds = TensorDataset(
        x,
        dataset.task_ys[0][0].repeat(len(x),1),
        dataset.sample_weight_masks[0][0].repeat(len(x),1)
    )
    dataset.datasets['test'] = ds
    return dataset, encoder


def predict_top_ae(pred, k=5):
    # 1) coerce to 1‑d NumPy array
    arr = np.asarray(pred).ravel()
    # 2) if empty, return empty dict
    if arr.size == 0:
        return {}
    # 3) argsort descending
    idx_desc = np.argsort(arr)[::-1]
    # 4) take up to k (or the array’s length, whichever is smaller)
    topk = idx_desc[: min(k, arr.size)]
    # 5) build result dict
    return {int(i): float(arr[i]) for i in topk}


def run_task(name, model_path, func, *args):
    log.info(f"Starting {name} task")
    try:
        with suppress_output_with_progress(f"{name} prediction"):
            result = func(model_path, *args)
        log.info(f"Completed {name} task")
        return result
    except Exception as e:
        log.error(f"Error in {name} task: {e}", exc_info=True)
        sys.exit(1)


def loader_ae(model_path, new_trial, kgid2x, relname2etype, bert_model, device):
    from utils.demo_utils import load_model_and_data, model_inference, prepare_runner
    (dataset, _), encoder, bert_enc, model, args_, runner = load_model_and_data(model_path, device=device)
    dataset, encoder = add_new_trial_to_dataset(dataset, encoder, new_trial, kgid2x, relname2etype, bert_model)
    args_, runner, encoder = prepare_runner(args_, dataset, encoder, bert_enc, model, device=device)
    _, y_pred, _ = model_inference(runner, mode='test')
    return predict_top_ae(y_pred[0], k=100)


def loader_safety(model_path, new_trial, kgid2x, relname2etype, bert_model, device):
    from utils.demo_utils import load_model_and_data, model_inference, prepare_runner
    (dataset, _), encoder, bert_enc, model, args_, runner = load_model_and_data(model_path, device=device)
    dataset, encoder = add_new_trial_to_dataset(dataset, encoder, new_trial, kgid2x, relname2etype, bert_model)
    args_, runner, encoder = prepare_runner(args_, dataset, encoder, bert_enc, model, device=device)
    _, y_pred, _ = model_inference(runner, mode='test')
    return y_pred[0].item()


def loader_efficacy(model_path, trial1, trial2, kgid2x, relname2etype, bert_model, device):
    from utils.demo_utils import load_model, model_inference, prepare_runner
    (dataset, _), encoder, bert_enc, model, args_ = load_model(model_path)
    dataset, encoder = add_new_trial_to_efficacy_dataset(dataset, encoder, trial1, trial2, kgid2x, relname2etype, bert_model)
    args_, runner, encoder = prepare_runner(args_, dataset, encoder, bert_enc, model, device=device)
    _, y_pred, _ = model_inference(runner, mode='test')
    return y_pred[0].item()


def main():
    parser = argparse.ArgumentParser(description="Run PlaNet tasks for clinical trial data")
    parser.add_argument('-p', '--pklpath', type=str, required=True, help="Path to the input pickle file")
    parser.add_argument('-o', '--output-dir', type=str, default='results', help="Directory to save results")
    parser.add_argument('-m', '--model-dir', type=str, default='/planet/data/models', help="Base directory for model checkpoints")
    parser.add_argument('-d', '--device', type=str, default='cpu', help="Device for computation (cpu or cuda)")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    configure_paths()
    load_kg_utils()
    kgid2x, x2kgid, relname2etype, etype2relname = load_kg_vocab()

    from utils.text_bert_features import TextBertFeatures
    from gcn_models.utils      import set_seed

    set_seed(24)
    bert_model = TextBertFeatures(
        bert_model='/bert_model',
        device=args.device  
    )    
    log.info("Loaded BERT model")
    new_trial_data = pkl.load(open(args.pklpath, 'rb'))
    log.info(f"Loaded {len(new_trial_data)} trial(s) from {args.pklpath}")

    # AE predictions
    ae_model_path = find_file(os.path.join(args.model_dir, 'ae_model_shxo9bgw', 'ckpt.pt'))
    AE_preds = [
        run_task('AE', ae_model_path, loader_ae, trial, kgid2x, relname2etype, bert_model, args.device)
        for trial in new_trial_data
    ]

    # Safety predictions
    saf_model_path = find_file(os.path.join(args.model_dir, 'safety_model_1xekl810', 'ckpt.pt'))
    safety_preds = [
        run_task('Safety', saf_model_path, loader_safety, trial, kgid2x, relname2etype, bert_model, args.device)
        for trial in new_trial_data
    ]

    # Efficacy predictions (if 2+ trials)
    efficacy_pred = None
    if len(new_trial_data) > 1:
        eff_model_path = find_file(os.path.join(args.model_dir, 'efficacy_model_34l5ms9m', 'ckpt.pt'))
        efficacy_pred = run_task(
            'Efficacy', eff_model_path, loader_efficacy,
            new_trial_data[0], new_trial_data[1],
            kgid2x, relname2etype, bert_model, args.device
        )
    else:
        log.info("Skipping efficacy (only one trial provided)")

    # Build and save results
    result = {'meta': {}, 'AE': {}, 'safety': {}, 'efficacy': {}}
    for i, trial in enumerate(new_trial_data, start=1):
        result['meta'][f'trial_{i}_label'] = trial['arm_label']
        result['meta'][f'trial_{i}_text']  = trial['arm_text']
    for i, pred in enumerate(AE_preds, start=1):
        result['AE'][f'trial_{i}_ae'] = pred
    for i, pred in enumerate(safety_preds, start=1):
        result['safety'][f'trial_{i}_safety'] = pred
    if efficacy_pred is not None:
        result['efficacy']['prob_trial1_gt_trial2'] = efficacy_pred

    out_file = os.path.join(args.output_dir,
                            f"result_{os.path.basename(args.pklpath).split('.')[0]}.json")
    with open(out_file, 'w') as f:
        json.dump(result, f, indent=2)
    log.info(f"Saved results to {out_file}")

    print(json.dumps(result, indent=2))

if __name__ == "__main__":
    main()

### **Step 5: Load Results**

In [None]:
import os
import json
import pprint

# Folder that contains your results_.pkl and _summary.json files
folder = '/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/prediction/results'

# List all files in the folder
files = sorted(os.listdir(folder))

if not files:
    print(f"No files found in '{folder}'.")
else:
    for filename in files:
        file_path = os.path.join(folder, filename)

        # Skip subdirectories, just in case
        if not os.path.isfile(file_path):
            continue

        # Only process pkl and json files
        if not (filename.endswith('.pkl') or filename.endswith('.json')):
            continue

        print("\n" + "=" * 80)
        print(f"Found file: {file_path}")
        print("=" * 80)

        # Skip pickle files (do not attempt to read/unpickle)
        if filename.endswith('.pkl'):
            print("Skipping .pkl file (binary pickle, not printing content).")
            continue

        # Handle JSON files: read and pretty-print
        if filename.endswith('.json'):
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)

                print("--- JSON File Content ---")
                pprint.pprint(data)
            except Exception as e:
                print(f"[ERROR] An error occurred while reading '{file_path}': {e}")


Found file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/prediction/results/result_NCT03554265_results.json
--- JSON File Content ---
{'AE': {'trial_1_ae': {'1': 0.04310673847794533,
                       '101': 0.09017280489206314,
                       '102': 0.08278005570173264,
                       '104': 0.11740996688604355,
                       '11': 0.0867893323302269,
                       '114': 0.1782989203929901,
                       '116': 0.06525116413831711,
                       '123': 0.1289212852716446,
                       '126': 0.08748872578144073,
                       '130': 0.0510839968919754,
                       '131': 0.17666888236999512,
                       '132': 0.03557215631008148,
                       '138': 0.04951423034071922,
                       '14': 0.15189671516418457,
                       '144': 0.04723656550049782,
                       '147': 0.07295051962137222,
            

In [None]:
import json
from pathlib import Path

# ========= CONFIG =========
DATA_DIR = Path(r"/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/prediction/results")
OUTPUT_HTML = DATA_DIR / "lc_results_json.html"

# ========= BUILD HTML =========
html_parts = [
    "<!doctype html>",
    "<html>",
    "<head>",
    "  <meta charset='utf-8'>",
    "  <title>LC Results JSON Files</title>",
    "  <style>",
    "    body { font-family: monospace; white-space: pre-wrap; }",
    "    h2 { margin-top: 2em; border-bottom: 1px solid #ccc; }",
    "    pre { background: #f8f8f8; padding: 10px; border-radius: 4px; }",
    "  </style>",
    "</head>",
    "<body>",
    "  <h1>LC Results JSON Files</h1>",
]

for path in sorted(DATA_DIR.glob("*.json")):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Pretty-print JSON
    pretty = json.dumps(data, indent=2)

    # Escape HTML special chars
    pretty_escaped = (
        pretty.replace("&", "&amp;")
              .replace("<", "&lt;")
              .replace(">", "&gt;")
    )

    html_parts.append(f"  <h2>{path.name}</h2>")
    html_parts.append("  <pre>")
    html_parts.append(pretty_escaped)
    html_parts.append("  </pre>")

html_parts.append("</body>")
html_parts.append("</html>")

# ========= WRITE FILE =========
with open(OUTPUT_HTML, "w", encoding="utf-8") as f:
    f.write("\n".join(html_parts))

print(f"HTML written to: {OUTPUT_HTML}")

HTML written to: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/prediction/results/lc_results_json.html


### **Step 6: Check Logs**

In [None]:
from pathlib import Path
from collections import deque

logs_dir = Path("/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/4_Experiment/prediction/logs")

def tail(path, n=40):
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        return list(deque(f, maxlen=n))

# Choose all files (you can filter later if needed)
log_files = sorted(
    [p for p in logs_dir.iterdir() if p.is_file()],
    key=lambda p: p.stat().st_mtime,
    reverse=True
)

if not log_files:
    print("⚠️ No files found in Logs directory.")
else:
    for lf in log_files:
        print("\n" + "="*80)
        print(f"📄 {lf.name}")
        print("="*80)
        for line in tail(lf, n=30):
            print(line.rstrip())


📄 PlaNet_Predict_master.log
[Sat Nov 22 06:21:52 AEDT 2025] Starting multi-trial prediction job
  notebooks/data → ../data
  Host models under parsing_package/data/models:
3u7di6ag
ae_model_shxo9bgw
dragon
efficacy_model_34l5ms9m
safety_model_1xekl810
[Sat Nov 22 06:21:52 AEDT 2025] Found 6 *_results.pkl files in /scratch/sq95/sp6154/planet/parsing_package/LC_Results
[Sat Nov 22 06:21:52 AEDT 2025] ===== Trial: NCT03554265 (file: NCT03554265_results.pkl) =====
[Sat Nov 22 06:29:53 AEDT 2025] Finished NCT03554265 (log: /scratch/sq95/sp6154/planet/parsing_package/logs/PlaNet_Predict_NCT03554265_results.log)
[Sat Nov 22 06:29:53 AEDT 2025] ===== Trial: NCT04678830 (file: NCT04678830_results.pkl) =====
[Sat Nov 22 06:36:17 AEDT 2025] Finished NCT04678830 (log: /scratch/sq95/sp6154/planet/parsing_package/logs/PlaNet_Predict_NCT04678830_results.log)
[Sat Nov 22 06:36:17 AEDT 2025] ===== Trial: NCT04809974 (file: NCT04809974_results.pkl) =====
[Sat Nov 22 06:41:19 AEDT 2025] Finished NCT0480

### **Step 7: Interpret Results**

**Clinical Efficacy and Safety Results Table:**

| Trial ID        | Drug / Intervention                            | Phase / Design         | Population & Arms                                                                                                | Treatment Duration                      | Safety Score*                                           | Efficacy (Probability Treatment > Placebo/Control)** | Interpretation                                                                       |
| --------------- | ---------------------------------------------- | ---------------------- | ---------------------------------------------------------------------------------------------------------------- | --------------------------------------- | ------------------------------------------------------- | ---------------------------------------------------- | ------------------------------------------------------------------------------------ |
| **NCT04678830** | Leronlimab<br>700mg SC weekly                  | Phase 2, RCT           | 56 long COVID patients<br>(28 leronlimab, 28 placebo)                                                            | 8 weeks                                 | **0.405**<br>(Treatment = placebo)                      | **0.483**                                            | No efficacy signal<br>(probability <0.5 suggests placebo performed slightly better)  |
| **NCT04880161** | Inhaled Ampion<br>8 mL QID x 5 days            | Phase 1, RCT           | 32 long COVID patients<br>(15 Ampion, 16 placebo)                                                                | 5 days treatment<br>28 days follow-up   | **0.608**<br>(Treatment = placebo)                      | **0.479**                                            | No efficacy signal<br>(placebo marginally better)                                    |
| **NCT05047952** | Vortioxetine<br>10–20mg daily                  | Phase 2, RCT           | 149 post-COVID cognitive impairment<br>(75 vortioxetine, 74 placebo)                                             | 8 weeks                                 | **0.609**<br>(Treatment = placebo)                      | **0.488**                                            | No efficacy signal<br>(no cognitive benefit over placebo)                            |
| **NCT05633407** | Efgartigimod<br>10mg/kg IV weekly              | Phase 2, RCT           | 53 POTS / autonomic long COVID<br>(36 efgartigimod, 17 placebo)                                                  | 24 weeks                                | **0.687**<br>(Treatment = placebo)                      | **0.493**                                            | No efficacy signal<br>(closest to null, still <0.5)                                  |
| **NCT04809974** | Niagen (nicotinamide riboside)<br>2000mg daily | Phase 2, RCT           | Long COVID with cognitive / neuropsychiatric symptoms<br>(~60% Niagen, 40% placebo)                              | Up to 6–7 months total follow-up        | **0.549**<br>(Niagen = placebo)                         | **0.481**                                            | No efficacy signal on cognitive / functional outcomes                                |
| **NCT04871815** | Sodium pyruvate nasal spray<br>3× daily        | Single-arm, open-label | Long COVID “long haulers” with respiratory and systemic symptoms                                                 | 7 days treatment (after 7-day baseline) | **0.399**                                               | – (no comparator arm)                                | Good tolerability but **no comparative efficacy estimate** (single-arm design)       |
| **NCT03554265** | Somatropin (Genotropin)<br>rhGH replacement    | Phase 2, multi-arm     | mTBI subjects (rhGH, 6 months)<br>Household controls (no treatment)<br>PASC/long COVID subjects (rhGH, 9 months) | 6–9 months                              | **0.524**<br>**mTBI vs controls vs PASC all identical** | **0.483**<br>(probability mTBI ≥ controls)           | No clear efficacy advantage of rhGH over household controls (on composite endpoints) |

* *Safety Score: model-derived probability that the trial is safely tolerated (higher = safer).
** *Efficacy probability <0.5 indicates the control/placebo arm performs marginally better than the active arm; values ≈0.5 indicate no difference.*

---

**Key Findings**

**Safety Profile**

* Across all **randomized long COVID trials** (Leronlimab, Ampion, Niagen, Vortioxetine, Efgartigimod), and in the **rhGH replacement** trial (mTBI, household controls, PASC), the **sodium pyruvate nasal spray study**, and **Genotropin**:

  * Safety scores ranged from **0.399 to 0.687**, all compatible with **acceptable tolerability**.
* **Highest safety scores:**

  * **Efgartigimod** (**0.687**) and **Vortioxetine** (**0.609**) – long treatment windows (8–24 weeks) with intensive monitoring.
  * **Niagen** (**0.549**) – dietary supplement with benign safety profile.
* **Lower but acceptable scores:**

  * **Leronlimab** (**0.405**) and **sodium pyruvate nasal spray** (**0.399**) still fall in a tolerable range, with **no major safety signal**.
  * **Somatropin (Genotropin) trial** shows an **identical safety probability (0.524)** for **mTBI subjects, household controls, and PASC subjects**, indicating **no arm-specific safety penalty**.
* For all RCTs, **safety scores were identical between active and control arms**, suggesting:

  * Balanced randomization
  * Adverse event (AE) patterns driven by background disease and trial context rather than the intervention itself
  * No major drug-specific toxicity captured by the model.

---

**Efficacy Outcomes:**

**Critical finding: none of the placebo-controlled trials demonstrate superiority of the active treatment over placebo.**

* All **efficacy probabilities (P[treatment > placebo])** are **<0.5**, tightly clustered between **0.479–0.488**.
* This pattern implies that, if anything, **placebo arms perform marginally better** than treatment arms, but differences are extremely small and compatible with **no true treatment effect**.
* The **rhGH / Genotropin trial (NCT03554265)** compares **mTBI vs household controls vs PASC**:

  * The model’s efficacy probability **P(mTBI ≥ household controls) = 0.483**, again showing **no superiority** of somatropin replacement on the aggregated clinical endpoints.
  * AE rankings and safety scores are **identical across all three arms**, reinforcing a **neutral efficacy–safety balance**.

The **remarkable consistency** of near-0.5 probabilities across independent trials, populations and mechanisms suggests:

* Absence of a robust treatment effect beyond placebo
* Well-randomized and well-controlled studies
* **Strong placebo and natural recovery components** in long COVID and related post-infectious sequelae.

---

**Trial-Specific Observations:**

**Leronlimab (CCR5 antagonist, NCT04678830)**

* **Efficacy probability 0.483** → no evidence that CCR5 blockade improves symptom trajectories versus placebo.
* AE distributions and safety scores are **almost perfectly overlapping** between arms, indicating **no clear safety trade-off**.

**Ampion (biologic anti-inflammatory, NCT04880161)**

* **Efficacy probability 0.479** – the **lowest** among the RCTs.
* Aggressive but **short (5-day) dosing schedule** with follow-up to 28 days.
* Suggests that **short-course anti-inflammatory nebulization is insufficient** to shift long COVID respiratory outcomes beyond placebo.

**Vortioxetine (pro-cognitive antidepressant, NCT05047952)**

* **Efficacy probability 0.488** – no clinically meaningful improvement in DSST-based cognitive scores over placebo, despite:

  * Robust trial size (**n=149**)
  * Reasonable treatment duration (**8 weeks**).
* Provides a relatively **high-quality negative result** for targeting cognitive impairment in post-COVID condition with vortioxetine.

**Efgartigimod (FcRn antagonist, NCT05633407)**

* **Efficacy probability 0.493** – closest to 0.5, still **non-superior to placebo**.
* Despite **24 weeks of IV therapy** aimed at reducing pathogenic autoantibodies in POTS-like phenotypes, **no treatment signal** emerges.

**Niagen (nicotinamide riboside, NCT04809974)**

* **Efficacy probability 0.481** with **safety = 0.549** (identical arms).
* Targeted mitochondrial / NAD⁺ pathways in long COVID with cognitive and physical symptoms.
* Results suggest **no detectable clinical advantage** over placebo in this setting, despite a biologically plausible mechanism.

**Sodium Pyruvate Nasal Spray (NCT04871815)**

* **Single-arm open-label design** → **no probability of superiority vs placebo** can be computed.
* Safety score **0.399** suggests:

  * Acceptable tolerability
  * But, in the absence of a control group, **any symptomatic improvement cannot be disambiguated** from natural recovery or placebo effects.
* Useful as a **safety / feasibility signal**, not as definitive efficacy evidence.

**Somatropin (Genotropin, rhGH replacement, NCT03554265)**

* Complex, **three-arm design**:

  * **mTBI subjects** on rhGH for 6 months
  * **Household control subjects** with no intervention
  * **PASC / long COVID subjects** on rhGH for 9 months
* The model shows:

  * **Identical AE distributions** across `trial_1_ae`, `trial_2_ae`, and `trial_3_ae`
  * **Identical safety scores** for mTBI, controls, and PASC (**0.524** for all three arms)
  * **Efficacy probability P(trial 1 ≥ trial 2) = 0.483**, indicating **no superiority** of treatment over controls at the aggregated endpoint level.
* Interpretation:

  * **rhGH replacement appears tolerable** across mTBI and PASC cohorts
  * However, **no clear efficacy advantage** is detected when compared to well-matched household controls within the PlaNet framework.

---

**Clinical and Research Implications**

**Negative Trial Results Across Multiple Mechanisms**

Across these trials, the following mechanistic classes **all fail to outperform placebo** in long COVID or closely related post-COVID conditions:

1. **Immune modulation** – Leronlimab (CCR5 blockade)
2. **Biologic anti-inflammatory** – Ampion
3. **Neurotransmitter modulation / pro-cognitive antidepressant** – Vortioxetine
4. **Autoantibody reduction via FcRn blockade** – Efgartigimod
5. **Mitochondrial / NAD⁺ augmentation** – Niagen
6. **Metabolic neuromodulation (rhGH replacement)** – Somatropin / Genotropin
7. **Redox / anti-inflammatory nasal therapy** – Sodium pyruvate nasal spray (single-arm, no controlled signal)

**Possible Explanations for Universal Failure**

1. **Mis-specified targets**
   The true drivers of long COVID and PASC may lie outside the pathways targeted by these agents, or require more nuanced modulation.

2. **Heterogeneous biology**
   Long COVID likely comprises **multiple endotypes** (e.g. viral persistence, dysautonomia, autoimmunity, microvascular dysfunction, central sensitization) which are **not captured by undifferentiated recruitment**.

3. **Timing and disease stage**
   Interventions administered months after infection may miss critical windows for preventing or reversing pathophysiological changes.

4. **Dose / duration constraints**

   * Ampion: very short exposure (5 days)
   * Other agents: potentially insufficient dose or exposure relative to underlying biology.

5. **Strong placebo and natural recovery**
   High symptom variability, regression to the mean, and expectation effects substantially **flatten treatment–control differences**.

6. **Outcome measurement limitations**
   Heavy reliance on **subjective symptom scales and global cognitive tests** may dilute subtle but real physiological improvements.

---

**Conclusion**

The extended PlaNet analysis, now including **Leronlimab, Ampion, Vortioxetine, Efgartigimod, Niagen, Somatropin (Genotropin), and Sodium pyruvate nasal spray**, paints a **coherent negative picture**: despite targeting diverse and biologically plausible pathways, **no evaluated monotherapy demonstrates superiority over placebo** in long COVID or related PASC contexts. Safety profiles are generally acceptable and remarkably similar between arms, but the **absence of efficacy signals across mechanisms and trial designs** underscores both the **complexity of long COVID biology** and the **urgent need for mechanistically stratified, biomarker-guided therapeutic strategies**.

# **LC Causal Genes**

## **Summary**

**Step 1: Prepare the Gene List:**
- Prepare the final lists of Long COVID-associated genes with the official gene symbols (e.g., `ACE2`, `TMPRSS2`, `IFNG`).

**Step 2: Query Drug-Target Databases:**
- Use the gene lists to search databases that contain information on drug-protein interactions. The most effective approach is to query them one by one.

Do the following using DrugBank and DisGeNET for one gene. After that, automate the searching process.

**DrugBank**
- DrugBank is a primary resource for detailed information on drug targets. 
1.  **Go to the DrugBank website**: [https://go.drugbank.com/](https://go.drugbank.com/)
2.  **Search for a Gene**: In the main search bar, enter the symbol for one of your genes (e.g., `ACE2`) and search.
3.  **Identify Associated Drugs**: On the results page for that gene/protein, look for sections like **"Associated Drugs"** or **"Drug Targets"**.
4.  **Record the Drugs**: For each gene, record the names of the drugs that are listed as interacting with it (e.g., inhibitors, antagonists, binders).

Repeat this process for each gene on each list.

**DisGeNET**
DisGeNET is excellent for exploring gene-disease and gene-drug relationships.

1.  **Go to the DisGeNET website**: [https://www.disgenet.org/](https://www.disgenet.org/)
2.  **Search for a Gene**: Use the search bar to find one of your genes.
3.  **Explore Associations**: On the gene's page, explore the "Summary of Associations" and look for curated drug-gene interactions.

**Step 3: Consolidate the Drug Candidate List:**

After querying the databases with all the genes, we need to consolidate the results.

1.  **Combine Results**: Create a master list of all the drug names we recorded from DrugBank and DisGeNET.
2.  **Find Unique Drugs**: Remove any duplicate drug names to create a final, unique list of drug candidates.

This final list is the biologically-informed set of drugs that are hypothesized to be relevant to Long COVID. We can now use these drugs in our PlaNet prediction pipeline.

## **Steps**

**Step 1:** Create the environment_causal_genes.yml File

Location: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes 

```
name: gene_mapper
channels:
  - conda-forge
  - defaults

dependencies:
  - python=3.9
  - requests
  - beautifulsoup4
```

**Step 2:** Create the Conda Environment

Open a terminal, navigate to the directory where the environment_causal_genes.yml is saved, and create a new environment named gene_mapper using the instructions in the file.

```bash
conda env create -f environment.yml
```

**Step 3:** Activate the New Environment

Once the environment is created, activate it to start using it.

```bash
conda activate gene_mapper
```

**Step 4:** Prepare the list of causal genes in different .csv files and save them in the following location:
- Location: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists 
- Install pandas
```bash
pip install pandas
```

**Step 5:** Create the python scripts to map drugs for each gene.

| Resource                                      | Primary focus & content                                                                                                     | Why it’s useful / strengths                                                                                                      | Access / notes                                                                                                                                                        |
| --------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **DrugBank**                                  | Curated drug-centric database linking drugs to their targets (including gene/protein), mechanisms, interactions, approvals. | Rich metadata, commercial & experimental drugs, includes target gene associations and interaction data.                          | API + downloadable XML (you’re already using local dump). Requires license for some uses; academic free tier exists. ([docs.drugbank.com][1]) ([dev.drugbank.com][2]) |
| **ChEMBL**                                    | Bioactivity database of small molecules with annotated target proteins (gene-level), including measured activities.         | High-throughput and literature-curated drug–target (gene) interactions with activity values; open.                               | Webservices / bulk download; good for mapping a gene to compounds with binding evidence. ([ebi.ac.uk][3]) ([Oxford Academic][4])                                      |
| **DrugCentral**                               | Integrated resource of approved and investigational drugs with mechanism-of-action targets.                                 | Up-to-date on approvals, mechanism-of-action gene targets, bioactivity profiles; downloadable.                                   | Open access, includes drug–target mappings and cross-references. ([drugcentral.org][5]) ([SpringerLink][6])                                                           |
| **Therapeutic Target Database (TTD)**         | Curated known and explored therapeutic targets and their corresponding drugs, with druggability info.                       | Focused on target (gene/protein) druggability and therapeutic indications; extensive download.                                   | Has recent updates (2024); bulk download available. ([idrblab.org][7], [idrblab.net][8])                                                                              |
| **IUPHAR/BPS Guide to Pharmacology (GtoPdb)** | Expert-curated quantitative interactions between ligands (drugs) and targets (genes/proteins).                              | High-quality ligand-target (gene) interaction data with pharmacology context.                                                    | REST API and downloadable tables under open license. ([guidetopharmacology.org][9], [guidetopharmacology.org][10])                                                    |
| **DGIdb (Drug–Gene Interaction Database)**    | Aggregator that harmonizes and exposes known/predicted drug–gene interactions from many sources.                            | Unified query over many underlying resources (DrugBank, PharmGKB, ChEMBL, etc.), with druggability annotations; easy list input. | API + web; good “first pass” mapping. Latest version improvements for precision medicine. ([Oxford Academic][11], [Oxford Academic][12])                              |
| **Drug Target Commons (DTC)**                 | Crowd-sourced, standardized bioactivity profiles for compound–target (gene) interactions.                                   | Community-curated binding/activity data used to refine consensus on drug–gene binding; interoperable with other systems.         | Downloadable, API, integrates into other platforms (and used by DGIdb). ([Oxford Academic][13], [drugtargetcommons.fimm.fi][14])                                      |
| **STITCH**                                    | Known and predicted chemical–protein associations (including drugs to gene products) integrating multiple evidence types.   | Combines experimental, text-mined, and inferred drug–gene interactions; useful for network/contextual inference.                 | Web interface and bulk data; includes confidence scores. ([stitch.embl.de][15], [Oxford Academic][16])                                                                |
| **BindingDB**                                 | Experimentally measured binding affinities between small molecules (drugs) and protein targets (genes).                     | Gold standard for physical interaction strength; can filter for your gene and retrieve drugs/compounds with affinity data.       | Download & programmatic access; FAIR-updated 2024. ([Oxford Academic][17], [bdb1.ucsd.edu][18])                                                                       |

[1]: https://docs.drugbank.com/v1/?utm_source=chatgpt.com "API Reference | DrugBank Help Center"
[2]: https://dev.drugbank.com/guides/tutorials/interactions?utm_source=chatgpt.com "API | DrugBank Help Center"
[3]: https://www.ebi.ac.uk/chembl/?utm_source=chatgpt.com "ChEMBL - ChEMBL"
[4]: https://academic.oup.com/nar/article/40/D1/D1100/2903401?utm_source=chatgpt.com "ChEMBL: a large-scale bioactivity database for drug discovery"
[5]: https://drugcentral.org/?utm_source=chatgpt.com "Drug Central"
[6]: https://link.springer.com/article/10.1007/s10822-023-00529-x?utm_source=chatgpt.com "Exploring DrugCentral: from molecular structures to clinical effects ..."
[7]: https://idrblab.org/Publication/P162-37713619.pdf?utm_source=chatgpt.com "TTD: Therapeutic Target Database describing target druggability information"
[8]: https://idrblab.net/ttd/full-data-download?utm_source=chatgpt.com "Full Data Download | Therapeutic Target Database"
[9]: https://www.guidetopharmacology.org/?utm_source=chatgpt.com "Home | IUPHAR/BPS Guide to PHARMACOLOGY"
[10]: https://www.guidetopharmacology.org/databaseContent.jsp?utm_source=chatgpt.com "Database Content | IUPHAR/BPS Guide to PHARMACOLOGY"
[11]: https://academic.oup.com/nar/article/52/D1/D1227/7416371?utm_source=chatgpt.com "DGIdb 5.0: rebuilding the drug–gene interaction database for precision ..."
[12]: https://academic.oup.com/nar/article/49/D1/D1144/6006193?utm_source=chatgpt.com "Integration of the Drug–Gene Interaction Database (DGIdb 4.0) with open ..."
[13]: https://academic.oup.com/database/article/doi/10.1093/database/bay083/5096727?utm_source=chatgpt.com "Drug Target Commons 2.0: a community platform for systematic analysis ..."
[14]: https://drugtargetcommons.fimm.fi/?utm_source=chatgpt.com "Home [drugtargetcommons.fimm.fi]"
[15]: https://stitch.embl.de/?utm_source=chatgpt.com "STITCH: chemical association networks"
[16]: https://academic.oup.com/nar/article/42/D1/D401/1061025?utm_source=chatgpt.com "STITCH 4: integration of protein–chemical interactions with user data ..."
[17]: https://academic.oup.com/nar/article/53/D1/D1633/7906836?utm_source=chatgpt.com "BindingDB in 2024: a FAIR knowledgebase of protein-small molecule ..."
[18]: https://bdb1.ucsd.edu/bind/index.jsp?utm_source=chatgpt.com "Binding Database"


## **Unique Genes**

In [None]:
import pandas as pd
import glob
import os
from collections import defaultdict

# Path to the folder containing your gene list CSVs
folder_path = r"/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists"

# Output filenames
output_union_with_sources = os.path.join(folder_path, "all_unique_genes_with_sources.csv")
output_union_genes_only = os.path.join(folder_path, "all_unique_genes.csv")

# Gather all CSVs except previous union outputs to avoid recursion
csv_files = [
    f for f in glob.glob(os.path.join(folder_path, "*.csv"))
    if os.path.basename(f) not in {
        os.path.basename(output_union_with_sources),
        os.path.basename(output_union_genes_only),
    }
]

gene_sources = defaultdict(set)  # gene_symbol -> set of source labels

for file in csv_files:
    label = os.path.splitext(os.path.basename(file))[0]  # e.g., "CT" from "CT.csv"
    try:
        df = pd.read_csv(file)
        if 'Gene_Symbol' in df.columns:
            genes = df['Gene_Symbol'].dropna().astype(str).str.strip()
            for g in genes:
                if g:
                    gene_sources[g].add(label)
        else:
            print(f"⚠️ Column 'Gene_Symbol' not found in {file}")
    except Exception as e:
        print(f"Error reading {file}: {e}")

# Build DataFrame with Sources column (sorted, semicolon-separated)
records = []
for gene in sorted(gene_sources):
    sources = ";".join(sorted(gene_sources[gene]))
    records.append({"Gene_Symbol": gene, "Sources": sources})

final_df = pd.DataFrame.from_records(records, columns=["Gene_Symbol", "Sources"])

# Save union with sources
final_df.to_csv(output_union_with_sources, index=False)

# Also save the plain unique gene list (no sources) for backward compatibility
genes_only_df = final_df[["Gene_Symbol"]].drop_duplicates().sort_values("Gene_Symbol")
genes_only_df.to_csv(output_union_genes_only, index=False)

# Print preview and counts
print("First rows of the unique gene list with sources:")
print(final_df.head())  # First few rows
print(f"\nTotal number of unique genes: {len(final_df)}")
print(f"\n✅ Done! Saved:\n - With sources: {output_union_with_sources}\n - Genes only: {output_union_genes_only}")

First rows of the unique gene list with sources:
  Gene_Symbol Sources
0         A2M      CT
1       ABCA1      CT
2       ABCF3      CT
3        ABL1      CT
4        ABL2      CT

Total number of unique genes: 1725

✅ Done! Saved:
 - With sources: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv
 - Genes only: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes.csv


## **Helper Function**

In [None]:
import os
import json
import pandas as pd

# Path to the union gene list with source annotations (must exist)
UNION_WITH_SOURCES = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
# Fallback if that doesn't exist, use plain union
UNION_SIMPLE = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes.csv"
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

def load_genes_with_sources():
    if os.path.isfile(UNION_WITH_SOURCES):
        df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
        if GENE_COLUMN not in df.columns or SOURCES_COLUMN not in df.columns:
            raise ValueError(f"{UNION_WITH_SOURCES} must contain columns '{GENE_COLUMN}' and '{SOURCES_COLUMN}'")
        df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
        df[SOURCES_COLUMN] = df[SOURCES_COLUMN].fillna("").astype(str)
        genes = df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()
    elif os.path.isfile(UNION_SIMPLE):
        df = pd.read_csv(UNION_SIMPLE, dtype=str)
        if GENE_COLUMN not in df.columns:
            raise ValueError(f"{UNION_SIMPLE} missing column '{GENE_COLUMN}'")
        df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
        df[SOURCES_COLUMN] = ""  # empty sources
        genes = df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()
    else:
        raise FileNotFoundError("Union gene list not found")
    return genes  # DataFrame with columns Gene_Symbol, Sources

## **Databases**

### **🟩DrugBank**

**Links:**
- Home: https://go.drugbank.com/
- Download: https://go.drugbank.com/releases/latest
  
**Characteristics:**
- Comprehensive, curated repository of drugs (small molecules, biologics, investigational agents) and their molecular targets, including gene/protein targets, mechanisms of action, approval status, and interactions. 
- Ideal for high-confidence drug–gene target mapping and mechanism annotation. 
- Requires appropriate licensing for some use cases; local XML dumps can be used for bulk offline parsing.

In [None]:
# Check DrugBank XML structure
# This script inspects the DrugBank XML structure to understand its tags and attributes.
# It samples a few drug entries to demonstrate how to extract relevant information.

import logging as log  # for logging messages
# Import modules for XML parsing, dictionary handling, and filesystem operations
import xml.etree.ElementTree as ET  # for reading and traversing XML files
from collections import defaultdict  # for easily creating dictionaries with default values
import os  # for file path handling and checking if files exist

# Path to your local DrugBank XML file
DRUGBANK_XML = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/Full_DataBases/full_database_DrugBank.xml"

# Extract the XML namespace from a tag (e.g., "{namespace}tagname" → "namespace")
def get_namespace(tag):
    if tag.startswith("{"):  # Namespaced tags start with "{"
        return tag.split("}")[0].strip("{")
    return ""  # No namespace

# Remove the namespace from a tag (e.g., "{namespace}tagname" → "tagname")
def strip_ns(tag):
    return tag.split("}", 1)[1] if "}" in tag else tag

def sample_drug_structure(xml_path, sample_n=3):
    """
    Inspect the DrugBank XML structure by printing:
      - The XML namespace (if any)
      - For the first `sample_n` <drug> entries:
          • Drug name
          • Primary DrugBank ID
          • Drug type (from XML attribute)
          • Up to 3 target entries with:
              ◦ Target ID
              ◦ Gene symbol
              ◦ Organism
              ◦ Synonyms
      - A shallow summary of tag paths under <drug> (to understand structure)
    """

    # Ensure the file exists before proceeding
    if not os.path.isfile(xml_path):
        print(f"ERROR: '{xml_path}' not found.")
        return

    # Create an iterator to parse XML progressively (saves memory)
    context = ET.iterparse(xml_path, events=("start", "end"))

    namespace = None  # Will hold the XML namespace string
    root = None       # Will hold the root element
    drugs_sampled = 0 # Counter for number of drugs processed
    tag_paths = set() # Set of unique tag paths found under <drug>

    print("Scanning XML for namespace and sampling drug entries...\n")

    # Iterate over each XML element in the file
    for event, elem in context:

        # Capture the root element and detect namespace at the very beginning
        if root is None and event == "start":
            root = elem
            namespace = get_namespace(elem.tag)
            if namespace:
                print(f"Detected XML namespace: '{namespace}'")
            else:
                print("No namespace detected.")

        # Helper function to format tag names with the current namespace
        def qname(t):
            return f"{{{namespace}}}{t}" if namespace else t

        # When finishing a <drug> element, process its contents
        if event == "end" and elem.tag == qname("drug"):
            if drugs_sampled < sample_n:  # Only sample up to `sample_n` drugs
                print(f"\n--- Sample drug #{drugs_sampled + 1} ---")

                # 1) Get the drug's main name
                name = elem.findtext(qname("name"), default="<no name>")
                print(f"Drug name: {name}")

                # 2) Get the primary DrugBank ID
                db_id = None
                for id_elem in elem.findall(qname("drugbank-id")):
                    if id_elem.get("primary") == "true":  # Attribute marks primary ID
                        db_id = id_elem.text
                        break
                print(f"Primary DrugBank ID: {db_id}")

                # 3) Get the drug type from the XML attribute
                dtype = elem.get("type", "<no type attr>")
                print(f"Drug type attribute: {dtype}")

                # 4) Record tag paths under this <drug> (depth 1 and 2)
                for child in elem:
                    tag_paths.add(strip_ns(child.tag))
                    for sub in child:
                        tag_paths.add(f"{strip_ns(child.tag)}/{strip_ns(sub.tag)}")

                # 5) Get the drug's targets
                targets = elem.find(qname("targets"))
                if targets is None:
                    print("No <targets> element.")
                else:
                    print("Targets (showing up to 3):")
                    tcount = 0
                    for target in targets.findall(qname("target")):
                        if tcount >= 3:  # Only display first 3 targets for this sample
                            break

                        # Get <polypeptide> section containing gene info
                        polypeptide = target.find(qname("polypeptide"))
                        # Get target ID (from <id> tag)
                        kind = target.findtext(qname("id"), default="<no target id>")
                        print(f"  - Target ID: {kind}")

                        if polypeptide is not None:
                            # Extract gene symbol
                            gene_name = polypeptide.findtext(qname("gene-name"), default="<no gene-name>")
                            # Extract organism name
                            organism = polypeptide.findtext(qname("organism"), default="<no organism>")
                            print(f"    * Gene symbol: {gene_name}")
                            print(f"    * Organism: {organism}")

                            # Extract synonyms if they exist
                            syn_parent = polypeptide.find(qname("synonyms"))
                            if syn_parent is not None:
                                syns = [s.text for s in syn_parent.findall(qname("synonym")) if s.text]
                                print(f"    * Synonyms: {syns[:5]}{'...' if len(syns) > 5 else ''}")
                        else:
                            print("    * <polypeptide> missing.")

                        tcount += 1

                drugs_sampled += 1  # Increase the sample counter

            # Clear the processed <drug> element from memory to save RAM
            elem.clear()

            # Stop if we have processed enough sample drugs
            if drugs_sampled >= sample_n:
                break

    # Print all unique tag paths found under <drug> for structural reference
    print("\n--- Shallow tag path summary under <drug> (examples) ---")
    for path in sorted(tag_paths):
        print(f"  {path}")

    print("\nDone sampling.")

# Run the sampling function when the script is executed directly
if __name__ == "__main__":
    sample_drug_structure(DRUGBANK_XML, sample_n=5)

Scanning XML for namespace and sampling drug entries...

Detected XML namespace: 'http://www.drugbank.ca'

--- Sample drug #1 ---
Drug name: Lepirudin
Primary DrugBank ID: None
Drug type attribute: <no type attr>
No <targets> element.

--- Sample drug #2 ---
Drug name: Phylloquinone
Primary DrugBank ID: None
Drug type attribute: <no type attr>
No <targets> element.

--- Sample drug #3 ---
Drug name: Calcium
Primary DrugBank ID: None
Drug type attribute: <no type attr>
No <targets> element.

--- Sample drug #4 ---
Drug name: Lepirudin
Primary DrugBank ID: DB00001
Drug type attribute: biotech
Targets (showing up to 3):
  - Target ID: BE0000048
    * Gene symbol: F2
    * Organism: Humans
    * Synonyms: ['3.4.21.5', 'Coagulation factor II']

--- Sample drug #5 ---
Drug name: Cetuximab
Primary DrugBank ID: None
Drug type attribute: <no type attr>
No <targets> element.

--- Shallow tag path summary under <drug> (examples) ---
  absorption
  affected-organisms
  affected-organisms/affected-

In [None]:
#!/usr/bin/env python3
import os
import json
import pandas as pd
import xml.etree.ElementTree as ET
from collections import defaultdict

# === CONFIGURATION ===
BASE_DIR = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes"
GENE_LIST_WITH_SOURCES = os.path.join(BASE_DIR, "gene_lists", "all_unique_genes_with_sources.csv")
DRUGBANK_XML = os.path.join(BASE_DIR, "Full_DataBases/full_database_DrugBank.xml")
OUTPUT_DIR = os.path.join(BASE_DIR, "gene_lists", "drugbank_results_with_sources")
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- XML helpers ---
def _get_namespace(tag):
    if tag.startswith("{"):
        return tag.split("}")[0].strip("{")
    return ""

def _qname(tag, ns):
    return f"{{{ns}}}{tag}" if ns else tag

def parse_drugbank_xml(xml_path, genes_of_interest=None):
    """
    Parse DrugBank XML, restrict to genes_of_interest (case-insensitive, including synonyms),
    return canonical gene -> set(drug names), and alias->canonical mapping.
    Only human targets are retained.
    """
    genes_filter = {g.strip().upper() for g in genes_of_interest} if genes_of_interest else None
    gene_to_drugs = defaultdict(set)
    alias_to_canonical = {}

    context = ET.iterparse(xml_path, events=("start", "end"))
    namespace = None

    for event, elem in context:
        if namespace is None and event == "start":
            namespace = _get_namespace(elem.tag)
        if event == "end" and elem.tag == _qname("drug", namespace):
            # Drug name
            drug_name_elem = elem.find(_qname("name", namespace))
            if drug_name_elem is None or not drug_name_elem.text:
                elem.clear()
                continue
            drug_name = drug_name_elem.text.strip()

            # Targets
            targets_parent = elem.find(_qname("targets", namespace))
            if targets_parent is not None:
                for target in targets_parent.findall(_qname("target", namespace)):
                    polypeptide = target.find(_qname("polypeptide", namespace))
                    if polypeptide is None:
                        continue

                    # Filter to human
                    organism = polypeptide.findtext(_qname("organism", namespace), default="").lower()
                    if "homo sapiens" not in organism and "human" not in organism:
                        continue

                    # Primary gene name
                    gene_name = polypeptide.findtext(_qname("gene-name", namespace))
                    if not gene_name:
                        continue
                    canonical = gene_name.strip().upper()

                    # Collect aliases (primary + synonyms)
                    aliases = {canonical}
                    syn_parent = polypeptide.find(_qname("synonyms", namespace))
                    if syn_parent is not None:
                        for syn in syn_parent.findall(_qname("synonym", namespace)):
                            if syn is not None and syn.text:
                                aliases.add(syn.text.strip().upper())

                    # If filtering by a provided list, skip if no intersection
                    if genes_filter and not (aliases & genes_filter):
                        continue

                    # Register alias -> canonical
                    for alias in aliases:
                        alias_to_canonical[alias] = canonical

                    # Associate drug
                    gene_to_drugs[canonical].add(drug_name)

            elem.clear()

    return gene_to_drugs, alias_to_canonical

def map_genes_to_drugs_with_sources(genes_df, gene_to_drugs, alias_to_canonical):
    """
    Input: DataFrame with columns Gene_Symbol and Sources.
    Returns mapping and flat rows.
    """
    mapping = {}
    longform_rows = []
    summary_rows = []

    for _, row in genes_df.iterrows():
        gene = str(row[GENE_COLUMN]).strip()
        sources = str(row[SOURCES_COLUMN]).strip()
        if not gene:
            continue
        upper = gene.upper()
        canonical = alias_to_canonical.get(upper, upper)
        drugs = sorted(gene_to_drugs.get(canonical, []))
        mapping[gene] = {"sources": sources, "drugs": drugs}

        # long form
        if drugs:
            for d in drugs:
                longform_rows.append({
                    "Gene_Symbol": gene,
                    "Gene_Sources": sources,
                    "Drug": d
                })
        else:
            longform_rows.append({
                "Gene_Symbol": gene,
                "Gene_Sources": sources,
                "Drug": ""
            })

        # summary row
        summary_rows.append({
            "Gene_Symbol": gene,
            "Gene_Sources": sources,
            "Drug_Count": len(drugs),
            "Drugs": ";".join(drugs)
        })

    return mapping, longform_rows, summary_rows

def main():
    # Load gene list with sources
    if not os.path.isfile(GENE_LIST_WITH_SOURCES):
        print(f"ERROR: gene list with sources not found at {GENE_LIST_WITH_SOURCES}")
        return
    if not os.path.isfile(DRUGBANK_XML):
        print(f"ERROR: DrugBank XML not found at {DRUGBANK_XML}")
        return

    genes_df = pd.read_csv(GENE_LIST_WITH_SOURCES, dtype=str)
    if GENE_COLUMN not in genes_df.columns or SOURCES_COLUMN not in genes_df.columns:
        print(f"ERROR: Expected columns '{GENE_COLUMN}' and '{SOURCES_COLUMN}' in {GENE_LIST_WITH_SOURCES}")
        return
    genes_df = genes_df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()

    union_genes = genes_df[GENE_COLUMN].dropna().astype(str).str.strip().unique().tolist()
    print(f"Loaded {len(union_genes)} genes with sources.")

    # Parse XML once
    print("Parsing DrugBank XML...")
    gene_to_drugs, alias_to_canonical = parse_drugbank_xml(DRUGBANK_XML, genes_of_interest=union_genes)
    print(f"Genes with at least one drug: {len(gene_to_drugs)}")

    # Map
    mapping, longform_rows, summary_rows = map_genes_to_drugs_with_sources(
        genes_df, gene_to_drugs, alias_to_canonical
    )

    # Save JSON
    json_path = os.path.join(OUTPUT_DIR, "drugbank_gene_drug_with_sources.json")
    with open(json_path, "w") as f:
        json.dump(mapping, f, indent=2)
    print(f"Saved JSON to {json_path}")

    # Save long-form CSV
    long_csv = os.path.join(OUTPUT_DIR, "drugbank_gene_drug_with_sources_longform.csv")
    pd.DataFrame(longform_rows).to_csv(long_csv, index=False)
    print(f"Saved long-form CSV to {long_csv}")

    # Save summary CSV
    summary_csv = os.path.join(OUTPUT_DIR, "drugbank_gene_drug_with_sources_summary.csv")
    pd.DataFrame(summary_rows).to_csv(summary_csv, index=False)
    print(f"Saved summary CSV to {summary_csv}")

    # Preview
    print("\nSample (first 10 genes):")
    count = 0
    for gene, info in mapping.items():
        print(f"  {gene} ({info['sources']}): {len(info['drugs'])} drug(s) -> {info['drugs'][:5]}")
        count += 1
        if count >= 10:
            break

if __name__ == "__main__":
    main()

Loaded 1725 genes with sources.
Parsing DrugBank XML...
Genes with at least one drug: 654
Saved JSON to /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/drugbank_results_with_sources/drugbank_gene_drug_with_sources.json
Saved long-form CSV to /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/drugbank_results_with_sources/drugbank_gene_drug_with_sources_longform.csv
Saved summary CSV to /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/drugbank_results_with_sources/drugbank_gene_drug_with_sources_summary.csv

Sample (first 10 genes):
  A2M (CT): 10 drug(s) -> ['Anacaulase', 'Bacitracin', 'Becaplermin', 'Cisplatin', 'Ocriplasmin']
  ABCA1 (CT): 3 drug(s) -> ['ATP', 'Glyburide', 'Probucol']
  ABCF3 (CT): 0 drug(s) -> []
  ABL1 (CT): 25 drug(s) -> ['1-[4-(PYRIDIN-4-YLOXY)PHENYL]-3-[3-(TRIFLUOROMETHYL)PHENYL]UREA', '2-amino-5-[3-(

### **🟩ChEMBL**

**Links:**
- Home: https://www.ebi.ac.uk/chembl/
- Download: https://ftp.ebi.ac.uk/pub/databases/chembl/ChEMBLdb/latest/chembl_35_postgresql.tar.gz


**Characteristics:**
- Open bioactivity database of small molecules with experimentally measured activities against protein targets. 
- Provides drug/compound-target associations with development-phase metadata (`max_phase`) so you can prioritize approved or late-stage compounds. 
- Useful for evidence-weighted mapping (binding/activity) and for discovering compounds with known potency.

**Linux Setup & Execution Steps:**

1. Prerequisites
- PostgreSQL installed (e.g., version 14).
- Python environment with `psycopg2-binary`, `pandas` installed.
- ChEMBL PostgreSQL dump restored (assumed available as a `.dump` or via `pg_restore`).
- Union gene list with source labels present at:  
  `gene_lists/all_unique_genes_with_sources.csv`

2. Database setup

a. Create (or confirm) the database and set ownership

```bash
# If the database did not already exist:
sudo -u postgres createdb -O sindypin chembl

# If already existed, change owner to your user:
sudo -u postgres psql -c "ALTER DATABASE chembl OWNER TO sindypin;"

b. (Optionally) Reassign existing objects from `postgres` to `sindypin`

```bash
sudo -u postgres psql -d chembl -c "REASSIGN OWNED BY postgres TO sindypin;"
```

c. Grant read/access privileges (alternative to full ownership)

```bash
sudo -u postgres psql -d chembl -c "GRANT USAGE ON SCHEMA public TO sindypin;"
sudo -u postgres psql -d chembl -c "GRANT SELECT ON ALL TABLES IN SCHEMA public TO sindypin;"
sudo -u postgres psql -d chembl -c "ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO sindypin;"
```

3. Verify database contents and permissions

```bash
# As the postgres superuser, check the number of molecules:
sudo -u postgres psql -d chembl -c "SELECT count(*) FROM molecule_dictionary;"

# As the target user (sindypin), verify you can query (permissions must be set):
psql -U sindypin -d chembl -c "SELECT count(*) FROM molecule_dictionary;"

# List available tables to confirm schema restored:
sudo -u postgres psql -d chembl -c "\dt"
```

4. Schema adjustment discovery

* Noted that the restored schema **did not have** `assay2target`, so the original SQL join needed revision.
* Instead, used the `assays.tid` field to link targets to assays.

5. Python script updates

a. Connection configuration

Configured PostgreSQL connection for peer auth (running as `sindypin`):

```python
PG_CONN_INFO = {
    "dbname": "chembl",
    "user": "sindypin",
    # "password": "<if using password auth, set here>"
    # omit "host" if relying on socket/peer authentication
}
```

b. Robust argument parsing for Jupyter

Used `parse_known_args()` so the script can be run inside Jupyter/VSCode without failing on kernel-added arguments:

```python
args, _unknown = parser.parse_known_args()
```

c. Updated gene→drug SQL logic

Replaced the original (broken) query with one that:

* Matches gene symbols to targets via:

  * Direct `target_dictionary.pref_name` equality.
  * Component synonyms → `target_components` → `target_dictionary`.
* Links targets to assays using `assays.tid` (since `assay2target` was missing).
* Traverses `assays` → `activities` → `molecule_dictionary`.
* Filters for approved drugs (`molecule_dictionary.max_phase >= 4`).

d. Added features

* **Per-gene caching** to avoid re-querying already-resolved genes (`chembl_local_gene_drug_cache.json`).
* **Logging** to both console and file (`chembl_mapper.log`), with INFO summary and DEBUG details.
* **Explicit printed summary & preview** for interactive consumption (first few rows, counts).
* Graceful handling of transaction errors with `conn.rollback()` so batch processing continues.

6. Run the script

From shell or within Jupyter/VSCode:

```bash
# Standard run (uses existing cache if present)
python map_genes_to_chembl_local.py

# Force re-query all genes (ignore cache)
python map_genes_to_chembl_local.py --refresh-cache
```

If inside a notebook, you can call:

```python
from map_genes_to_chembl_local import main
main(refresh_cache=False)
```

7. Output

* JSON mapping: `chembl_local_results/chembl_local_gene_drug_with_sources.json`
* Long-form CSV: `chembl_local_results/chembl_local_gene_drug_with_sources.csv`
* Cache file: `chembl_local_results/chembl_local_gene_drug_cache.json`
* Log: `chembl_local_results/chembl_mapper.log`

8. Example sanity checks after run

```bash
# View a few lines of the CSV
head -n 10 chembl_local_results/chembl_local_gene_drug_with_sources.csv

# Count genes with at least one mapped drug (in Python or use awk/cut)
```

Notes

* If new ChEMBL versions are restored or schema differs, revisit the SQL CTEs to ensure table/column names match.
* The caching mechanism can be invalidated with `--refresh-cache` when schema/data changes.

In [None]:
#!/usr/bin/env python3
"""
Map gene symbols to approved ChEMBL drugs (max_phase >= 4) using a local PostgreSQL ChEMBL dump.
Features:
  - Per-gene caching to avoid redundant queries.
  - Logging to console and file.
  - Batch SQL to resolve multiple genes at once.
  - Output: JSON + CSV, with summary, preview, and explicit printed output for notebooks/VSCode.

Usage (CLI):
  python map_genes_to_chembl_local.py [--refresh-cache]

In Jupyter/VSCode you can import and call main(refresh_cache=False) directly or let the
argument parser ignore the kernel's extra args by using parse_known_args().
"""
import os
import json
import argparse
import logging
import time
from collections import defaultdict

import pandas as pd
import psycopg2

# --- CONFIGURATION ---
UNION_WITH_SOURCES = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

# PostgreSQL connection (adjust if using password authentication; peer auth if running as sindypin)
PG_CONN_INFO = {
    "dbname": "chembl",
    "user": "sindypin",
    # "password": "your_password_if_set",  # uncomment if using password-based auth
    # omit "host" to use unix socket / peer auth; include "host": "localhost" if you want TCP
}

# Output paths
OUTPUT_DIR = os.path.join(os.path.dirname(UNION_WITH_SOURCES), "chembl_local_results")
CACHE_FILE = os.path.join(OUTPUT_DIR, "chembl_local_gene_drug_cache.json")
JSON_OUT = os.path.join(OUTPUT_DIR, "chembl_local_gene_drug_with_sources.json")
CSV_OUT = os.path.join(OUTPUT_DIR, "chembl_local_gene_drug_with_sources.csv")
LOG_FILE = os.path.join(OUTPUT_DIR, "chembl_mapper.log")

os.makedirs(OUTPUT_DIR, exist_ok=True)


def setup_logger():
    logger = logging.getLogger("chembl_mapper")
    logger.setLevel(logging.DEBUG)
    fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
    # console handler
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(fmt)
    if not logger.hasHandlers():
        logger.addHandler(ch)
    else:
        # avoid duplicate handlers in interactive re-run
        found = any(isinstance(h, logging.StreamHandler) for h in logger.handlers)
        if not found:
            logger.addHandler(ch)
    # file handler
    fh = logging.FileHandler(LOG_FILE)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(fmt)
    # avoid duplicate file handlers
    if not any(isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", None) == fh.baseFilename for h in logger.handlers):
        logger.addHandler(fh)
    return logger


def load_genes_with_sources():
    if not os.path.isfile(UNION_WITH_SOURCES):
        raise FileNotFoundError(f"Union gene list not found at {UNION_WITH_SOURCES}")
    df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
    if GENE_COLUMN not in df.columns or SOURCES_COLUMN not in df.columns:
        raise ValueError(f"Expected columns '{GENE_COLUMN}' and '{SOURCES_COLUMN}' in {UNION_WITH_SOURCES}")
    df[GENE_COLUMN] = df[GENE_COLUMN].astype(str).str.strip()
    df[SOURCES_COLUMN] = df[SOURCES_COLUMN].fillna("")
    return df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()


def query_gene_drug_mapping(gene_list_upper, conn, logger):
    """
    Bulk query mapping uppercase gene symbols to approved drug names (max_phase >= 4).
    Resolves gene -> target via direct pref_name or component_synonym, then
    uses assays.tid to link target to assays, and proceeds to activities -> drugs.
    """
    if not gene_list_upper:
        return {}

    # Validate required tables (note: no assay2target here; use assays.tid)
    with conn.cursor() as check_cur:
        check_cur.execute("""
            SELECT tablename
            FROM pg_catalog.pg_tables
            WHERE schemaname='public' AND tablename IN (
                'target_dictionary', 'component_synonyms', 'target_components',
                'assays', 'activities', 'molecule_dictionary'
            );
        """)
        existing = {r[0] for r in check_cur.fetchall()}
    required = {
        "target_dictionary",
        "component_synonyms",
        "target_components",
        "assays",
        "activities",
        "molecule_dictionary",
    }
    missing = required - existing
    if missing:
        logger.error("Schema mismatch: missing required table(s): %s", ", ".join(sorted(missing)))
        raise RuntimeError(f"Missing required ChEMBL table(s): {', '.join(sorted(missing))}")

    sql = """
    WITH input_genes AS (
        SELECT unnest(%s::text[]) AS gene_symbol
    ),
    matched_targets AS (
        -- direct pref_name match
        SELECT DISTINCT ig.gene_symbol, t.tid
        FROM input_genes ig
        JOIN target_dictionary t ON lower(t.pref_name) = lower(ig.gene_symbol)
        UNION
        -- via component synonym
        SELECT DISTINCT ig.gene_symbol, t.tid
        FROM input_genes ig
        JOIN component_synonyms cs 
          ON lower(cs.component_synonym) = lower(ig.gene_symbol)
        JOIN target_components tc ON cs.component_id = tc.component_id
        JOIN target_dictionary t ON tc.tid = t.tid
    ),
    gene_drug_pairs AS (
        SELECT
            mt.gene_symbol,
            m.pref_name AS drug_name
        FROM matched_targets mt
        JOIN assays a ON mt.tid = a.tid
        JOIN activities act ON a.assay_id = act.assay_id
        JOIN molecule_dictionary m ON act.molregno = m.molregno
        WHERE m.max_phase >= 4
    )
    SELECT gene_symbol, drug_name
    FROM gene_drug_pairs
    ORDER BY gene_symbol, drug_name;
    """

    mapping = defaultdict(set)
    try:
        with conn.cursor() as cur:
            cur.execute(sql, (gene_list_upper,))
            rows = cur.fetchall()
        for gene_sym, drug in rows:
            if drug:
                mapping[gene_sym.upper()].add(drug)
        logger.debug(
            "Queried %d genes; received %d gene-drug associations.",
            len(gene_list_upper),
            sum(len(v) for v in mapping.values()),
        )
    except Exception as e:
        conn.rollback()  # clear failed transaction so subsequent batches can proceed
        logger.exception("Error querying gene batch: %s", e)
    return mapping

def main(refresh_cache=False):
    logger = setup_logger()
    start_time = time.time()
    logger.info("Starting ChEMBL local gene->drug mapping. Refresh cache: %s", refresh_cache)

    # Load gene list
    try:
        genes_df = load_genes_with_sources()
    except Exception as e:
        logger.exception("Failed to load gene list: %s", e)
        print(f"ERROR: Failed to load gene list: {e}")
        return

    gene_symbols = genes_df[GENE_COLUMN].dropna().astype(str).str.strip().unique().tolist()
    gene_symbols_upper = [g.upper() for g in gene_symbols if g]
    logger.info(f"Loaded {len(gene_symbols_upper)} unique genes (uppercased).")

    # Load or initialize cache
    if os.path.isfile(CACHE_FILE) and not refresh_cache:
        try:
            with open(CACHE_FILE, "r") as f:
                cache = json.load(f)
            cached = {k.upper(): set(v) for k, v in cache.items()}
            logger.info(f"Loaded existing cache with {len(cached)} genes.")
        except Exception:
            logger.warning("Could not load existing cache, starting fresh.")
            cached = {}
    else:
        cached = {}

    to_query = [g for g in gene_symbols_upper if g not in cached]
    logger.info(f"{len(to_query)} gene(s) will be queried from DB (skipping {len(gene_symbols_upper)-len(to_query)} via cache).")

    # Connect to PostgreSQL
    try:
        conn = psycopg2.connect(**PG_CONN_INFO)
    except Exception as e:
        logger.exception("Failed to connect to PostgreSQL: %s", e)
        print(f"ERROR: Could not connect to PostgreSQL: {e}")
        return

    try:
        new_results = {}
        batch_size = 100
        for i in range(0, len(to_query), batch_size):
            batch = to_query[i : i + batch_size]
            logger.info(f"Querying batch {i}-{i+len(batch)-1} ({len(batch)} genes)...")
            batch_map = query_gene_drug_mapping(batch, conn, logger)
            for gene_up, drugs in batch_map.items():
                new_results[gene_up] = drugs
        # Merge into cache
        for gene_up, drugs in new_results.items():
            cached.setdefault(gene_up, set()).update(drugs)
    finally:
        conn.close()

    # Persist updated cache
    serializable_cache = {k: sorted(list(v)) for k, v in cached.items()}
    try:
        with open(CACHE_FILE + ".tmp", "w") as f:
            json.dump(serializable_cache, f, indent=2)
        os.replace(CACHE_FILE + ".tmp", CACHE_FILE)
        logger.info(f"Cache saved to {CACHE_FILE} ({len(cached)} genes).")
    except Exception as e:
        logger.exception("Failed to save cache: %s", e)

    # Build final mapping + long-form rows
    final_mapping = {}
    longform_rows = []
    for _, row in genes_df.iterrows():
        gene_orig = row[GENE_COLUMN]
        sources = row[SOURCES_COLUMN]
        gene_up = gene_orig.upper()
        drugs = sorted(cached.get(gene_up, []))
        final_mapping[gene_orig] = {"sources": sources, "drugs": drugs}
        if drugs:
            for d in drugs:
                longform_rows.append({
                    "Gene_Symbol": gene_orig,
                    "Gene_Sources": sources,
                    "Drug": d
                })
        else:
            longform_rows.append({
                "Gene_Symbol": gene_orig,
                "Gene_Sources": sources,
                "Drug": ""
            })

    # Write outputs
    try:
        with open(JSON_OUT, "w") as f:
            json.dump(final_mapping, f, indent=2)
        pd.DataFrame(longform_rows).to_csv(CSV_OUT, index=False)
    except Exception as e:
        logger.exception("Failed saving output files: %s", e)

    # Summary calculations
    total_genes = len(final_mapping)
    genes_with_drugs = sum(1 for v in final_mapping.values() if v["drugs"])
    non_empty_pairs = sum(1 for r in longform_rows if r["Drug"])
    unique_drugs = len({r["Drug"] for r in longform_rows if r["Drug"]})

    # Log summary
    logger.info("=== Summary ===")
    logger.info(f"Total genes processed: {total_genes}")
    logger.info(f"Genes with ≥1 drug: {genes_with_drugs}")
    logger.info(f"Total gene–drug pairs (non-empty): {non_empty_pairs}")
    logger.info(f"Unique drugs across all genes: {unique_drugs}")
    preview_df = pd.DataFrame(longform_rows).head(10)
    logger.info("First few rows:\n%s", preview_df.to_string(index=False))

    # Explicit print (good for notebooks)
    print("\n=== Summary ===")
    print(f"Total genes processed: {total_genes}")
    print(f"Genes with ≥1 drug: {genes_with_drugs}")
    print(f"Total gene–drug pairs (non-empty): {non_empty_pairs}")
    print(f"Unique drugs across all genes: {unique_drugs}")
    print("\nFirst few rows:")
    print(preview_df.to_string(index=False))
    print(f"\nResults saved to:\n  JSON: {JSON_OUT}\n  CSV:  {CSV_OUT}")

    elapsed = time.time() - start_time
    logger.info(f"Finished in {elapsed:.1f}s. Outputs: JSON={JSON_OUT}, CSV={CSV_OUT}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Local ChEMBL gene→drug mapper with cache and logging.")
    parser.add_argument("--refresh-cache", action="store_true", help="Ignore existing cache and re-query all genes.")
    # use parse_known_args so Jupyter/kernel extra args (like --f=...) don't error out
    args, _unknown = parser.parse_known_args()
    main(refresh_cache=args.refresh_cache)

2025-08-05 04:43:33,937 [INFO] Starting ChEMBL local gene->drug mapping. Refresh cache: False


2025-08-05 04:43:33,954 [INFO] Loaded 1725 unique genes (uppercased).
2025-08-05 04:43:33,960 [INFO] Loaded existing cache with 0 genes.
2025-08-05 04:43:33,962 [INFO] 1725 gene(s) will be queried from DB (skipping 0 via cache).
2025-08-05 04:43:33,967 [INFO] Querying batch 0-99 (100 genes)...
2025-08-05 04:43:58,548 [INFO] Querying batch 100-199 (100 genes)...
2025-08-05 04:44:03,922 [INFO] Querying batch 200-299 (100 genes)...
2025-08-05 04:44:08,458 [INFO] Querying batch 300-399 (100 genes)...
2025-08-05 04:44:12,803 [INFO] Querying batch 400-499 (100 genes)...
2025-08-05 04:44:17,024 [INFO] Querying batch 500-599 (100 genes)...
2025-08-05 04:44:21,182 [INFO] Querying batch 600-699 (100 genes)...
2025-08-05 04:44:26,005 [INFO] Querying batch 700-799 (100 genes)...
2025-08-05 04:44:31,031 [INFO] Querying batch 800-899 (100 genes)...
2025-08-05 04:44:35,278 [INFO] Querying batch 900-999 (100 genes)...
2025-08-05 04:44:41,827 [INFO] Querying batch 1000-1099 (100 genes)...
2025-08-05 04


=== Summary ===
Total genes processed: 1725
Genes with ≥1 drug: 549
Total gene–drug pairs (non-empty): 80702
Unique drugs across all genes: 2626

First few rows:
Gene_Symbol Gene_Sources          Drug
        A2M           CT              
      ABCA1           CT              
      ABCF3           CT              
       ABL1           CT   ABEMACICLIB
       ABL1           CT   ABROCITINIB
       ABL1           CT ACALABRUTINIB
       ABL1           CT     ADAGRASIB
       ABL1           CT     ADAPALENE
       ABL1           CT     ADENOSINE
       ABL1           CT      AFATINIB

Results saved to:
  JSON: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/chembl_local_results/chembl_local_gene_drug_with_sources.json
  CSV:  /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/chembl_local_results/chembl_local_gene_drug_with_sources.csv


### **🟨DGIdb**

**Links:**
- Home: https://dgidb.org/
- Download: https://dgidb.org/downloads
  
**Characteristics:**
- Aggregator that harmonizes drug–gene interaction information from multiple underlying sources (e.g., DrugBank, PharmGKB, ChEMBL) and annotates druggability and interaction types. 
- Good first-pass for broad coverage and for consolidating disparate evidence into unified gene→drug interaction sets.

In [None]:
#!/usr/bin/env python3
"""
DGIdb gene->drug mapping using the downloaded TSV (interactions.tsv),
preserving gene source labels. Avoids live GraphQL by relying on the
bulk download from https://www.dgidb.org/downloads.

Outputs:
  - JSON per-gene mapping: sources + drugs
  - Long-form CSV: one row per gene-drug pair with interaction types / claim sources
  - Summary printed to console including first few rows
"""
import os
import json
import pandas as pd

# === CONFIGURATION ===
BASE_DIR = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes"
GENE_WITH_SOURCES_CSV = os.path.join(BASE_DIR, "gene_lists", "all_unique_genes_with_sources.csv")
UNION_SIMPLE = os.path.join(BASE_DIR, "gene_lists", "all_unique_genes.csv")
INTERACTIONS_TSV = os.path.join(BASE_DIR,"Full_DataBases/full_database_DGIdb.tsv")  # <-- adjust to your downloaded file
OUTPUT_DIR = os.path.join(BASE_DIR, "gene_lists", "dgidb_results_local")
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Candidate column names in DGIdb interactions file
POSSIBLE_GENE_COLS = [
    "gene_name", "gene_symbol", "entrez_gene_symbol", "gene", "gene_id", "approved_symbol"
]
POSSIBLE_DRUG_COLS = [
    "drug_name", "drug_primary_name", "drug", "drug_concept_id", "approved_drug_name"
]
POSSIBLE_INTERACTION_TYPE_COLS = [
    "interaction_types", "interaction_type"
]
POSSIBLE_CLAIM_SOURCE_COLS = [
    "interaction_claim_source", "interaction_claim_sources", "source", "sources"
]


def load_genes_with_sources():
    if os.path.isfile(GENE_WITH_SOURCES_CSV):
        df = pd.read_csv(GENE_WITH_SOURCES_CSV, dtype=str)
        if GENE_COLUMN not in df.columns or SOURCES_COLUMN not in df.columns:
            raise ValueError(f"{GENE_WITH_SOURCES_CSV} must have columns '{GENE_COLUMN}' and '{SOURCES_COLUMN}'")
        df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
        df[SOURCES_COLUMN] = df[SOURCES_COLUMN].fillna("")
    elif os.path.isfile(UNION_SIMPLE):
        df = pd.read_csv(UNION_SIMPLE, dtype=str)
        if GENE_COLUMN not in df.columns:
            raise ValueError(f"{UNION_SIMPLE} missing column '{GENE_COLUMN}'")
        df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
        df[SOURCES_COLUMN] = ""
    else:
        raise FileNotFoundError("Union gene list not found (with or without sources)")
    return df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()


def pick_column(df_cols, candidates):
    for c in candidates:
        if c in df_cols:
            return c
    return None


def main():
    # --- load gene list ---
    genes_df = load_genes_with_sources()
    union_genes_up = {g.upper() for g in genes_df[GENE_COLUMN].dropna().astype(str).str.strip().unique().tolist()}

    # --- load interactions TSV ---
    if not os.path.isfile(INTERACTIONS_TSV):
        print(f"ERROR: interactions.tsv not found at '{INTERACTIONS_TSV}'.")
        print("Download it from https://www.dgidb.org/downloads (latest interactions.tsv).")
        return

    inter = pd.read_csv(INTERACTIONS_TSV, sep="\t", dtype=str, low_memory=False)

    # autodetect columns
    gene_col = pick_column(inter.columns, POSSIBLE_GENE_COLS)
    drug_col = pick_column(inter.columns, POSSIBLE_DRUG_COLS)
    interaction_type_col = pick_column(inter.columns, POSSIBLE_INTERACTION_TYPE_COLS)
    claim_source_col = pick_column(inter.columns, POSSIBLE_CLAIM_SOURCE_COLS)

    if gene_col is None or drug_col is None:
        print("ERROR: Could not auto-detect gene or drug column in interactions.tsv.")
        print("Available columns:", list(inter.columns)[:50])
        return

    # Normalize
    inter["gene_up"] = inter[gene_col].fillna("").astype(str).str.strip().str.upper()
    inter["drug_clean"] = inter[drug_col].fillna("").astype(str).str.strip()
    if interaction_type_col:
        inter["interaction_types_norm"] = inter[interaction_type_col].fillna("").astype(str).str.strip()
    else:
        inter["interaction_types_norm"] = ""
    if claim_source_col:
        inter["claim_sources_norm"] = inter[claim_source_col].fillna("").astype(str).str.strip()
    else:
        inter["claim_sources_norm"] = ""

    # Filter to genes of interest
    filtered = inter[inter["gene_up"].isin(union_genes_up)].copy()
    if filtered.empty:
        print("WARNING: No matching interactions found for your gene list in interactions.tsv.")
    # Build mapping: gene -> drug -> aggregates
    gene_drug_info = {}  # gene_upper -> {drug: {interaction_types:set, claim_sources:set}}

    # Build uppercase to original gene symbol & sources
    gene_meta = {row[GENE_COLUMN].upper(): (row[GENE_COLUMN], row[SOURCES_COLUMN]) for _, row in genes_df.iterrows()}

    for _, row in filtered.iterrows():
        gene_up = row["gene_up"]
        drug = row["drug_clean"]
        if not gene_up or not drug:
            continue
        itypes = set([x.strip() for x in row["interaction_types_norm"].split(";") if x.strip()]) if row["interaction_types_norm"] else set()
        csources = set([x.strip() for x in row["claim_sources_norm"].split(";") if x.strip()]) if row["claim_sources_norm"] else set()

        gene_entry = gene_drug_info.setdefault(gene_up, {})
        drug_entry = gene_entry.setdefault(drug, {"interaction_types": set(), "claim_sources": set()})
        drug_entry["interaction_types"].update(itypes)
        drug_entry["claim_sources"].update(csources)

    # Prepare outputs
    nested_mapping = {}  # original gene symbol -> {sources, drugs: [names]}
    longform_rows = []
    for gene_up, (orig_gene, sources) in gene_meta.items():
        drugs_dict = gene_drug_info.get(gene_up, {})
        drugs_list = sorted(drugs_dict.keys())
        nested_mapping[orig_gene] = {"sources": sources, "drugs": drugs_list}

        if drugs_list:
            for drug in drugs_list:
                info = drugs_dict.get(drug, {})
                interaction_types = ";".join(sorted(info.get("interaction_types", [])))
                claim_sources = ";".join(sorted(info.get("claim_sources", [])))
                longform_rows.append({
                    "Gene_Symbol": orig_gene,
                    "Gene_Sources": sources,
                    "Drug": drug,
                    "Interaction_Types": interaction_types,
                    "Claim_Sources": claim_sources
                })
        else:
            longform_rows.append({
                "Gene_Symbol": orig_gene,
                "Gene_Sources": sources,
                "Drug": "",
                "Interaction_Types": "",
                "Claim_Sources": ""
            })

    # Save JSON
    json_path = os.path.join(OUTPUT_DIR, "dgidb_gene_drug_with_sources.json")
    with open(json_path, "w") as f:
        json.dump(nested_mapping, f, indent=2)

    # Save long-form CSV
    df_long = pd.DataFrame(longform_rows)
    csv_path = os.path.join(OUTPUT_DIR, "dgidb_gene_drug_with_sources_longform.csv")
    df_long.to_csv(csv_path, index=False)

    # Summary
    total_genes = len(nested_mapping)
    genes_with_drugs = sum(1 for v in nested_mapping.values() if v["drugs"])
    non_empty_pairs = df_long[df_long["Drug"].astype(str).str.strip() != ""].shape[0]
    unique_drugs = df_long.loc[df_long["Drug"].astype(str).str.strip() != "", "Drug"].nunique()

    print("\n=== Summary ===")
    print(f"Total genes processed: {total_genes}")
    print(f"Genes with ≥1 drug: {genes_with_drugs}")
    print(f"Total non-empty gene–drug pairs: {non_empty_pairs}")
    print(f"Unique drugs across all genes: {unique_drugs}")

    print("\nFirst few rows of the long-form result:")
    if not df_long.empty:
        print(df_long.head(10).to_string(index=False))
    else:
        print("  <no rows>")

    print(f"\nDGIdb (local TSV) results saved to:\n  JSON: {json_path}\n  CSV:  {csv_path}")

if __name__ == "__main__":
    main()


=== Summary ===
Total genes processed: 1725
Genes with ≥1 drug: 936
Total non-empty gene–drug pairs: 20984
Unique drugs across all genes: 9288

First few rows of the long-form result:
Gene_Symbol Gene_Sources                    Drug Interaction_Types Claim_Sources
        A2M           CT          (+)-WAY 100135         inhibitor              
        A2M           CT      (R)-FLUROCARAZOLOL         inhibitor              
        A2M           CT      (S)-FLUROCARAZOLOL         inhibitor              
        A2M           CT    1-NAPHTHYLPIPERAZINE           agonist              
        A2M           CT           2-METHYL-5-HT           agonist              
        A2M           CT 5-(NONYLOXY)-TRYPTAMINE           agonist              
        A2M           CT                    5-CT           agonist              
        A2M           CT           5-HT-MODULINE         modulator              
        A2M           CT     5-HYDROXYTRYPTAMINE           agonist              
     

### **🟩DrugCentral**

**Links:**
- Home: https://drugcentral.org/
- Download: https://drugcentral.org/download
  
**Characteristics:**
 - Integrated compendium of approved and investigational drugs with annotated mechanisms of action and their gene/protein targets. 
 - Combines regulatory approval status with pharmacological action and chemical structure, making it a reliable source for current approved-drug target mappings.

In [None]:
#!/usr/bin/env python3
"""
DrugCentral mapping using downloaded target–drug file, preserving gene list sources.
User must have downloaded the DrugCentral TSV (e.g., full_database_DrugCentral.tsv)
that contains at least GENE and DRUG_NAME columns.

Prints summary (counts) and shows first few rows of the long-form result.
"""
import os
import json
import pandas as pd
from tqdm import tqdm

# CONFIG - adjust the path to your downloaded DrugCentral file
UNION_WITH_SOURCES = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
DRUGCENTRAL_TSV = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/Full_DataBases/full_database_DrugCentral.tsv"
OUTPUT_DIR = os.path.join(os.path.dirname(UNION_WITH_SOURCES), "drugcentral_results")
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

# Match the actual headers you inspected
TARGET_SYMBOL_COL = "GENE"          # primary gene symbol column
FALLBACK_TARGET_COL = "TARGET_NAME"  # fallback if GENE has no match
DRUG_NAME_COL = "DRUG_NAME"          # drug name column

os.makedirs(OUTPUT_DIR, exist_ok=True)


def load_genes_with_sources():
    df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
    if GENE_COLUMN not in df.columns:
        raise ValueError(f"Expected column '{GENE_COLUMN}' in {UNION_WITH_SOURCES}")
    df[GENE_COLUMN] = df[GENE_COLUMN].astype(str).str.strip()
    df[SOURCES_COLUMN] = df.get(SOURCES_COLUMN, "").fillna("")
    return df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()


def main():
    genes_df = load_genes_with_sources()
    union_genes = set(genes_df[GENE_COLUMN].str.upper())

    if not os.path.isfile(DRUGCENTRAL_TSV):
        print(f"DrugCentral file missing at {DRUGCENTRAL_TSV}")
        return

    dc = pd.read_csv(DRUGCENTRAL_TSV, sep="\t", dtype=str, low_memory=False)

    # Validate expected columns
    missing = [c for c in (TARGET_SYMBOL_COL, DRUG_NAME_COL) if c not in dc.columns]
    if missing:
        print(f"Adjust TARGET_SYMBOL_COL / DRUG_NAME_COL: missing columns {missing}. Available columns:", list(dc.columns)[:30])
        return

    # Prepare normalized matching fields
    dc["gene_primary_up"] = dc[TARGET_SYMBOL_COL].fillna("").astype(str).str.strip().str.upper()
    dc["gene_fallback_up"] = dc[FALLBACK_TARGET_COL].fillna("").astype(str).str.strip().str.upper() if FALLBACK_TARGET_COL in dc.columns else ""
    dc["drug_clean"] = dc[DRUG_NAME_COL].fillna("").astype(str).str.strip()

    mapping = {}
    rows = []

    for _, row in tqdm(genes_df.iterrows(), total=len(genes_df), desc="DrugCentral genes"):
        gene = row[GENE_COLUMN]
        sources = row[SOURCES_COLUMN]
        gene_up = gene.upper()

        # Primary match on GENE column
        subset = dc[dc["gene_primary_up"] == gene_up]
        # Fallback: if none, try TARGET_NAME
        if subset.empty and FALLBACK_TARGET_COL in dc.columns:
            subset = dc[dc["gene_fallback_up"] == gene_up]

        drugs = sorted(set(subset["drug_clean"].dropna()))
        mapping[gene] = {"sources": sources, "drugs": drugs}
        if drugs:
            for d in drugs:
                rows.append({"Gene_Symbol": gene, "Gene_Sources": sources, "Drug": d})
        else:
            rows.append({"Gene_Symbol": gene, "Gene_Sources": sources, "Drug": ""})

    # Save outputs
    json_path = os.path.join(OUTPUT_DIR, "drugcentral_gene_drug_with_sources.json")
    csv_path = os.path.join(OUTPUT_DIR, "drugcentral_gene_drug_with_sources.csv")
    with open(json_path, "w") as f:
        json.dump(mapping, f, indent=2)
    df_rows = pd.DataFrame(rows)
    df_rows.to_csv(csv_path, index=False)

    # Summary / preview
    total_genes = len(mapping)
    genes_with_drugs = sum(1 for v in mapping.values() if v["drugs"])
    non_empty_pairs = df_rows[df_rows["Drug"].astype(str).str.strip() != ""].shape[0]
    unique_drugs = df_rows.loc[df_rows["Drug"].astype(str).str.strip() != "", "Drug"].nunique()

    print("\n=== Summary ===")
    print(f"Total genes processed: {total_genes}")
    print(f"Genes with ≥1 drug: {genes_with_drugs}")
    print(f"Total non-empty gene–drug pairs: {non_empty_pairs}")
    print(f"Unique drugs across all genes: {unique_drugs}")

    print("\nFirst few rows of the long-form result:")
    if not df_rows.empty:
        print(df_rows.head(10).to_string(index=False))
    else:
        print("  <no rows>")

    # Sample per-gene (first 5)
    print("\nSample per-gene (first 5):")
    shown = 0
    for gene, info in mapping.items():
        print(f"  {gene} ({info['sources']}): {len(info['drugs'])} drug(s) -> {info['drugs'][:5]}")
        shown += 1
        if shown >= 5:
            break

    print(f"\nDrugCentral results saved to:\n  JSON: {json_path}\n  CSV:  {csv_path}")


if __name__ == "__main__":
    main()

DrugCentral genes: 100%|██████████| 1725/1725 [00:05<00:00, 287.82it/s]



=== Summary ===
Total genes processed: 1725
Genes with ≥1 drug: 390
Total non-empty gene–drug pairs: 5769
Unique drugs across all genes: 1443

First few rows of the long-form result:
Gene_Symbol Gene_Sources       Drug
        A2M           CT           
      ABCA1           CT   probucol
      ABCF3           CT           
       ABL1           CT   afatinib
       ABL1           CT   axitinib
       ABL1           CT  bosutinib
       ABL1           CT  ceritinib
       ABL1           CT crizotinib
       ABL1           CT  dasatinib
       ABL1           CT  erlotinib

Sample per-gene (first 5):
  A2M (CT): 0 drug(s) -> []
  ABCA1 (CT): 1 drug(s) -> ['probucol']
  ABCF3 (CT): 0 drug(s) -> []
  ABL1 (CT): 27 drug(s) -> ['afatinib', 'axitinib', 'bosutinib', 'ceritinib', 'crizotinib']
  ABL2 (CT): 14 drug(s) -> ['axitinib', 'bosutinib', 'crizotinib', 'dasatinib', 'erlotinib']

DrugCentral results saved to:
  JSON: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third

### **🟩TTD**

**Links:**
- Home: https://idrblab.net/ttd/
- Download: https://idrblab.net/ttd/full-data-download

**Datasets to download:**
- Target to drug mapping with mode of action: P1-07-Drug-TargetMapping.xlsx --> holds the core TargetID ↔ DrugID table.
- Target information in raw format: P1-01-TTD_target_download.txt --> holds TargetID → gene symbol/name mappings.
- Drug information in raw format: P1-02-TTD_drug_download.txt --> holds DrugID → drug name mappings.
  
**Characteristics:**
- TTD: Therapeutic Target Database.
- Curated information on known and explored therapeutic targets (genes/proteins), their associated diseases, and corresponding drugs. 
- Emphasizes target druggability and therapeutic context, providing target-to-drug mappings with mode-of-action details. 
- Bulk downloads are available for integration.

In [None]:
#!/usr/bin/env python3
"""
TTD gene -> drug mapping using the modern, all-in-one TTD data file.
This script uses P1-01-TTD_target_download.txt as the single source of truth
for gene-target-drug information, as is correct for the 2024 TTD data format.

Requires:
  - Your list of genes with source labels (CSV).
  - The TTD target info file (P1-01-TTD_target_download.txt).

The other TTD files (P1-02, P1-07) are no longer needed for this workflow.
"""
import os
import json
import csv
import re
from collections import defaultdict

import pandas as pd
import numpy as np

# === USER CONFIGURATION ===
UNION_WITH_SOURCES   = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
TTD_TARGET_INFO_FILE = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/Full_DataBases/full_database_TTD/P1-01-TTD_target_download.txt"

OUTPUT_DIR           = os.path.join(os.path.dirname(UNION_WITH_SOURCES), "ttd_results")
GENE_COLUMN          = "Gene_Symbol"
SOURCES_COLUMN       = "Sources"

os.makedirs(OUTPUT_DIR, exist_ok=True)


def parse_ttd_data_file(path):
    """
    Parses the TTD data format (e.g., P1-01-TTD_target_download.txt), which
    can be separated by '---' or '___'.
    """
    print(f"Parsing TTD data file: {os.path.basename(path)}")
    with open(path, 'r', encoding='utf-8', errors='ignore') as f:
        content = f.read()

    separator_pattern = r'[-_]{50,}'
    parts = re.split(separator_pattern, content, maxsplit=2)

    if len(parts) < 3:
        raise ValueError(f"Could not find data section in {os.path.basename(path)}. Expected long separator lines.")
    
    data_content = parts[2]
    records = defaultdict(dict)
    record_order = []

    for line in data_content.strip().split('\n'):
        if not line.strip():
            continue
        
        parts = line.strip().split('\t', 2)
        if len(parts) != 3:
            continue
        
        main_id, field, value = parts
        
        if main_id not in records:
            record_order.append(main_id)

        field_key = field.lower()
        # Handle multi-value fields like DRUGINFO by creating a list
        if field_key in records[main_id]:
            if not isinstance(records[main_id][field_key], list):
                records[main_id][field_key] = [records[main_id][field_key]]
            records[main_id][field_key].append(value)
        else:
            records[main_id][field_key] = value

    ordered_records = [records[rid] for rid in record_order]
    df = pd.DataFrame(ordered_records)
    print(f"Parsed {len(df)} records from {os.path.basename(path)}")
    return df


def main():
    # 1) Load user's gene list and create a map for easy lookup
    user_genes_df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
    user_genes_df[GENE_COLUMN] = user_genes_df[GENE_COLUMN].str.strip()
    # Create a lookup map of {UPPERCASE_GENE: Original_Gene_Symbol}
    gene_case_map = {g.upper(): g for g in user_genes_df[GENE_COLUMN].unique()}
    union_genes_upper = set(gene_case_map.keys())

    # 2) Parse the main TTD target data file
    ttd_df = parse_ttd_data_file(TTD_TARGET_INFO_FILE)

    # 3) Extract gene-drug relationships
    output_records = []
    processed_genes = set()

    print("Finding gene-drug relationships...")
    for _, row in ttd_df.iterrows():
        if 'genename' not in row or pd.isna(row['genename']):
            continue

        # Find which of the user's genes match this TTD record
        ttd_genes = {g.strip().upper() for g in str(row['genename']).replace(";", ",").split(",")}
        matching_genes = ttd_genes & union_genes_upper
        
        if not matching_genes:
            continue

        # If we have a match, check for drug info
        processed_genes.update(matching_genes)
        
        # --- START: CORRECTED LOGIC ---
        if 'druginfo' in row:
            drug_info_value = row['druginfo']
            # Proceed if the cell contains a list of drugs, or a single non-empty string.
            if isinstance(drug_info_value, list) or pd.notna(drug_info_value):
                drug_list = drug_info_value
                if not isinstance(drug_list, list):
                    drug_list = [drug_list]  # Ensure it's always a list for consistent processing
                
                for drug_entry in drug_list:
                    # drug_entry format is 'DRUG_ID\tDRUG_NAME\tSTATUS'
                    drug_parts = str(drug_entry).split('\t')
                    if len(drug_parts) > 1:
                        drug_name = drug_parts[1].strip()
                        for upper_gene in matching_genes:
                            output_records.append({
                                GENE_COLUMN: gene_case_map[upper_gene],
                                "Drug": drug_name
                            })
        # --- END: CORRECTED LOGIC ---

    print(f"Found drugs for {len(processed_genes)} of your genes.")
    
    # 4) Create final DataFrame and merge with original sources
    if not output_records:
        print("\nNo drug associations were found for any of the genes in your list.")
        return

    results_df = pd.DataFrame(output_records).drop_duplicates()
    final_df = pd.merge(results_df, user_genes_df, on=GENE_COLUMN, how="left")

    # 5) Write CSV output
    csv_path = os.path.join(OUTPUT_DIR, "ttd_gene_drug_with_sources.csv")
    # Ensure consistent column order
    final_df = final_df[[GENE_COLUMN, SOURCES_COLUMN, "Drug"]]
    final_df.to_csv(csv_path, index=False)

    # 6) Generate JSON output
    json_path = os.path.join(OUTPUT_DIR, "ttd_gene_drug_with_sources.json")
    gene_to_drugs_map = defaultdict(lambda: {'sources': '', 'drugs': set()})
    
    for _, row in final_df.iterrows():
        gene = row[GENE_COLUMN]
        source = row.get(SOURCES_COLUMN, '') # Use .get for safety
        drug = row['Drug']
        gene_to_drugs_map[gene]['sources'] = source
        gene_to_drugs_map[gene]['drugs'].add(drug)
    
    # Convert sets to sorted lists for clean JSON output
    final_json = {g: {'sources': v['sources'], 'drugs': sorted(list(v['drugs']))} for g, v in sorted(gene_to_drugs_map.items())}

    with open(json_path, 'w', encoding='utf-8') as jf:
        json.dump(final_json, jf, indent=2)

    # 7) Final Summary
    print("\n=== Summary ===")
    print(f"Total genes processed: {len(user_genes_df[GENE_COLUMN].unique())}")
    print(f"Genes with ≥1 drug:    {len(final_json)}")
    print(f"Total non-empty pairs:   {len(results_df)}")
    unique_drugs = len(results_df['Drug'].unique())
    print(f"Unique drugs found:      {unique_drugs}")
    print(f"\nResults saved to:\n  JSON: {json_path}\n  CSV:  {csv_path}")


if __name__ == "__main__":
    main()

Parsing TTD data file: P1-01-TTD_target_download.txt
Parsed 4298 records from P1-01-TTD_target_download.txt
Finding gene-drug relationships...
Found drugs for 738 of your genes.

=== Summary ===
Total genes processed: 1725
Genes with ≥1 drug:    592
Total non-empty pairs:   15885
Unique drugs found:      12131

Results saved to:
  JSON: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/ttd_results/ttd_gene_drug_with_sources.json
  CSV:  /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/ttd_results/ttd_gene_drug_with_sources.csv


### **🟩GtoPDB/IUPHAR**

**Links:**
- Home: https://www.guidetopharmacology.org/
- Download: https://www.guidetopharmacology.org/download.jsp

**Datasets to download:**
1. For Target and Gene Information: Target and family list: | Download TSV file (1.9MB) --> This file contains the target_id to human_gene_symbol mapping, which is essential. It will likely be named targets_and_families.tsv.
2. For Ligand Name Information: Ligand list: | Download TSV file (6.5MB) --> This file maps the ligand_id to the actual compound/drug name. It will likely be named ligands.tsv.
3. For the Interaction Data: Interaction data for ligands and targets: | Download TSV file (7.1MB) --> This is the central file that links a ligand_id to a target_id, forming the basis of the entire analysis. It will likely be named interactions.tsv.
  
**Characteristics:**
- GtoPdb: BPS Guide to Pharmacology
- Expert-curated resource linking pharmacological targets to ligands (including approved drugs and experimental compounds) with quantitative interaction data. 
- Strong for high-quality, literature-backed ligand–target relationships and pharmacological context. 
- Offers web services for programmatic retrieval.

In [None]:
#!/usr/bin/env python3
"""
GtoPdb (IUPHAR) mapping via LOCAL FILES for maximum speed.
This version is robust to minor changes in column names in the downloaded files.
"""
import os
import json
import csv
import pandas as pd

# --- USER CONFIGURATION ---
UNION_WITH_SOURCES = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
GTOPDB_DATA_DIR   = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/Full_DataBases/full_database_GtoPbd_IUPHAR"
OUTPUT_DIR        = os.path.join(os.path.dirname(UNION_WITH_SOURCES), "gtopdb_results_fast")
GENE_COLUMN       = "Gene_Symbol"
SOURCES_COLUMN    = "Sources"

# --- SETUP ---
os.makedirs(OUTPUT_DIR, exist_ok=True)


def load_genes_with_sources():
    """Loads the user's master list of genes with their sources."""
    df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
    df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
    df[SOURCES_COLUMN] = df[SOURCES_COLUMN].fillna("").str.strip()
    return df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()


def find_column(columns, candidates):
    """Finds the first matching column name from a list of candidates."""
    for c in candidates:
        if c in columns:
            return c
    return None


def build_gtopdb_map_from_files(data_dir):
    """
    Builds the gene->ligand map by reading and merging local GtoPdb files.
    """
    print("Building GtoPdb map from local files...")

    paths = {
        "interactions": os.path.join(data_dir, "interactions.tsv"),
        "targets":      os.path.join(data_dir, "targets_and_families.tsv"),
        "ligands":      os.path.join(data_dir, "ligands.tsv")
    }

    for name, path in paths.items():
        if not os.path.isfile(path):
            print(f"ERROR: File not found at '{path}'")
            return None

    print("Loading data into memory...")
    try:
        common_parser_options = {
            'sep':           '\t',
            'engine':        'python',
            'comment':       '#',
            'on_bad_lines':  'warn',
            'encoding':      'utf-8-sig'
        }

        interactions_df = pd.read_csv(paths["interactions"], **common_parser_options)
        targets_df      = pd.read_csv(paths["targets"],      **common_parser_options)
        ligands_df      = pd.read_csv(paths["ligands"],      **common_parser_options)

        # Strip stray quotes/whitespace from headers
        for df in (interactions_df, targets_df, ligands_df):
            df.columns = df.columns.str.strip('"').str.strip()

    except Exception as e:
        print(f"Error reading data files: {e}")
        return None

    # === updated candidate column names ===
    CANDIDATE_COLS = {
        'interaction_target_id': ['Target id', 'target_id', 'Target ID'],
        'interaction_ligand_id': ['Ligand id', 'ligand_id', 'Ligand ID'],
        'interaction_species':   ['Target Species', 'Target species', 'species', 'Species'],
        'target_id':             ['Target id', 'target_id', 'Target ID'],
        'target_gene_symbol':    ['Human gene symbol', 'human_gene_symbol', 'Gene symbol', 'HGNC symbol'],
        'ligand_id':             ['Ligand id', 'ligand_id', 'Ligand ID'],
        'ligand_name':           ['Name', 'name', 'Ligand name']
    }

    dataframes = {
        "interactions": interactions_df,
        "targets":      targets_df,
        "ligands":      ligands_df
    }

    actual_cols = {
        key: find_column(dataframes[file].columns, CANDIDATE_COLS[key])
        for key, file in [
            ('interaction_target_id', 'interactions'),
            ('interaction_ligand_id', 'interactions'),
            ('interaction_species', 'interactions'),
            ('target_id', 'targets'),
            ('target_gene_symbol', 'targets'),
            ('ligand_id', 'ligands'),
            ('ligand_name', 'ligands'),
        ]
    }

    # Report missing columns
    for name, col in actual_cols.items():
        if col is None:
            file_key = 'interactions' if 'interaction' in name else ('targets' if 'target' in name else 'ligands')
            print(f"ERROR: Could not find required column for '{name}'.")
            print(f"Available columns in '{paths[file_key]}': {list(dataframes[file_key].columns)}")
            return None

    # Filter for human interactions only
    print("Filtering for human interactions...")
    human_interactions = interactions_df[
        interactions_df[actual_cols['interaction_species']] == 'Human'
    ].copy()

    # Subset columns
    targets_subset      = targets_df[[actual_cols['target_id'], actual_cols['target_gene_symbol']]].dropna()
    ligands_subset      = ligands_df[[actual_cols['ligand_id'], actual_cols['ligand_name']]].dropna()
    interactions_subset = human_interactions[
        [actual_cols['interaction_target_id'], actual_cols['interaction_ligand_id']]
    ]

    # Merge
    print("Merging interactions with targets...")
    merged_df = pd.merge(
        interactions_subset,
        targets_subset,
        left_on=actual_cols['interaction_target_id'],
        right_on=actual_cols['target_id']
    )

    print("Merging with ligands...")
    final_df = pd.merge(
        merged_df,
        ligands_subset,
        left_on=actual_cols['interaction_ligand_id'],
        right_on=actual_cols['ligand_id']
    )

    # Build map
    print("Grouping data to create final map...")
    final_df[actual_cols['target_gene_symbol']] = (
        final_df[actual_cols['target_gene_symbol']].str.upper()
    )
    gtop_map = final_df.groupby(
        actual_cols['target_gene_symbol']
    )[actual_cols['ligand_name']].unique().apply(list).to_dict()

    print("Finished building GtoPdb gene->ligand map.")
    return gtop_map


def main():
    genes_df = load_genes_with_sources()
    gtop_map = build_gtopdb_map_from_files(GTOPDB_DATA_DIR)
    if gtop_map is None:
        return

    print(f"Mapping {len(genes_df)} genes against GtoPdb data...")

    mapping_results = {}
    longform_rows   = []
    for _, row in genes_df.iterrows():
        gene    = row[GENE_COLUMN]
        sources = row[SOURCES_COLUMN]
        ligands = gtop_map.get(gene.upper(), [])
        mapping_results[gene] = {"sources": sources, "drugs": sorted(ligands)}

        if ligands:
            for ligand_name in sorted(ligands):
                longform_rows.append({
                    "Gene_Symbol":  gene,
                    "Gene_Sources": sources,
                    "Ligand":       ligand_name
                })
        else:
            longform_rows.append({
                "Gene_Symbol":  gene,
                "Gene_Sources": sources,
                "Ligand":       ""
            })

    json_path = os.path.join(OUTPUT_DIR, "gtopdb_gene_drug_with_sources.json")
    csv_path  = os.path.join(OUTPUT_DIR, "gtopdb_gene_drug_with_sources.csv")

    # Write JSON
    with open(json_path, "w") as f:
        json.dump({k: mapping_results[k] for k in sorted(mapping_results)}, f, indent=2)

    # Build & preview long table
    df_long = pd.DataFrame(longform_rows)
    print("\nFirst 5 rows of the gene–ligand table:")
    print(df_long.head())

    # Save CSV
    df_long.to_csv(csv_path, index=False)

    # Summary
    total_genes        = len(mapping_results)
    genes_with_ligands = sum(1 for v in mapping_results.values() if v["drugs"])
    non_empty_pairs    = df_long[df_long["Ligand"].astype(bool)].shape[0]
    unique_ligands     = df_long.query("Ligand != ''")["Ligand"].nunique()

    print("\n=== GtoPdb Summary ===")
    print(f"Total genes processed: {total_genes}")
    print(f"Genes with ≥1 ligand:  {genes_with_ligands}")
    print(f"Total non-empty gene–ligand pairs: {non_empty_pairs}")
    print(f"Unique ligands found: {unique_ligands}")
    print(f"\nGtoPdb results saved to:\n  JSON: {json_path}\n  CSV:  {csv_path}")


if __name__ == "__main__":
    main()

Building GtoPdb map from local files...
Loading data into memory...
Filtering for human interactions...
Merging interactions with targets...
Merging with ligands...
Grouping data to create final map...
Finished building GtoPdb gene->ligand map.
Mapping 1725 genes against GtoPdb data...

First 5 rows of the gene–ligand table:
  Gene_Symbol Gene_Sources                               Ligand
0         A2M           CT                                     
1       ABCA1           CT  bihelical apoA-I mimetic peptide 5A
2       ABCA1           CT                             probucol
3       ABCF3           CT                                     
4        ABL1           CT                                  BO1

=== GtoPdb Summary ===
Total genes processed: 1725
Genes with ≥1 ligand:  485
Total non-empty gene–ligand pairs: 6984
Unique ligands found: 4559

GtoPdb results saved to:
  JSON: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/gtopd

### **🟨DTC**

- DTC: Drug–Gene Interaction Database.
- Aggregator that harmonizes drug–gene interaction information from multiple underlying sources (e.g., DrugBank, PharmGKB, ChEMBL) and annotates druggability and interaction types. 
- Good first-pass for broad coverage and for consolidating disparate evidence into unified gene→drug interaction sets.

In [None]:
#!/usr/bin/env python3
"""
DTC gene -> compound mapping with gene source labels via downloaded DTC bioactivity export.
"""
import os
import json
import pandas as pd

UNION_WITH_SOURCES = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
DTC_FILE = "/path/to/DTC_data.csv"  # user must download from Drug Target Commons
OUTPUT_DIR = os.path.join(os.path.dirname(UNION_WITH_SOURCES), "dtc_results")
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

# Adjust based on actual DTC export headers
DTC_TARGET_GENE_COL = "target_name"
DTC_COMPOUND_NAME_COL = "compound_name"

os.makedirs(OUTPUT_DIR, exist_ok=True)

def load_genes_with_sources():
    df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
    df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
    df[SOURCES_COLUMN] = df[SOURCES_COLUMN].fillna("")
    return df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()

def main():
    genes_df = load_genes_with_sources()
    if not os.path.isfile(DTC_FILE):
        print(f"DTC file not found at {DTC_FILE}")
        return
    dtc = pd.read_csv(DTC_FILE, dtype=str, low_memory=False)
    if DTC_TARGET_GENE_COL not in dtc.columns or DTC_COMPOUND_NAME_COL not in dtc.columns:
        print("Adjust column names to match DTC export. Available:", list(dtc.columns)[:20])
        return

    dtc["gene_up"] = dtc[DTC_TARGET_GENE_COL].str.strip().str.upper()
    dtc["compound_clean"] = dtc[DTC_COMPOUND_NAME_COL].str.strip()

    mapping = {}
    rows = []
    for _, row in genes_df.iterrows():
        gene = row[GENE_COLUMN]
        sources = row[SOURCES_COLUMN]
        gene_up = gene.upper()
        subset = dtc[dtc["gene_up"] == gene_up]
        compounds = sorted(set(subset["compound_clean"].dropna()))
        mapping[gene] = {"sources": sources, "drugs": compounds}
        if compounds:
            for c in compounds:
                rows.append({"Gene_Symbol": gene, "Gene_Sources": sources, "Compound": c})
        else:
            rows.append({"Gene_Symbol": gene, "Gene_Sources": sources, "Compound": ""})

    with open(os.path.join(OUTPUT_DIR, "dtc_gene_drug_with_sources.json"), "w") as f:
        json.dump(mapping, f, indent=2)
    pd.DataFrame(rows).to_csv(os.path.join(OUTPUT_DIR, "dtc_gene_drug_with_sources.csv"), index=False)
    print("DTC results saved.")

if __name__ == "__main__":
    main()

DTC file not found at /path/to/DTC_data.csv


**NOTE:** The website is down so I couldn't run this script.

### **🟨STITCH**

- Integrates known and predicted protein–chemical (including drug) interactions using experimental data, literature/text mining, and inference across species. 
- Offers confidence scores for associations, useful for network-based inference and expanding beyond strictly curated pairs.

In [None]:
import requests, certifi
SESSION = requests.Session()
SESSION.verify = certifi.where()

In [None]:
#!/usr/bin/env python3
import ssl, certifi, requests
import os, json, time, math, requests, pandas as pd

UNION_WITH_SOURCES = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
OUTPUT_DIR = os.path.join(os.path.dirname(UNION_WITH_SOURCES), "stitch_results_api")
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

BASE = "https://stitch.embl.de/api"  # STITCH API base
SPECIES = 9606
REQUIRED_SCORE = 700      # 0–1000; 700 = high confidence
BATCH = 200               # genes per API POST
SLEEP = 0.34              # be nice to the server (~3 req/s)

os.makedirs(OUTPUT_DIR, exist_ok=True)

def load_genes_with_sources():
    df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
    df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
    df[SOURCES_COLUMN] = df[SOURCES_COLUMN].fillna("")
    return df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()

def post_api(path, params, identifiers):
    """
    POST helper for endpoints that accept 'identifiers' body separated by newlines.
    """
    url = f"{BASE}/json/{path}"
    data = {"identifiers": "\n".join(identifiers)}
    resp = requests.post(url, params=params, data=data, timeout=120)
    resp.raise_for_status()
    return resp.json()

def resolve_genes_to_string_ids(genes):
    """
    Map gene symbols -> STRING IDs via STITCH 'resolve' endpoint.
    Returns dict {input_gene_upper: [string_id1, string_id2, ...]}.
    """
    out = {}
    for i in range(0, len(genes), BATCH):
        chunk = genes[i:i+BATCH]
        r = post_api("resolve", {"species": SPECIES}, chunk)
        # r is a list of mappings; group by 'input'
        for row in r:
            inp = str(row.get("input", "")).upper()
            sid = row.get("stringId")
            if not inp or not sid:
                continue
            out.setdefault(inp, []).append(sid)
        time.sleep(SLEEP)
    return out

def fetch_partners_for_string_ids(string_ids):
    """
    Query interaction partners for STRING IDs; return list of rows.
    We try v11 method name first, then the older name as fallback.
    """
    rows = []
    methods = ["interaction_partners", "interactors"]
    for method in methods:
        try:
            for i in range(0, len(string_ids), BATCH):
                chunk = string_ids[i:i+BATCH]
                params = {
                    "species": SPECIES,
                    "required_score": REQUIRED_SCORE,
                    "network_flavor": "evidence"
                }
                r = post_api(method, params, chunk)
                # Each element describes (A,B, score, etc.)
                for row in r:
                    a = row.get("stringId_A") or row.get("stringIda")
                    b = row.get("stringId_B") or row.get("stringIdb")
                    prefA = row.get("preferredName_A") or row.get("preferredNamea")
                    prefB = row.get("preferredName_B") or row.get("preferredNameb")
                    extA = row.get("externalId_A") or row.get("externalIda")
                    extB = row.get("externalId_B") or row.get("externalIdb")
                    score = row.get("score")
                    rows.append({
                        "stringId_A": a, "preferred_A": prefA, "external_A": extA,
                        "stringId_B": b, "preferred_B": prefB, "external_B": extB,
                        "score": score,
                    })
                time.sleep(SLEEP)
            if rows:
                break
        except requests.HTTPError as e:
            # try the next method name
            if method == methods[-1]:
                raise
            time.sleep(SLEEP)
    return rows

def main():
    genes_df = load_genes_with_sources()
    genes = genes_df[GENE_COLUMN].dropna().astype(str).str.strip().tolist()
    genes_up = [g.upper() for g in genes]

    # 1) resolve gene symbols -> STRING IDs
    mapping_ids = resolve_genes_to_string_ids(genes_up)

    # 2) fetch partners for all STRING IDs (batched)
    all_string_ids = sorted({sid for sids in mapping_ids.values() for sid in sids})
    partner_rows = fetch_partners_for_string_ids(all_string_ids)
    partners_df = pd.DataFrame(partner_rows)

    if partners_df.empty:
        print("No partners returned from STITCH.")
        return

    # 3) keep only CHEMICAL partners (STITCH chemicals use external IDs like CID00000…)
    # A chemical can be either side; gather both directions.
    def is_chem(external):
        return isinstance(external, str) and external.startswith("CID")

    chems_A = partners_df[partners_df["external_B"].map(is_chem)][["stringId_A","preferred_A","external_B","preferred_B","score"]].rename(
        columns={"stringId_A":"protein_id", "preferred_A":"protein_name", "external_B":"chemical_ext", "preferred_B":"chemical_name"}
    )
    chems_B = partners_df[partners_df["external_A"].map(is_chem)][["stringId_B","preferred_B","external_A","preferred_A","score"]].rename(
        columns={"stringId_B":"protein_id", "preferred_B":"protein_name", "external_A":"chemical_ext", "preferred_A":"chemical_name"}
    )
    chem_links = pd.concat([chems_A, chems_B], ignore_index=True).drop_duplicates()

    # 4) join back to your original gene list via the resolved IDs
    # Build reverse index STRING ID -> list of input gene symbols (upper)
    sid_to_genes = {sid: [g for g,u in zip(genes, genes_up) if sid in mapping_ids.get(u, [])] for sid in all_string_ids}

    rows = []
    for _, r in chem_links.iterrows():
        for gene in sid_to_genes.get(r["protein_id"], []):
            sources = genes_df.loc[genes_df[GENE_COLUMN]==gene, SOURCES_COLUMN].iloc[0]
            rows.append({
                "Gene_Symbol": gene,
                "Gene_Sources": sources,
                "Chemical": r["chemical_name"],
                "Chemical_External": r["chemical_ext"],
                "Score": r["score"],
                "Protein_Name": r["protein_name"],
                "Protein_STRING": r["protein_id"],
            })

    out_df = pd.DataFrame(rows).sort_values(["Gene_Symbol","Score"], ascending=[True,False])

    # 5) build your JSON mapping {gene: {sources, drugs:[...]}}
    mapping = {}
    for gene, sub in out_df.groupby("Gene_Symbol"):
        mapping[gene] = {
            "sources": sub["Gene_Sources"].iloc[0],
            "drugs": sorted(set(sub["Chemical"].dropna()))
        }

    with open(os.path.join(OUTPUT_DIR, "stitch_gene_drug_with_sources.json"), "w") as f:
        json.dump(mapping, f, indent=2)
    out_df.to_csv(os.path.join(OUTPUT_DIR, "stitch_gene_drug_with_sources.csv"), index=False)
    print(f"STITCH API results saved to {OUTPUT_DIR}")

if __name__ == "__main__":
    main()

### **🟨BindingDB**

- Public repository of experimentally measured binding affinities between small molecules and protein targets. 
- A strong source for physical interaction evidence, especially when you need binding strength context (e.g., Kd, Ki) or want to validate target engagement.

In [None]:
#!/usr/bin/env python3
"""
BindingDB gene -> ligand mapping with gene source labels.
Needs:
  - BindingDB full TSV (e.g., BindingDB_All_YYYYMM.tsv)
  - UniProt mapping of gene symbols to human UniProt accessions
Outputs JSON and CSV.
"""
import os
import json
import time
import pandas as pd
import requests

UNION_WITH_SOURCES = "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/all_unique_genes_with_sources.csv"
BINDINGDB_TSV = "/path/to/BindingDB_All_202507.tsv"  # adapt to actual name
OUTPUT_DIR = os.path.join(os.path.dirname(UNION_WITH_SOURCES), "bindingdb_results")
GENE_COLUMN = "Gene_Symbol"
SOURCES_COLUMN = "Sources"

# Expected columns in BindingDB dump (verify header names)
UNIPROT_COL = "Target Chain UniProt Entry"  # might vary slightly
LIGAND_NAME_COL = "Ligand Name"

os.makedirs(OUTPUT_DIR, exist_ok=True)
session = requests.Session()

def load_genes_with_sources():
    df = pd.read_csv(UNION_WITH_SOURCES, dtype=str)
    df[GENE_COLUMN] = df[GENE_COLUMN].str.strip()
    df[SOURCES_COLUMN] = df[SOURCES_COLUMN].fillna("")
    return df[[GENE_COLUMN, SOURCES_COLUMN]].drop_duplicates()

def gene_to_uniprot(gene):
    query = f"gene_exact:{gene} AND organism_id:9606"
    url = "https://rest.uniprot.org/uniprotkb/search"
    params = {
        "query": query,
        "fields": "accession",
        "format": "json",
        "limit": 5
    }
    resp = session.get(url, params=params)
    if resp.status_code != 200:
        return []
    hits = resp.json().get("results", [])
    return [r["primaryAccession"] for r in hits if "primaryAccession" in r]

def main():
    genes_df = load_genes_with_sources()
    if not os.path.isfile(BINDINGDB_TSV):
        print(f"BindingDB file missing at {BINDINGDB_TSV}")
        return

    # Map genes to UniProt accessions
    gene2unis = {}
    for _, row in genes_df.iterrows():
        gene = row[GENE_COLUMN]
        unis = gene_to_uniprot(gene)
        gene2unis[gene] = [u.upper() for u in unis]
        time.sleep(0.2)  # rate limit

    # Reverse: UniProt -> genes
    uniprot2genes = {}
    for gene, unis in gene2unis.items():
        for u in unis:
            uniprot2genes.setdefault(u, set()).add(gene)

    # Parse BindingDB in chunks
    mapping = {g: {"sources": row[SOURCES_COLUMN], "drugs": set()} for _, row in genes_df.iterrows() for g in [row[GENE_COLUMN]]}
    chunk_iter = pd.read_csv(BINDINGDB_TSV, sep="\t", dtype=str, low_memory=False, chunksize=200000)
    for chunk in chunk_iter:
        if UNIPROT_COL not in chunk.columns or LIGAND_NAME_COL not in chunk.columns:
            print("Check column names in BindingDB TSV. Available:", list(chunk.columns)[:20])
            return
        chunk["target_up"] = chunk[UNIPROT_COL].fillna("").str.strip().str.upper()
        chunk["ligand"] = chunk[LIGAND_NAME_COL].fillna("").str.strip()
        for _, row in chunk.iterrows():
            up = row["target_up"]
            if not up:
                continue
            genes_for = uniprot2genes.get(up, set())
            for gene in genes_for:
                ligand = row["ligand"]
                if ligand:
                    mapping[gene]["drugs"].add(ligand)

    # Finalize
    output_mapping = {}
    rows = []
    for gene, info in mapping.items():
        sources = info["sources"]
        drugs = sorted(info["drugs"])
        output_mapping[gene] = {"sources": sources, "drugs": drugs}
        if drugs:
            for d in drugs:
                rows.append({"Gene_Symbol": gene, "Gene_Sources": sources, "Drug": d})
        else:
            rows.append({"Gene_Symbol": gene, "Gene_Sources": sources, "Drug": ""})

    with open(os.path.join(OUTPUT_DIR, "bindingdb_gene_drug_with_sources.json"), "w") as f:
        json.dump(output_mapping, f, indent=2)
    pd.DataFrame(rows).to_csv(os.path.join(OUTPUT_DIR, "bindingdb_gene_drug_with_sources.csv"), index=False)
    print("BindingDB results saved.")

if __name__ == "__main__":
    main()

## **All Databases**

In [None]:
import pandas as pd
from pathlib import Path

# =========================
# CONFIG
# =========================

# Folder where your CSVs live
DATA_DIR = Path("/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/ALL")

# Map each file to a drug source name
SOURCE_FILES = {
    "ChEMBL": "chembl.csv",
    "DrugBank": "drugbank.csv",
    "DrugCentral": "drugcentral.csv",
    "GtoPdb": "gtopdb.csv",
    "TTD": "ttd.csv",
}

# Base drugs you want to look for
TARGET_DRUGS = [
    "Leronlimab",
    "Ampion",
    "Efgartigimod",
    "Allogeneic",
    "Paxlovid",      # brand
    "Nirmatrelvir",  # Paxlovid component
    "Ritonavir",     # Paxlovid component
    "Vortioxetine",
]

# Synonyms / aliases for each base drug
DRUG_SYNONYMS = {
    "Leronlimab": [
        "PRO 140",
        "PRO-140",
        "Vyrologix",
        "Leronlimab (PRO 140)",
    ],
    "Ampion": [
        "Ampion (human serum albumin–derived biologic)",
        "Ampio Ampion",
    ],
    "Efgartigimod": [
        "Efgartigimod alfa",
        "Efgartigimod alfa-fcab",
        "ARGX-113",
        "ARGX113",
        "Vyvgart",
        "Vyvgart IV",
    ],
    # ‘Allogeneic’ is a product type; keep a few likely phrases
    "Allogeneic": [
        "Allogeneic cell therapy",
        "Allogeneic mesenchymal stem cells",
        "Allogeneic MSCs",
    ],
    "Paxlovid": [
        "Nirmatrelvir/Ritonavir",
        "Nirmatrelvir plus Ritonavir",
        "PF-07321332/ritonavir",
        "PF-07321332",
        "PAXLOVID",
    ],
    "Nirmatrelvir": [
        "PF-07321332",
        "Nirmatrelvir (PF-07321332)",
    ],
    "Ritonavir": [
        "Norvir",
        "ABT-538",
        "Ritonavir (ABT-538)",
    ],
    "Vortioxetine": [
        "Trintellix",
        "Brintellix",
        "Lu AA21004",
        "Lu-AA21004",
    ],
}


# =========================
# LOAD & MERGE ALL TABLES (LONG FORMAT, INTERNAL)
# =========================

def load_and_tag_source(data_dir: Path, source_files: dict) -> pd.DataFrame:
    dfs = []

    for source_name, filename in source_files.items():
        fpath = data_dir / filename
        if not fpath.exists():
            print(f"⚠️ File not found: {fpath}")
            continue

        # sep=None lets pandas guess comma vs tab
        df = pd.read_csv(fpath, sep=None, engine="python")

        # Input columns: Gene_Symbol, Gene_Sources, Drug
        rename_map = {
            "Gene_Symbol": "Gene",
            "Gene_Sources": "Gene_Source",
            "Drug": "Drug",
        }
        df = df.rename(columns=rename_map)

        expected_cols = ["Gene", "Gene_Source", "Drug"]
        missing = [c for c in expected_cols if c not in df.columns]
        if missing:
            raise ValueError(f"{filename} is missing columns: {missing}")

        df = df[expected_cols].copy()

        # Clean gene + source
        for col in ["Gene", "Gene_Source"]:
            df[col] = df[col].astype(str).str.strip()

        # Clean drug and turn blanks/nan into NA
        df["Drug"] = df["Drug"].astype(str).str.strip()
        df.loc[df["Drug"].str.lower().isin(["nan", "none", ""]), "Drug"] = pd.NA

        # Tag database source
        df["Drug_Source"] = source_name

        dfs.append(df)

    if not dfs:
        raise RuntimeError("No files were successfully loaded.")

    merged = pd.concat(dfs, ignore_index=True)
    merged = merged.drop_duplicates()

    return merged


# =========================
# BUILD GENE-LEVEL TABLES (ONE ROW PER GENE)
# =========================

def build_gene_level_table(df: pd.DataFrame) -> pd.DataFrame:
    """
    Returns one row per Gene with:
      - Gene_Source: unique gene sources
      - Count: number of unique drugs
      - Drugs: comma-separated unique drugs
      - Drug_Source: comma-separated unique databases
    """

    def unique_join(series):
        vals = {str(x).strip() for x in series.dropna() if str(x).strip() != ""}
        return ", ".join(sorted(vals)) if vals else ""

    gene_level = (
        df
        .groupby("Gene", as_index=False)
        .agg(
            Gene_Source=("Gene_Source", unique_join),
            Count=("Drug", lambda s: s.dropna().nunique()),
            Drugs=("Drug", unique_join),
            Drug_Source=("Drug_Source", unique_join),
        )
    )
    return gene_level


# =========================
# SEARCH FOR TARGET DRUGS (USES LONG FORMAT INTERNALLY)
# =========================

def search_drugs(df: pd.DataFrame, targets: list, synonyms: dict):
    """
    Case-insensitive substring search for each target drug name
    and its synonyms in the Drug column.
    """
    print("\n==================== DRUG SEARCH RESULTS ====================\n")

    for base in targets:
        terms = [base] + synonyms.get(base, [])
        # Build a combined mask for base + all synonyms
        mask = pd.Series(False, index=df.index)

        for term in terms:
            if term.strip() == "":
                continue
            mask |= df["Drug"].str.contains(term, case=False, na=False)

        matches = df[mask].drop_duplicates()

        syn_str = ", ".join(terms[1:]) if len(terms) > 1 else "none"
        print(f"--- {base} ---")
        print(f"Synonyms searched: {syn_str if syn_str else 'none'}")

        if matches.empty:
            print("No matches found.\n")
        else:
            out_df = (
                matches[["Gene", "Gene_Source", "Drug", "Drug_Source"]]
                .sort_values(["Drug", "Gene", "Drug_Source"])
            )
            print(out_df.to_string(index=False))
            print("\n")


# =========================
# MAIN
# =========================

def main():
    # 1. Load long-format table (one row per Gene–Drug–Source)
    df_long = load_and_tag_source(DATA_DIR, SOURCE_FILES)

    # 2. Build gene-level table (ONE ROW PER GENE)
    gene_level = build_gene_level_table(df_long)

    # ---- Table 1: one row per gene, with drugs in a single column 'Drug'
    master_gene = gene_level[["Gene", "Gene_Source", "Drugs", "Drug_Source"]].rename(
        columns={"Drugs": "Drug"}
    )
    master_path = DATA_DIR / "all_gene_drug_sources_by_gene.csv"
    master_gene.to_csv(master_path, index=False)
    print(f"✅ Saved gene-level master table: {master_path}")
    print(f"Master (by gene) shape: {master_gene.shape}")
    print("First 5 rows of gene-level master table:")
    print(master_gene.head(), "\n")

    # ---- Table 2: same but with Count + Drugs column
    summary_path = DATA_DIR / "gene_drug_counts_by_gene.csv"
    gene_level.to_csv(summary_path, index=False)
    print(f"✅ Saved gene-level summary table: {summary_path}")
    print(f"Summary table shape: {gene_level.shape}")
    print(gene_level.head(), "\n")

    # 3. Search for your specific drugs (including synonyms) using the long-format data
    search_drugs(df_long, TARGET_DRUGS, DRUG_SYNONYMS)


if __name__ == "__main__":
    main()

✅ Saved gene-level master table: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/ALL/all_gene_drug_sources_by_gene.csv
Master (by gene) shape: (1725, 4)
First 5 rows of gene-level master table:
    Gene Gene_Source                                               Drug  \
0    A2M          CT  Anacaulase, Bacitracin, Becaplermin, Cisplatin...   
1  ABCA1          CT  ATP, Glyburide, Probucol, RG-7273, RG7232, bih...   
2  ABCF3          CT                                                      
3   ABL1          CT  1-[4-(PYRIDIN-4-YLOXY)PHENYL]-3-[3-(TRIFLUOROM...   
4   ABL2          CT  ABEMACICLIB, AFATINIB, ALECTINIB, ALPELISIB, A...   

                                  Drug_Source  
0       ChEMBL, DrugBank, DrugCentral, GtoPdb  
1  ChEMBL, DrugBank, DrugCentral, GtoPdb, TTD  
2       ChEMBL, DrugBank, DrugCentral, GtoPdb  
3  ChEMBL, DrugBank, DrugCentral, GtoPdb, TTD  
4  ChEMBL, DrugBank, DrugCentral, GtoPdb, TTD   

✅ Saved g

  mask |= df["Drug"].str.contains(term, case=False, na=False)
  mask |= df["Drug"].str.contains(term, case=False, na=False)


--- Leronlimab ---
Synonyms searched: PRO 140, PRO-140, Vyrologix, Leronlimab (PRO 140)
Gene Gene_Source       Drug Drug_Source
CCR5          CT Leronlimab    DrugBank
CCR5          CT    PRO 140         TTD
CCR5          CT    PRO-140         TTD
CCR5          CT leronlimab      GtoPdb


--- Ampion ---
Synonyms searched: Ampion (human serum albumin–derived biologic), Ampio Ampion
No matches found.

--- Efgartigimod ---
Synonyms searched: Efgartigimod alfa, Efgartigimod alfa-fcab, ARGX-113, ARGX113, Vyvgart, Vyvgart IV
No matches found.

--- Allogeneic ---
Synonyms searched: Allogeneic cell therapy, Allogeneic mesenchymal stem cells, Allogeneic MSCs
No matches found.

--- Paxlovid ---
Synonyms searched: Nirmatrelvir/Ritonavir, Nirmatrelvir plus Ritonavir, PF-07321332/ritonavir, PF-07321332, PAXLOVID
No matches found.

--- Nirmatrelvir ---
Synonyms searched: PF-07321332, Nirmatrelvir (PF-07321332)
No matches found.

--- Ritonavir ---
Synonyms searched: Norvir, ABT-538, Ritonavir (ABT-53

In [None]:
import pandas as pd
from pathlib import Path

# =========================
# CONFIG
# =========================

# Folder where your CSVs live
DATA_DIR = Path("/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/ALL")

# Map each file to a drug source name
SOURCE_FILES = {
    "ChEMBL": "chembl.csv",
    "DrugBank": "drugbank.csv",
    "DrugCentral": "drugcentral.csv",
    "GtoPdb": "gtopdb.csv",
    "TTD": "ttd.csv",
}

# =========================
# LOAD & MERGE ALL TABLES (LONG FORMAT)
# =========================

def load_and_tag_source(data_dir: Path, source_files: dict) -> pd.DataFrame:
    dfs = []

    for source_name, filename in source_files.items():
        fpath = data_dir / filename
        if not fpath.exists():
            print(f"⚠️ File not found: {fpath}")
            continue

        # sep=None lets pandas guess comma vs tab
        df = pd.read_csv(fpath, sep=None, engine="python")

        # Expected input columns: Gene_Symbol, Gene_Sources, Drug
        rename_map = {
            "Gene_Symbol": "Gene",
            "Gene_Sources": "Gene_Source",
            "Drug": "Drug",
        }
        df = df.rename(columns=rename_map)

        expected_cols = ["Gene", "Gene_Source", "Drug"]
        missing = [c for c in expected_cols if c not in df.columns]
        if missing:
            raise ValueError(f"{filename} is missing columns: {missing}")

        df = df[expected_cols].copy()

        # Clean gene + source
        for col in ["Gene", "Gene_Source"]:
            df[col] = df[col].astype(str).str.strip()

        # Clean drug and turn blanks/nan into NA
        df["Drug"] = df["Drug"].astype(str).str.strip()
        df.loc[df["Drug"].str.lower().isin(["nan", "none", ""]), "Drug"] = pd.NA

        # Tag database source
        df["Drug_Source"] = source_name

        dfs.append(df)

    if not dfs:
        raise RuntimeError("No files were successfully loaded.")

    merged = pd.concat(dfs, ignore_index=True)
    merged = merged.drop_duplicates()

    return merged

# =========================
# BUILD DRUG-LEVEL SUMMARY TABLE (ALL DRUGS)
# =========================

def build_all_drug_gene_table(df: pd.DataFrame) -> pd.DataFrame:
    """
    Build one row per distinct Drug (as it appears in the mapping).

    Columns:
      - Drug: drug name as in the sources
      - Gene_Count: number of unique genes mapped to this drug
      - Genes: comma-separated list of unique genes
      - Gene_Sources: unique gene sources (from Gene_Source)
      - Drug_Sources: unique databases where this Drug appears
    """

    # Drop rows with missing Drug
    df = df.dropna(subset=["Drug"]).copy()

    def unique_join(series):
        vals = {str(x).strip() for x in series.dropna() if str(x).strip() != ""}
        return ", ".join(sorted(vals)) if vals else ""

    drug_table = (
        df
        .groupby("Drug", as_index=False)
        .agg(
            Gene_Count=("Gene", lambda s: s.dropna().nunique()),
            Genes=("Gene", unique_join),
            Gene_Sources=("Gene_Source", unique_join),
            Drug_Sources=("Drug_Source", unique_join),
        )
    )

    # Optional: sort by Gene_Count descending, then Drug name
    drug_table = drug_table.sort_values(
        ["Gene_Count", "Drug"], ascending=[False, True]
    ).reset_index(drop=True)

    return drug_table

# =========================
# MAIN
# =========================

def main():
    # 1. Load long-format table (one row per Gene–Drug–Source)
    df_long = load_and_tag_source(DATA_DIR, SOURCE_FILES)
    print(f"Loaded long-format table with shape: {df_long.shape}")

    # 2. Build drug-level summary table for ALL drugs
    drug_table = build_all_drug_gene_table(df_long)

    # 3. Ensure output folder exists
    out_dir = DATA_DIR / "Final_Tables"
    out_dir.mkdir(parents=True, exist_ok=True)

    # 4. Save to CSV
    out_path = out_dir / "all_drug_gene_counts.csv"
    drug_table.to_csv(out_path, index=False)

    print(f"✅ Saved drug-level gene count table for ALL drugs: {out_path}")
    print(f"Drug table shape: {drug_table.shape}")
    print("First 5 rows of drug table:")
    print(drug_table.head())

if __name__ == "__main__":
    main()

Loaded long-format table with shape: (120437, 4)
✅ Saved drug-level gene count table for ALL drugs: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/LG_Causal_Genes/gene_lists/ALL/Final_Tables/all_drug_gene_counts.csv
Drug table shape: (21863, 5)
First 5 rows of drug table:
          Drug  Gene_Count                                              Genes  \
0     IMATINIB         210  ABL1, ABL2, ACE, ACHE, ACTR2, ACVR1, ACVR1B, A...   
1    SIROLIMUS         206  ABL1, ABL2, ACE, ACHE, ACTR2, ACVR1, ACVR1B, A...   
2  PALBOCICLIB         201  ABL1, ABL2, ACHE, ACTR2, ACVR1, ACVR1B, ACVR2A...   
3    SUNITINIB         194  ABL1, ABL2, ACHE, ACTR2, ACVR1, ACVR1B, ACVR2A...   
4    SORAFENIB         191  ABL1, ABL2, ACTR2, ACVR1, ACVR1B, ACVR2A, ACVR...   

   Gene_Sources Drug_Sources  
0  CT, DCE_KEGG       ChEMBL  
1  CT, DCE_KEGG       ChEMBL  
2            CT       ChEMBL  
3            CT       ChEMBL  
4            CT       ChEMBL  


# **Ground Truth & Cross-Cohorts**

## **Steps**

**1. Starting Point: Existing Trial**

1. We started from the real clinical trial: **NCT04809974**: 72 participants testing Niagen (Nicotinamide Riboside) for Long COVID cognitive symptoms.
2. **Cohort characteristics:**
   - Sample size: 72 participants
   - Primary target: Cognitive symptoms (Brain Fog)
   - Age mean: 46 years
   - Gender distribution: 70% female, 30% male
   - Race distribution: 91% White, 3% Black, 5% Other
   - Ethnicity distribution: 91% Non-Hispanic, 9% Hispanic
3. We parsed this trial using PlaNet.

---

**2. Create 3 Drug-Specific Versions of NCT04809974**

1. We wanted to "reuse" the same clinical trial structure but plug in **new drugs** that target Long COVID causal driver genes:
   * **Leronlimab** (1 causal gene target)
   * **Ritonavir** (70 causal gene targets)
   * **Vortioxetine** (11 causal gene targets)

2. For that, we created three modified versions of the trial as JSON files:
   * `NCT04809974_leronlimab.json`
   * `NCT04809974_ritonavir.json`
   * `NCT04809974_vortioxetine.json`

3. Inside each file, the overall trial design (population, outcomes, eligibility criteria, etc.) remains the same as NCT04809974.
   What changes is how the **intervention / arm descriptions** refer to the drug (swapping Niagen for Leronlimab, Ritonavir, or Vortioxetine).

---

**2.1 Detailed Modifications in Hybrid Trial JSON Files**

To ensure the PlaNet prediction model correctly recognizes each drug (rather than defaulting to Nicotinamide Riboside features embedded throughout the original trial), we performed comprehensive modifications across multiple JSON sections. The table below summarizes what was changed vs. preserved:

**Summary of Field Sources**

| JSON Section | Source | Modification Details |
|--------------|--------|---------------------|
| `identificationModule` | Modified | NCT ID changed to `NCT04809974_{drug}` to indicate synthetic trial |
| `descriptionModule` | **Rewritten** | Drug mechanism + Long COVID cognitive symptom context combined |
| `conditionsModule` | NCT04809974 | Preserved: Long COVID, Brain Fog, Cognitive Impairment, Fatigue |
| `designModule` | Modified | Changed to randomized, double-blind, placebo-controlled |
| `eligibilityModule` | NCT04809974 | Preserved: Ages 18-65, brain fog + ≥2 neuro/physical symptoms, full inclusion/exclusion criteria |
| `armsInterventionsModule` | **Replaced** | Drug-specific arms, doses, routes, durations |
| `outcomesModule` | Adapted | Nicotinamide-specific measures removed; generic Long COVID cognitive outcomes retained |
| `derivedSection.interventionBrowseModule` | **Replaced** | Drug-specific MeSH terms |
| `resultsSection` | **Removed** | No real results exist for hypothetical drug-population combination |
| `hasResults` | Set to `false` | Indicates no results data present |

---

**2.1.1 Leronlimab (`NCT04809974_leronlimab.json`)**

**Drug Source Trial:** NCT04678830 (CytoDyn Phase 2 Long COVID trial, n=56)

**Causal Driver Gene Targets:** 1 gene → **CCR5**

**Key Modifications:**

| Field | Original (Niagen) | Modified (Leronlimab) |
|-------|-------------------|----------------------|
| **briefTitle** | "...Following Nicotinamide Riboside Therapy" | "...Following Leronlimab Therapy" |
| **briefSummary** | Describes NAD+ precursor supplementation | Describes Leronlimab as CCR5 antagonist for Long COVID |
| **detailedDescription** | Nicotinamide Riboside dosing, NAD+ metabolism | Leronlimab mechanism (humanized IgG4 mAb), 700mg weekly SC |
| **Arms** | Placebo, Niagen | Placebo, 700mg Leronlimab |
| **Interventions** | Nicotinamide Riboside oral capsules | Leronlimab (PRO 140) weekly SC injection |
| **Intervention MeSH** | `D009536` (Nicotinamide) | `C420063` (leronlimab) |

**Leronlimab-Specific Protocol Details:**
- **Dose:** 700 mg weekly (two 350mg vials)
- **Route:** Subcutaneous injection
- **Duration:** 8 weeks
- **Formulation:** 175mg/mL in histidine/glycine/NaCl/sorbitol/polysorbate 20 buffer
- **Mechanism:** Humanized IgG4 monoclonal antibody targeting CCR5

---

**2.1.2 Vortioxetine (`NCT04809974_vortioxetine.json`)**

**Drug Source Trial:** NCT05047952 (Brain and Cognition Discovery Foundation Phase 2 Post-COVID trial, n=149)

**Causal Driver Gene Targets:** 11 genes → **HDAC6, ADRB2, HRH2, HTR1A, HTR1B, HTR1D, HTR2C, HTR3A, HTR6, HTR7**

**Key Modifications:**

| Field | Original (Niagen) | Modified (Vortioxetine) |
|-------|-------------------|------------------------|
| **briefTitle** | "...Following Nicotinamide Riboside Therapy" | "...Following Vortioxetine Therapy" |
| **briefSummary** | Describes NAD+ precursor supplementation | Describes Vortioxetine as pro-cognitive antidepressant |
| **detailedDescription** | Nicotinamide Riboside dosing, NAD+ metabolism | Vortioxetine multimodal mechanism, age-based dosing |
| **Arms** | Placebo, Niagen | Placebo, Vortioxetine |
| **Interventions** | Nicotinamide Riboside oral capsules | Vortioxetine oral tablets |
| **Intervention MeSH** | `D009536` (Nicotinamide) | `D000078784` (Vortioxetine) |
| **Primary Outcomes** | Cognitive measures | DSST cognitive score added as primary |

**Vortioxetine-Specific Protocol Details:**
- **Dose (18-64 years):** 10 mg daily × 2 weeks → 20 mg daily × 6 weeks
- **Dose (65+ years):** 5 mg daily × 2 weeks → 10 mg daily × 6 weeks
- **Route:** Oral
- **Duration:** 8 weeks
- **Mechanism:** Multimodal antidepressant with pro-cognitive, anti-inflammatory properties
- **Trade Names:** Trintellix, Brintellix

---

**2.1.3 Ritonavir (`NCT04809974_ritonavir.json`)**

**Drug Source Trial:** NCT05576662 (Stanford STOP-PASC Paxlovid trial, n=168)

**Note:** The source trial used Nirmatrelvir + Ritonavir (Paxlovid). We extracted **only Ritonavir** for this hybrid trial.

**Causal Driver Gene Targets:** 70 genes → **ACE, ACHE, ADORA1, ADORA2A, ADRA1D, ADRA2B, ADRA2C, ADRB2, AR, AVPR2, BDKRB1, BDKRB2, CA2, CALCR, CASP1, CCKAR, CCR2, CCR4, CCR5, CHRM1, CHRM2, CHRM4, CHRM5, CNR1, CRHR2, DRD1, DRD3, DRD4, EDNRB, EGFR, ERBB2, ESR1, ESR2, FYN, GCGR, GHSR, HDAC6, HRH2, HTR1A, HTR1B, HTR2B, HTR2C, HTR3A, HTR6, HTR7, HTT, INSR, LCK, MAPK1, MAPK14, MAPK3, MC3R, MC5R, MEN1, MMP1, MMP9, NOS1, NR3C1, NTSR1, OPRK1, PGR, PRKACA, PRKCA, PTAFR, SIGMAR1, SLC22A1, TACR1, TACR2, VDR, VIPR1**

**Key Modifications:**

| Field | Original (Niagen) | Modified (Ritonavir) |
|-------|-------------------|---------------------|
| **briefTitle** | "...Following Nicotinamide Riboside Therapy" | "...Following Ritonavir Therapy" |
| **briefSummary** | Describes NAD+ precursor supplementation | Describes Ritonavir as antiviral/immunomodulator |
| **detailedDescription** | Nicotinamide Riboside dosing, NAD+ metabolism | Ritonavir CYP3A4 inhibition, antiviral properties |
| **Arms** | Placebo, Niagen | Placebo, Ritonavir |
| **Interventions** | Nicotinamide Riboside oral capsules | Ritonavir oral capsules q12h |
| **Intervention MeSH** | `D009536` (Nicotinamide) | `D019438` (Ritonavir) |

**Ritonavir-Specific Protocol Details:**
- **Dose:** 100 mg every 12 hours
- **Route:** Oral capsule
- **Duration:** 15 days
- **Mechanism:** HIV protease inhibitor, potent CYP3A4 inhibitor, potential immunomodulatory effects
- **Trade Name:** Norvir

---

**2.1.4 Why These Changes Matter for PlaNet**

The original simple approach of changing only the drug name in `armsInterventionsModule` failed because:

1. **Drug references appear in 9+ locations** throughout the NCT04809974 JSON:
   - `descriptionModule.briefSummary`
   - `descriptionModule.detailedDescription`
   - `conditionsModule.keywords`
   - `armsInterventionsModule`
   - `resultsSection.participantFlowModule.groups`
   - `resultsSection.baselineCharacteristicsModule.groups`
   - `resultsSection.outcomeMeasuresModule`
   - `resultsSection.adverseEventsModule.eventGroups`
   - `derivedSection.interventionBrowseModule.meshes`

2. **PlaNet extracts features from the entire document**, not just the drug name field. With ~90% of text still describing Nicotinamide Riboside therapy, the model predicted Nicotinamide Riboside outcomes regardless of the drug label change.

3. **The hybrid approach solves this by:**
   - Removing the `resultsSection` entirely (eliminates all Niagen result references)
   - Rewriting `descriptionModule` with drug-specific mechanism text
   - Replacing `armsInterventionsModule` with drug-appropriate protocols
   - Updating `derivedSection.interventionBrowseModule` with correct MeSH terms
   - Keeping only the population-relevant fields (eligibility, conditions) from the original

---

**3. Parse and Enrich the Three New Trials**

1. We used the PlaNet parsing pipeline (`parse_trial.py`) to:
   * Read each of the three JSON versions of the trial.
   * Extract and normalize all key trial fields (NCT ID, arms, interventions, conditions, outcomes, eligibility, etc.).

2. During parsing, several enrichment steps occur:
   * Drug information is refined using **Medex** (a clinical NLP tool) to detect drug names, formulations, and doses from free text.
   * Eligibility criteria are converted into **UMLS concepts**.
   * Conditions are mapped to disease IDs using **disease ontologies**.
   * Drugs are mapped to **standardized drug identifiers**.
   * Population features are extracted (e.g., age, sex, enrollment, etc.).
   * A **trial-arm–level knowledge graph** representation is built (nodes and edges linking drugs, diseases, and other biomedical entities).

3. For each new drug-trial combination, the pipeline produced trial-level and arm-level objects.

4. Output names:
   * `trial_data_NCT04809974_leronlimab.pkl`
   * `parsed_trial_NCT04809974_leronlimab.json`
   * a corresponding summary file
     (and similarly for ritonavir and vortioxetine).

---

**4. Run PlaNet Prediction for All New Trials**

1. Next, we used a PBS script (`PlaNet_Predict_Many`) on GADI (supercomputer) to:
   * Scan the folder **parsing_package/LC_Results/**
   * Automatically find all files that contain the trial representations needed by PlaNet.

2. For each trial (leronlimab, ritonavir, vortioxetine), PlaNet:
   * Loaded the prepared trial-arm representation.
   * Ran the trained PlaNet models for:
     * **Adverse events (AE)** risk per arm.
     * **Overall safety score** per arm.
     * **Efficacy comparison** between arms (for example, the probability that arm 1 is more effective than arm 2).

3. The prediction results were written as JSON files.

4. These result files have a consistent structure:
   * A "meta" section describing the trial and each arm.
   * An "AE" section with predicted probabilities for many adverse events, for each arm.
     * The adverse events are indexed by numeric codes (for example 173, 9, 30, 21, etc.).
   * A "safety" section with safety scores per arm.
   * An "efficacy" section with comparative efficacy metrics between arms.

---

**5. Understanding the AE Codes**

1. In the "AE" block of each results JSON, we have entries like:
   * `trial_1_ae`: 173, 9, 30, 21, … with associated probabilities.
   * `trial_2_ae` and `trial_3_ae` with the same index space.

2. These numbers **are not MedDRA codes**; they are **internal indices** in the PlaNet AE vocabulary used during model training (for example, index 173 might correspond to a specific named adverse event like "Fatigue").

3. That's why the raw result files alone are not interpretable in clinical language — they require a separate mapping from:

   * AE index (for example 173) → AE name (for example "Fatigue").

---

**6. Use the AE Index → AE Name Mapping**

1. PlaNet ships with a mapping file that links these indices to human-readable AE names:
   * File name: `ae1017_idx2aename.pkl`
   * Location: `notebooks/small_data/` within the PlaNet tree.

2. That file contains:
   * Total number of adverse events in the vocabulary (around 1017).
   * A dictionary of index → AE name, for example:
     * 0 → Abdominal pain
     * 1 → Abdominal distension
     * … and so on.

3. To interpret the results, we decided to:
   * Read each prediction JSON from `parsing_package/results/`.
   * For every AE index in `trial_1_ae`, `trial_2_ae`, and `trial_3_ae`:
     * Look up the corresponding AE name using the mapping file.
     * Pair the AE code, AE name, and predicted probability.

---

**7. Mapping Predictions into Human-Readable Outputs**

1. We then planned and wrote a mapping script (`map_AE.py`) whose role is:
   * To locate all PlaNet prediction results in `parsing_package/results/` (files starting with "result_" and ending with ".json").
   * For each file:
     * Keep the original meta, safety, and efficacy fields.
     * Replace each AE index with a human-readable AE name, while preserving the predicted probabilities.

2. For every trial and every arm, the script produces structured output such as:
   * Trial arm identifier (for example, `trial_1_ae`).
   * AE index (to keep the original reference).
   * AE name (from the mapping).
   * Probability predicted by PlaNet.

3. The mapped outputs are saved back into `parsing_package/results/` as new files (for example, with a prefix like "mapped_"), giving:
   * A clean, readable representation where you can directly see "Fatigue", "Headache", "Nausea", etc., and their predicted probabilities for each arm and for each new drug.

---

**8. Final Picture for the New Drugs on NCT04809974**

Putting it all together, for each of the three drugs (leronlimab, ritonavir, vortioxetine) on the NCT04809974 Long COVID cognitive symptom cohort:

1. We created a **drug-specific JSON version** of the trial and stored it.
2. We **parsed and enriched** that trial into PlaNet's internal representation.
3. We ran **PlaNet prediction** to predict adverse event profiles, safety, and efficacy.
4. We used the **AE index → AE name mapping** to convert numeric AE codes into human-readable adverse event names for every trial arm.
5. We now have:
   * **Realistic trial-arm predictions** for each new drug in a cognitive-focused Long COVID population.
   * **Interpretable adverse event profiles** (names plus probabilities) that you can tabulate and compare across drugs and arms in your paper.

---

**Appendix A: Drug–Causal Gene Target Summary**

| Drug | Causal Gene Targets (n) | Genes |
|------|------------------------|-------|
| **Leronlimab** | 1 | CCR5 |
| **Vortioxetine** | 11 | HDAC6, ADRB2, HRH2, HTR1A, HTR1B, HTR1D, HTR2C, HTR3A, HTR6, HTR7 |
| **Ritonavir** | 70 | ACE, ACHE, ADORA1, ADORA2A, ADRA1D, ADRA2B, ADRA2C, ADRB2, AR, AVPR2, BDKRB1, BDKRB2, CA2, CALCR, CASP1, CCKAR, CCR2, CCR4, CCR5, CHRM1, CHRM2, CHRM4, CHRM5, CNR1, CRHR2, DRD1, DRD3, DRD4, EDNRB, EGFR, ERBB2, ESR1, ESR2, FYN, GCGR, GHSR, HDAC6, HRH2, HTR1A, HTR1B, HTR2B, HTR2C, HTR3A, HTR6, HTR7, HTT, INSR, LCK, MAPK1, MAPK14, MAPK3, MC3R, MC5R, MEN1, MMP1, MMP9, NOS1, NR3C1, NTSR1, OPRK1, PGR, PRKACA, PRKCA, PTAFR, SIGMAR1, SLC22A1, TACR1, TACR2, VDR, VIPR1 |

---

**Appendix B: NCT04809974 Cohort Characteristics**

| Characteristic | Value |
|----------------|-------|
| **Trial ID** | NCT04809974 |
| **Intervention** | Niagen (Nicotinamide Riboside) |
| **Sample Size** | 72 |
| **Primary Target** | Cognitive symptoms (Brain Fog) |
| **Age Mean** | 46 years |
| **Age Range** | 18–65 |
| **Female** | 70% |
| **Male** | 30% |
| **White** | 91% |
| **Black** | 3% |
| **Other Race** | 5% |
| **Non-Hispanic** | 91% |
| **Hispanic** | 9% |

---

**Appendix C: File Locations Summary**

| Stage | Location | Files |
|-------|----------|-------|
| Input JSONs | `parsing_package/Input_ALL/` | `NCT04809974_leronlimab.json`, `NCT04809974_ritonavir.json`, `NCT04809974_vortioxetine.json` |
| Parsed Outputs | `parsing_package/LC_Results/` | `trial_data_*.pkl`, `parsed_trial_*.json` |
| Prediction Results | `parsing_package/results/` | `result_*.json` |
| Mapped Results | `parsing_package/results/` | `mapped_*.json` |
| AE Mapping File | `notebooks/small_data/` | `ae1017_idx2aename.pkl` |

**Multi-Cohort Drug Substitution Workflow**

**1. Overview: Population-Stratified Drug Repurposing**

To generate personalized drug rankings for Long COVID, we applied the PlaNet drug substitution workflow across **three clinically distinct Long COVID cohorts**, each representing different symptom profiles, demographics, and comorbidity patterns. For each cohort, we evaluated **1,776 candidate drugs out of 19,172 drugs** that bind to at least 11 Long COVID causal driver genes, ranked by the number of causal gene targets (with Vortioxetine at rank 1,776 serving as the inclusion threshold of ≥11 causal gene targets).

This cross-cohort design enables:
- **Population-specific safety and efficacy predictions** that account for cohort demographics, comorbidities, and symptom profiles
- **Identification of drugs with consistent benefit** across heterogeneous Long COVID presentations
- **Detection of population-specific risks** that may contraindicate certain drugs in specific subgroups

---

**2. Long COVID Cohort Characteristics**

| Characteristic | **NCT04809974** | **NCT04880161** | **NCT05576662** |
|----------------|-----------------|-----------------|-----------------|
| **Original Intervention** | Niagen (Nicotinamide Riboside) | Ampion (Inhaled Biologic) | Nirmatrelvir/Ritonavir (Paxlovid) |
| **Sample Size (N)** | 72 | 32 | 168 |
| **Primary Symptom Target** | Cognitive (Brain Fog) | Respiratory | Multi-symptom |
| **Age Mean (years)** | 46 | 52 | 43 |
| **Age Range** | 18–65 | 18+ | 18+ |
| **Female (%)** | 70 | 56 | 60 |
| **Male (%)** | 30 | 44 | 40 |
| **White (%)** | 91 | 75 | 85 |
| **Black (%)** | 3 | 25 | 3 |
| **Asian (%)** | — | — | 20 |
| **Other Race (%)** | 5 | — | — |
| **Hispanic (%)** | 9 | NR | 12 |
| **Non-Hispanic (%)** | 91 | NR | 88 |
| **Vaccinated (%)** | NR | NR | 99 (n=153) |
| **Key Eligibility** | Brain fog + ≥2 neuro/physical symptoms | Respiratory symptoms ≥4 weeks | Multi-symptom Long COVID |
| **Key Exclusions** | CNS disease, major psychiatric illness | Severe COPD, CKD, liver disease, CFS | — |

**Cohort Rationale:**
- **NCT04809974 (Cognitive):** Younger, predominantly female, high proportion White; represents neurological/cognitive Long COVID phenotype
- **NCT04880161 (Respiratory):** Older, more balanced sex distribution, higher Black representation; represents pulmonary Long COVID phenotype with stricter comorbidity exclusions
- **NCT05576662 (Multi-symptom):** Largest cohort, highly vaccinated, diverse symptom burden; represents general Long COVID population

---

**3. Drug Selection Strategy**

**3.1 Causal Gene-Based Drug Retrieval**

Candidate drugs were identified through systematic mapping of Long COVID causal driver genes to approved or investigational compounds using curated drug–target databases (DrugBank, ChEMBL, DGIdb, STITCH, TTD).

**Inclusion Criteria:**
- Drug targets ≥1 Long COVID causal driver gene identified via TWMR, CT, or DCE
- Drug has existing pharmacokinetic and safety data (approved or in clinical development)

**Ranking Strategy:**
- Drugs ranked by **number of causal driver genes targeted** (descending)
- Higher rank = more causal gene targets = stronger mechanistic engagement with Long COVID biology

**3.2 Drug Inclusion Threshold**

| Parameter | Value |
|-----------|-------|
| **Total candidate drugs** | 1,776 |
| **Ranking criterion** | Number of Long COVID causal driver genes targeted |
| **Rank 1 (highest)** | Drug with most causal gene targets |
| **Rank 1,776 (threshold)** | Vortioxetine (11 causal gene targets) |
| **Minimum causal genes for inclusion** | ≥11 (based on Vortioxetine threshold) |

**3.3 Reference Drugs from Validation Analysis**

Three drugs from our initial validation analysis span the causal gene target spectrum:

| Drug | Causal Gene Targets (n) | Approximate Rank | Target Examples |
|------|------------------------|------------------|-----------------|
| **Ritonavir** | 70 | High (top tier) | ACE, CCR5, MAPK1, EGFR, HTR1A, DRD1, etc. |
| **Vortioxetine** | 11 | 1,776 (threshold) | HDAC6, ADRB2, HTR1A, HTR1B, HTR2C, HTR3A, HTR6, HTR7, etc. |
| **Leronlimab** | 1 | Below threshold | CCR5 |

**Note:** Leronlimab (1 causal gene target) falls below the inclusion threshold but was retained in validation analyses as a mechanistically focused comparator.

---

**4. Cross-Cohort Drug Substitution Workflow**

**4.1 Workflow Overview**

For each of the **1,776 candidate drugs** × **3 cohorts** = **5,328 drug–cohort combinations**, we:

1. **Created drug-specific trial JSON files** preserving cohort-specific population characteristics
2. **Parsed and enriched** each synthetic trial into PlaNet's knowledge graph representation
3. **Generated predictions** for adverse events (AE), safety scores (S), and comparative efficacy (E)
4. **Mapped AE indices** to human-readable adverse event names
5. **Computed composite priority scores** integrating S, AE, and E predictions

**4.2 Cohort-Specific Trial Modifications**

For each cohort, the following fields were **preserved** to maintain population specificity:

| Preserved Field | NCT04809974 | NCT04880161 | NCT05576662 |
|-----------------|-------------|-------------|-------------|
| `eligibilityModule` | Ages 18–65, brain fog + ≥2 symptoms, excludes CNS/psychiatric disease | Ages 18+, respiratory ≥4 weeks, excludes COPD/CKD/CFS | Ages 18+, multi-symptom Long COVID |
| `conditionsModule` | Long COVID, Brain Fog, Cognitive Impairment | Long COVID, Respiratory symptoms, Dyspnea | Long COVID, Fatigue, Multi-organ |
| Population demographics | 70% F, 91% White, mean age 46 | 56% F, 75% White, mean age 52 | 60% F, 85% White, 99% vaccinated, mean age 43 |

For each cohort, the following fields were **modified** for each candidate drug:

| Modified Field | Modification |
|----------------|--------------|
| `identificationModule` | NCT ID changed to `{original_NCT}_{drug_name}` |
| `descriptionModule` | Rewritten with drug-specific mechanism and indication |
| `armsInterventionsModule` | Replaced with drug-appropriate arms, doses, routes, durations |
| `derivedSection.interventionBrowseModule` | Updated with drug-specific MeSH terms |
| `resultsSection` | Removed (no real results for synthetic combinations) |
| `outcomesModule` | Adapted to remove original drug-specific measures |

---

**5. PlaNet Prediction Pipeline**

**5.1 Prediction Outputs**

For each drug–cohort combination, PlaNet generates:

| Output | Description | Scale |
|--------|-------------|-------|
| **Adverse Event Profile** | Predicted probability for each of 1,017 adverse events | 0–1 per AE |
| **Safety Score (S)** | Global safety metric aggregating AE risk | Higher = safer |
| **Efficacy Probability (E)** | Comparative efficacy vs. placebo | 0–1 (probability drug > placebo) |

---

**6. Cross-Cohort Analysis Outputs**

**6.1 Cohort-Specific Drug Rankings**

For each cohort, drugs are ranked by composite priority score, generating:

| Output File | Description |
|-------------|-------------|
| `rankings_NCT04809974_cognitive.csv` | Top drugs for cognitive Long COVID phenotype |
| `rankings_NCT04880161_respiratory.csv` | Top drugs for respiratory Long COVID phenotype |
| `rankings_NCT05576662_multisymptom.csv` | Top drugs for multi-symptom Long COVID phenotype |

**6.2 Cross-Cohort Consistency Analysis**

To identify drugs with robust benefit across populations:

| Analysis | Description |
|----------|-------------|
| **Consensus ranking** | Drugs ranked highly across all 3 cohorts |
| **Cohort-specific candidates** | Drugs with strong signal in one cohort but not others |
| **Safety flags** | Drugs with cohort-specific safety concerns (e.g., respiratory AEs in NCT04880161) |

**6.3 Population-Specific Safety Considerations**

| Cohort | Key Safety Considerations |
|--------|---------------------------|
| **NCT04809974 (Cognitive)** | CNS adverse events, psychiatric effects, drug–drug interactions with neurological medications |
| **NCT04880161 (Respiratory)** | Pulmonary AEs, contraindications with COPD/CKD exclusions, older population tolerability |
| **NCT05576662 (Multi-symptom)** | Broad AE profile, vaccine interaction potential, polypharmacy risks in multi-organ involvement |

---

**7. Computational Infrastructure**

```
┌─────────────────────────────────────────────────────────────────────────┐
│                        CROSS-COHORT PIPELINE                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐                 │
│  │ NCT04809974 │    │ NCT04880161 │    │ NCT05576662 │                 │
│  │ (Cognitive) │    │(Respiratory)│    │(Multi-sympt)│                 │
│  └──────┬──────┘    └──────┬──────┘    └──────┬──────┘                 │
│         │                  │                  │                         │
│         ▼                  ▼                  ▼                         │
│  ┌────────────────────────────────────────────────────────────────┐    │
│  │              1,776 Candidate Drugs (≥11 causal genes)          │    │
│  └────────────────────────────────────────────────────────────────┘    │
│         │                  │                  │                         │
│         ▼                  ▼                  ▼                         │
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐                 │
│  │  Generate   │    │  Generate   │    │  Generate   │                 │
│  │ 1,776 JSONs │    │ 1,776 JSONs │    │ 1,776 JSONs │                 │
│  └──────┬──────┘    └──────┬──────┘    └──────┬──────┘                 │
│         │                  │                  │                         │
│         ▼                  ▼                  ▼                         │
│  ┌────────────────────────────────────────────────────────────────┐    │
│  │                    PlaNet Parsing & Enrichment                  │    │
│  └────────────────────────────────────────────────────────────────┘    │
│         │                  │                  │                         │
│         ▼                  ▼                  ▼                         │
│  ┌────────────────────────────────────────────────────────────────┐    │
│  │              PlaNet Prediction (S, AE, E per arm)               │    │
│  └────────────────────────────────────────────────────────────────┘    │
│         │                  │                  │                         │
│         ▼                  ▼                  ▼                         │
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐                 │
│  │  Cognitive  │    │ Respiratory │    │Multi-symptom│                 │
│  │  Rankings   │    │  Rankings   │    │  Rankings   │                 │
│  └──────┬──────┘    └──────┬──────┘    └──────┬──────┘                 │
│         │                  │                  │                         │
│         └──────────────────┼──────────────────┘                         │
│                            ▼                                            │
│  ┌────────────────────────────────────────────────────────────────┐    │
│  │                Cross-Cohort Consensus Analysis                  │    │
│  └────────────────────────────────────────────────────────────────┘    │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
```

---

**8. File Structure**

| Location | Contents |
|----------|----------|
| `parsing_package/Input_ALL/NCT04809974/` | 1,776 drug-substituted JSONs for cognitive cohort |
| `parsing_package/Input_ALL/NCT04880161/` | 1,776 drug-substituted JSONs for respiratory cohort |
| `parsing_package/Input_ALL/NCT05576662/` | 1,776 drug-substituted JSONs for multi-symptom cohort |

| Location | Contents |
|----------|----------|
| `parsing_package/results/NCT04809974/` | Raw PlaNet predictions (`result_*.json`) |
| `parsing_package/results/NCT04880161/` | Raw PlaNet predictions (`result_*.json`) |
| `parsing_package/results/NCT05576662/` | Raw PlaNet predictions (`result_*.json`) |
| `parsing_package/results/mapped/` | Human-readable AE mappings (`mapped_*.json`) |
| `analysis/rankings/` | Cohort-specific and consensus drug rankings (`.csv`) |
| `analysis/cross_cohort/` | Cross-cohort comparison tables and figures |

**Reference Files**

| Location | Contents |
|----------|----------|
| `notebooks/small_data/ae1017_idx2aename.pkl` | AE index → AE name mapping |
| `data/drug_rankings/all_drugs_by_causal_genes.csv` | 1,776 drugs ranked by causal gene count |

---

**Appendix A: Cohort Summary Table**

| Cohort | NCT ID | N | Symptom Focus | Age | Female % | White % | Vaccinated % |
|--------|--------|---|---------------|-----|----------|---------|--------------|
| Cognitive | NCT04809974 | 72 | Brain Fog | 46 | 70 | 91 | NR |
| Respiratory | NCT04880161 | 32 | Pulmonary | 52 | 56 | 75 | NR |
| Multi-symptom | NCT05576662 | 168 | Multi-organ | 43 | 60 | 85 | 99 |

---

**Appendix B: Drug Selection Threshold**

| Rank | Drug Example | Causal Genes (n) | Included |
|------|--------------|------------------|----------|
| 1 | (Top drug) | ~100+ | ✓ |
| ... | ... | ... | ✓ |
| ~50 | Ritonavir | 70 | ✓ |
| ... | ... | ... | ✓ |
| 1,776 | Vortioxetine | 11 | ✓ (threshold) |
| 1,777+ | Leronlimab, etc. | <11 | ✗ |

---

**Appendix C: Cross-Cohort Comparison Matrix**

|  | NCT04809974 (Cognitive) | NCT04880161 (Respiratory) | NCT05576662 (Multi-symptom) |
|--|-------------------------|---------------------------|----------------------------|
| **NCT04809974** | — | Δ Safety: CNS vs. pulmonary AEs | Δ Population: age, vaccination |
| **NCT04880161** | Δ Demographics: race distribution | — | Δ Exclusions: COPD/CKD |
| **NCT05576662** | Δ Size: 72 vs. 168 | Δ Size: 32 vs. 168 | — |

## **Scripts**

### **Parse**

#### **Python Script**

In [None]:
#!/usr/bin/env python3
"""
parse_trial_counterfactual.py (a.k.a. parse_new_drug.py)

Extended version of parse_trial.py that supports:
1. Standard trial parsing (original functionality)
2. Drug replacement for counterfactual analysis (new functionality)
3. Drug lookup/validation in KG / drug table before processing
4. Flexible drug name matching (case-insensitive, synonyms, brand names)
5. Counterfactual leakage controls:
   - removes results-derived content for CF trials
   - replaces brief_summary with a neutral CF summary by default
   - reduces original-drug narrative leakage into arm_text

Key robustness fixes:
A) TrialGraphBuilder API differences:
   - Some PlaNet versions do NOT expose `arm_labels`.
   - We now resolve arm labels via multiple candidate attributes and
     fall back to parsed_trial["arm_group"] order.

B) Base trial not clearly drug-annotated:
   - `find_experimental_drug()` is more permissive (Drug/Biological/Dietary Supplement).
   - If still not found, we attempt to infer a non-placebo arm and create a minimal
     synthetic base intervention so replacement can proceed.

Usage:
    # Standard parsing
    python parse_new_drug.py NCT04678830

    # Check if drug(s) exist in KG (without running full pipeline)
    python parse_new_drug.py --check-drug "tocilizumab"
    python parse_new_drug.py --check-drug "Remdesivir" "baricitinib" "SomeUnknownDrug"

    # Drug replacement (counterfactual)
    python parse_new_drug.py NCT04678830 --replace-drug "Tocilizumab"

    # Multiple replacements
    python parse_new_drug.py NCT04678830 --replace-drug "Tocilizumab" "Baricitinib" "Remdesivir"

    # From existing parsed JSON
    python parse_new_drug.py --from-parsed parsed_trial_NCT04678830.json --replace-drug "Tocilizumab"

    # Skip drugs not found in KG / drug table
    python parse_new_drug.py NCT04678830 -r "tocilizumab" "unknowndrug" --skip-not-found

    # (Not recommended) keep original summary text in CF mode
    python parse_new_drug.py NCT04678830 -r "tocilizumab" --keep-original-summary
"""

import argparse
import requests
import os
import pickle
import math
import numpy as np
import pandas as pd
import tempfile
import json
import subprocess
import shlex
import pathlib
import re
from typing import Any, List, Dict, Optional, Tuple
from copy import deepcopy
from difflib import SequenceMatcher

from data_parsers.external_tools.medex import medex_input
from data_parsers.external_tools import medex

from data_parsers import DiseaseExtract
from data_parsers import CriteriaOutputParser
from data_parsers import DrugMatcher, get_intervention_drug_ids
from data_parsers import OutcomeMeasureExtract
from data_parsers import UMLSConceptSearcher

from data_parsers import UMLSTFIDFMatcher
from data_parsers.umls_utils import UMLSUtils

from knowledge_graph import KnowledgeGraphBuilder
from knowledge_graph.kg import UnionFind
from knowledge_graph.build_graph import TrialGraphBuilder
from knowledge_graph.node_features import TrialAttributeFeatures


DATA_DIR = "data"
RESULTS_ROOT = os.environ.get("RESULTS_DIR", "LC_Results")


# -------------------------
# Helpers
# -------------------------

def ensure_dir(p: str) -> str:
    pathlib.Path(p).mkdir(parents=True, exist_ok=True)
    return p

def save_pkl(obj: Any, path: str) -> None:
    with open(path, "wb") as f:
        pickle.dump(obj, f)

def load_pkl(path: str) -> Any:
    with open(path, "rb") as f:
        return pickle.load(f)

def _to_jsonable(obj: Any) -> Any:
    """
    Recursively convert objects into JSON-serializable forms.
    """
    if obj is None or isinstance(obj, (bool, int, float, str)):
        return obj
    if isinstance(obj, np.generic):
        return obj.item()
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, set):
        try:
            return sorted(_to_jsonable(v) for v in obj)
        except Exception:
            return [_to_jsonable(v) for v in obj]
    if isinstance(obj, (bytes, bytearray)):
        try:
            return obj.decode("utf-8", errors="ignore")
        except Exception:
            return str(obj)
    if isinstance(obj, (pathlib.Path,)):
        return str(obj)
    if isinstance(obj, dict):
        return {str(k): _to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_jsonable(v) for v in obj]
    return str(obj)

def save_json(obj: Any, path: str) -> None:
    with open(path, "w") as f:
        json.dump(_to_jsonable(obj), f, indent=2)

def load_json(path: str) -> Dict:
    with open(path, "r") as f:
        return json.load(f)


# =========================================================================
# DRUG LOOKUP / KG VALIDATION
# =========================================================================

class DrugKGLookup:
    """
    Handles drug name lookup and validation against the Knowledge Graph / drug table.
    Supports flexible matching: case-insensitive, synonyms, brand names, etc.
    """

    def __init__(self, drug_data_path: str = None, entity2cid_path: str = None):
        self.drug_data_path = drug_data_path or f"{DATA_DIR}/drug_data/drugs_all_03_04_21.pkl"
        self.entity2cid_path = entity2cid_path or f"{DATA_DIR}/kg_data/kg-entity2cid-31_7_21.pkl"

        self.drug_data: Optional[pd.DataFrame] = None
        self.entity2cid = None
        self.kgid_to_entity: Dict[Any, Any] = {}

        # lowercase name -> (drugbank_id, canonical_name)
        self.name_to_drugbank: Dict[str, Tuple[Optional[str], str]] = {}
        # All known drug names (lowercase)
        self.all_drug_names: set[str] = set()

        self._loaded = False

    def load(self) -> None:
        if self._loaded:
            return

        print("[INFO] Loading drug lookup data...")

        with open(self.drug_data_path, "rb") as f:
            self.drug_data = pickle.load(f)

        try:
            with open(self.entity2cid_path, "rb") as f:
                self.entity2cid = pickle.load(f)
            self.kgid_to_entity = {v: k for k, v in self.entity2cid.items()}
        except FileNotFoundError:
            print(
                f"[WARNING] KG entity2cid mapping not found at {self.entity2cid_path}. "
                "KG ID resolution may be limited."
            )
            self.entity2cid = {}
            self.kgid_to_entity = {}

        self._build_name_index()

        self._loaded = True
        print(f"[INFO] Loaded {len(self.name_to_drugbank)} drug name variants")

    def _build_name_index(self) -> None:
        # CASE 1: DataFrame
        if hasattr(self.drug_data, "iterrows"):
            df: pd.DataFrame = self.drug_data
            columns = list(df.columns)

            # Find DrugBank ID column
            id_col = None
            for col in [
                "drugbank-id", "drugbank_id", "id", "DrugBank ID",
                "drugbank", "drugbank-primary-id", "drugbank-id-primary",
                "primary_id"
            ]:
                if col in columns:
                    id_col = col
                    break

            if id_col is None:
                for col in columns:
                    series = df[col]
                    try:
                        non_na = series.dropna().astype(str).head(50)
                    except Exception:
                        continue
                    if any(re.match(r"^DB\d+", v) for v in non_na):
                        id_col = col
                        print(f"[INFO] Heuristically using '{id_col}' as DrugBank ID column")
                        break

            if id_col is None:
                print("[WARNING] Could not find DrugBank ID column in drug dataframe.")
                print(f"[WARNING] Available columns (first 15): {columns[:15]}")
                print("[WARNING] Will build lookup index without DrugBank IDs; KG ID resolution may be limited.")

            # Find name column
            name_col = None
            for col in ["name", "primary_name", "Name"]:
                if col in columns:
                    name_col = col
                    break
            if name_col is None:
                for col in columns:
                    if "name" in col.lower():
                        name_col = col
                        break

            if name_col is None:
                print("[ERROR] Could not find a drug name column in the drug dataframe.")
                print(f"[ERROR] Available columns (first 15): {columns[:15]}")
                return

            # Synonyms column
            syn_col = None
            for cand in ["synonyms", "Synonyms", "synonym"]:
                if cand in columns:
                    syn_col = cand
                    break

            for idx in range(len(df)):
                row = df.iloc[idx]

                drugbank_id: Optional[str] = None
                if id_col is not None:
                    val = row[id_col]
                    if pd.notna(val):
                        drugbank_id = str(val).strip()

                primary_name = row[name_col]
                if isinstance(primary_name, str) and primary_name.strip():
                    primary_name = primary_name.strip()
                    name_lower = primary_name.lower()
                    self.name_to_drugbank[name_lower] = (drugbank_id, primary_name)
                    self.all_drug_names.add(name_lower)

                if syn_col is not None:
                    synonyms_raw = row[syn_col]
                    if synonyms_raw is None:
                        continue
                    if isinstance(synonyms_raw, float) and pd.isna(synonyms_raw):
                        continue

                    syn_names: List[str] = []

                    if isinstance(synonyms_raw, dict):
                        syn_data = synonyms_raw.get("synonym", [])
                        if isinstance(syn_data, (list, tuple, set, np.ndarray)):
                            syn_names = [s for s in syn_data if isinstance(s, str)]
                        elif isinstance(syn_data, str):
                            syn_names = [syn_data]

                    elif isinstance(synonyms_raw, (list, tuple, set, np.ndarray)):
                        syn_names = [s for s in synonyms_raw if isinstance(s, str)]

                    elif isinstance(synonyms_raw, str):
                        syn_names = [synonyms_raw]

                    for syn in syn_names:
                        if syn and syn.strip():
                            syn_lower = syn.lower().strip()
                            self.name_to_drugbank[syn_lower] = (drugbank_id, primary_name)
                            self.all_drug_names.add(syn_lower)

        # CASE 2: dict-style structure
        elif isinstance(self.drug_data, dict):
            for drugbank_id, info in self.drug_data.items():
                if isinstance(info, dict):
                    names = []
                    primary_name = info.get("primary_name") or info.get("name")
                    if primary_name:
                        names.append(primary_name)

                    for key in [
                        "synonyms", "international_brand_names",
                        "product_names", "other_names"
                    ]:
                        vals = info.get(key, [])
                        if isinstance(vals, list):
                            names.extend(vals)

                    canonical = primary_name or drugbank_id
                    for name in names:
                        if name and isinstance(name, str):
                            name_lower = name.lower().strip()
                            self.name_to_drugbank[name_lower] = (drugbank_id, canonical)
                            self.all_drug_names.add(name_lower)

    def _normalize_name(self, name: str) -> str:
        name = name.lower().strip()
        name = re.sub(r"\s*\(.*?\)\s*", " ", name)  # Remove parentheticals
        name = re.sub(r"\s+", " ", name)
        return name.strip()

    def _similarity_score(self, name1: str, name2: str) -> float:
        return SequenceMatcher(None, name1.lower(), name2.lower()).ratio()

    def lookup(self, drug_name: str, fuzzy_threshold: float = 0.85) -> Dict[str, Any]:
        self.load()

        result: Dict[str, Any] = {
            "found": False,
            "input_name": drug_name,
            "matched_name": None,
            "canonical_name": None,
            "drugbank_id": None,
            "kg_id": None,
            "match_type": None,
            "similarity": None,
            "suggestions": [],
        }

        normalized = self._normalize_name(drug_name)

        # Exact
        if normalized in self.name_to_drugbank:
            drugbank_id, canonical = self.name_to_drugbank[normalized]
            result.update({
                "found": True,
                "matched_name": normalized,
                "canonical_name": canonical,
                "drugbank_id": drugbank_id,
                "match_type": "exact",
                "similarity": 1.0,
            })
        else:
            # Fuzzy
            best_match = None
            best_score = 0.0
            for known_name in self.all_drug_names:
                score = self._similarity_score(normalized, known_name)
                if score > best_score:
                    best_score = score
                    best_match = known_name

            if best_match is not None and best_score >= fuzzy_threshold:
                drugbank_id, canonical = self.name_to_drugbank[best_match]
                result.update({
                    "found": True,
                    "matched_name": best_match,
                    "canonical_name": canonical,
                    "drugbank_id": drugbank_id,
                    "match_type": "fuzzy",
                    "similarity": best_score,
                })
            else:
                # Suggestions
                suggestions = []
                for known_name in self.all_drug_names:
                    score = self._similarity_score(normalized, known_name)
                    if score >= 0.5:
                        suggestions.append((known_name, score))
                suggestions.sort(key=lambda x: x[1], reverse=True)
                result["suggestions"] = [s[0] for s in suggestions[:5]]

        # Try KG ID resolution
        if result["drugbank_id"] and self.entity2cid:
            target = result["drugbank_id"]
            for entity_key, kg_id in self.entity2cid.items():
                try:
                    if hasattr(entity_key, "uid") and target in str(entity_key.uid):
                        result["kg_id"] = kg_id
                        break
                    if isinstance(entity_key, str) and target in entity_key:
                        result["kg_id"] = kg_id
                        break
                except Exception:
                    continue

        return result

    def check_drugs(self, drug_names: List[str], verbose: bool = True) -> List[Dict[str, Any]]:
        self.load()
        results: List[Dict[str, Any]] = []

        if verbose:
            print("\n" + "=" * 70)
            print("DRUG KNOWLEDGE GRAPH LOOKUP")
            print("=" * 70)

        for drug_name in drug_names:
            result = self.lookup(drug_name)
            results.append(result)
            if verbose:
                self._print_lookup_result(result)

        if verbose:
            print("=" * 70)
            found_count = sum(1 for r in results if r["found"])
            print(f"SUMMARY: {found_count}/{len(results)} drugs found in KG/drug table")
            print("=" * 70 + "\n")

        return results

    def _print_lookup_result(self, result: Dict[str, Any]) -> None:
        print(f"\n  Input: '{result['input_name']}'")
        if result["found"]:
            status = "✅ FOUND"
            if result["match_type"] == "fuzzy":
                status += f" (fuzzy match, {result['similarity']:.0%} similar)"
            print(f"  Status: {status}")
            print(f"  Canonical Name: {result['canonical_name']}")
            print(f"  DrugBank ID: {result['drugbank_id']}")
            if result["kg_id"]:
                print(f"  KG ID: {result['kg_id']}")
            else:
                print("  KG ID: (will be resolved during processing if possible)")
        else:
            print("  Status: ❌ NOT FOUND")
            if result["suggestions"]:
                print(f"  Did you mean: {', '.join(result['suggestions'][:3])}?")

    def get_canonical_name(self, drug_name: str) -> Optional[str]:
        result = self.lookup(drug_name)
        return result["canonical_name"] if result["found"] else None

    def get_drugbank_id(self, drug_name: str) -> Optional[str]:
        result = self.lookup(drug_name)
        return result["drugbank_id"] if result["found"] else None


_drug_lookup: Optional[DrugKGLookup] = None

def get_drug_lookup() -> DrugKGLookup:
    global _drug_lookup
    if _drug_lookup is None:
        _drug_lookup = DrugKGLookup()
    return _drug_lookup

def check_drug_in_kg(drug_name: str) -> Dict[str, Any]:
    return get_drug_lookup().lookup(drug_name)

def check_drugs_in_kg(drug_names: List[str], verbose: bool = True) -> List[Dict[str, Any]]:
    return get_drug_lookup().check_drugs(drug_names, verbose=verbose)


# -------------------------
# Data loading
# -------------------------

def get_clinical_trial_data(nctid: str) -> Dict[str, Any]:
    """
    Try local JSON first, else ClinicalTrials.gov API v2.
    """
    candidates = []
    if nctid.lower().endswith(".json"):
        candidates.append(nctid if os.path.isabs(nctid) else os.path.join(os.getcwd(), nctid))
    else:
        candidates.append(os.path.join("/app/studies", f"{nctid}.json"))
        candidates.append(os.path.join(os.getcwd(), "LC_Clinical_Trials", f"{nctid}.json"))
        candidates.append(os.path.join(os.getcwd(), f"{nctid}.json"))

    for path in candidates:
        if os.path.isfile(path):
            try:
                with open(path, "r") as f:
                    print(f"[INFO] Loading local JSON for {nctid} from {path}")
                    return json.load(f)
            except Exception as e:
                return {"error": f"Failed to read local JSON {path}: {e}"}

    base_url = "https://clinicaltrials.gov/api/v2/studies"
    request_url = f"{base_url}/{nctid}"
    try:
        print(f"[INFO] Fetching {nctid} from ClinicalTrials.gov API: {request_url}")
        response = requests.get(request_url, timeout=30)
        if response.status_code == 200:
            return response.json()
        return {"error": f"Failed to fetch data. Status: {response.status_code}. Msg: {response.text}"}
    except Exception as e:
        return {"error": str(e)}


# -------------------------
# Parsing / enrichment
# -------------------------

def parse(nctid: str) -> Dict[str, Any]:
    trial_data = get_clinical_trial_data(nctid)
    if "error" in trial_data:
        raise RuntimeError(f"Error fetching data: {trial_data['error']}")

    attributes = {
        "nct_id": "protocolSection.identificationModule.nctId",
        "arm_group": "protocolSection.armsInterventionsModule.armGroups",
        "intervention": "protocolSection.armsInterventionsModule.interventions",
        "condition": "protocolSection.conditionsModule.conditions",
        # Keep original path to avoid breaking expectations
        "intervention_mesh_terms": "derivedSection.conditionBrowseModule.meshes",
        "event_groups": "resultsSection.adverseEventsModule.eventGroups",
        "primary_outcome": "protocolSection.outcomesModule.primaryOutcomes",
        "secondary_outcome": "protocolSection.outcomesModule.secondaryOutcomes",
        "eligibility_criteria": "protocolSection.eligibilityModule.eligibilityCriteria",
        "brief_summary": "protocolSection.descriptionModule.briefSummary",
        "phase": "protocolSection.designModule.phases",
        "enrollment": "protocolSection.designModule.enrollmentInfo",
        "gender_sex": "protocolSection.eligibilityModule.sex",
        "minimum_age": "protocolSection.eligibilityModule.minimumAge",
        "maximum_age": "protocolSection.eligibilityModule.maximumAge",
    }

    parsed_trial: Dict[str, Any] = {}
    for attribute, path in attributes.items():
        val: Any = trial_data
        for component in path.split("."):
            if not isinstance(val, dict) or component not in val:
                val = None
                break
            val = val[component]
        parsed_trial[attribute] = val

    parsed_trial.setdefault("arm_group", [])
    parsed_trial.setdefault("intervention", [])

    for arm_group in parsed_trial["arm_group"]:
        if "label" in arm_group:
            arm_group["arm_group_label"] = arm_group.pop("label")

    for intervention in parsed_trial["intervention"]:
        if "type" in intervention:
            intervention["intervention_type"] = intervention.pop("type").title()
        if "name" in intervention:
            intervention["intervention_name"] = intervention.pop("name")
        if "otherNames" in intervention:
            intervention["other_name"] = intervention.pop("otherNames")
        if "armGroupLabels" in intervention:
            intervention["arm_group_label"] = intervention.pop("armGroupLabels")

    parsed_trial["clinical_results"] = (
        {"reported_events": {"group_list": {"group": parsed_trial.pop("event_groups")}}}
        if parsed_trial.get("event_groups") is not None
        else {}
    )

    return parsed_trial


def run_medex_and_parse_output(parsed_trial: Dict[str, Any]) -> Dict[str, Any]:
    result: Dict[str, Any] = {}
    classpath = (
        "resources/medex/Medex_UIMA_1.3.8/bin:"
        "resources/medex/Medex_UIMA_1.3.8/lib/*"
    )
    args_template = (
        "java -Xmx1024m -cp {0} org.apache.medex.Main "
        "-i {1} -o {2} -b n -f y -d y -t n"
    )

    with tempfile.TemporaryDirectory() as basedir:
        medex_input._generate_medex_inputs(parsed_trial, result)
        input_dir = os.path.join(basedir, "inputs")
        os.makedirs(input_dir, exist_ok=True)
        with open(os.path.join(input_dir, "medex_input.json"), "w") as f:
            json.dump(result, f)

        output_path = os.path.join(basedir, "outputs")
        os.makedirs(os.path.join(output_path, "data"), exist_ok=True)

        args = args_template.format(
            classpath,
            os.path.join(input_dir, "medex_input.json"),
            os.path.join(output_path, "data"),
        )

        print(args)
        try:
            subprocess.run(shlex.split(args), check=True)
        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"Medex execution failed with error: {e}")

        medex_output_parser = medex.MedexOutputParser(base_paths=[output_path])
        medex_output_parser.fill_medex_info(parsed_trial)

    return parsed_trial


def parse_eligiility_criteria(parsed_trial: Dict[str, Any]) -> Dict[str, Any]:
    args_template = (
        "java -Xmx8192m -jar resources/criteria2query.jar  --input {0} --outputDir {1}"
    )

    with tempfile.TemporaryDirectory() as basedir:
        input_dir = os.path.join(basedir, "inputs")
        os.makedirs(input_dir, exist_ok=True)
        with open(os.path.join(input_dir, "crit_input.txt"), "w") as f:
            f.write(parsed_trial.get("eligibility_criteria", "") or "")

        output_path = os.path.join(basedir, "outputs")
        os.makedirs(os.path.join(output_path, "data"), exist_ok=True)

        args = args_template.format(
            os.path.join(input_dir, "crit_input.txt"),
            os.path.join(output_path, "data"),
        )

        print(args)
        try:
            subprocess.run(shlex.split(args), check=True)
        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"Crit2Query execution failed with error: {e}")

        parsed_trial["ec_umls"] = CriteriaOutputParser.parse_crit_output_from_file(
            os.path.join(output_path, "data", "output.json")
        )
    return parsed_trial


def extract_outcomes(parsed_trial: Dict[str, Any]) -> Dict[str, Any]:
    outcome_extractor = OutcomeMeasureExtract(
        f"{DATA_DIR}/outcome_data/clusters-outcome-measures.txt"
    )
    outcome_extractor.load_phrase_models(f"{DATA_DIR}/outcome_data")
    outcome_extractor.populate_cids(parsed_trial)
    return parsed_trial


def population_extraction(umls_utils: UMLSUtils, parsed_trial: Dict[str, Any]) -> Dict[str, Any]:
    umls_concept_searcher = UMLSConceptSearcher(
        api_key="",
        version="2020AB",
        cache_dir=f"{DATA_DIR}/population_data/umls_search_cache",
    )
    umls_concept_searcher.set_umls_search(False)

    criteria_all = parsed_trial["ec_umls"]
    for category in criteria_all:
        for inclusion in criteria_all[category]:
            for criterion in criteria_all[category][inclusion]:
                criterion.map_concept(umls_concept_searcher)

    umls_utils.cuid2parents = {}
    for category in criteria_all:
        for inclusion in criteria_all[category]:
            for criterion in criteria_all[category][inclusion]:
                if criterion.concept is not None:
                    criterion.parents = umls_utils.parents(criterion.concept["ui"])

    tfidf_matcher = UMLSTFIDFMatcher(
        umls_utils.cuid2concept, f"{DATA_DIR}/population_data", None
    )
    tfidf_matcher.populate_result_single(parsed_trial["ec_umls"])
    return parsed_trial


def _phase_feature_vec(phases: List[str]) -> List[int]:
    v = [0] * 5
    for phase in phases or []:
        if phase in ["EARLY_PHASE1", "PHASE1"]:
            v[1] = 1
        elif phase == "N/A":
            v[0] = 1
        elif phase == "PHASE2":
            v[2] = 1
        elif phase == "PHASE3":
            v[3] = 1
        elif phase == "PHASE4":
            v[4] = 1
        else:
            raise RuntimeError(f"Unknown phase: {phase}")
    return v


def _enrollment_feat(enrollment: Any) -> List[float]:
    is_anticipated = False
    if isinstance(enrollment, dict):
        if enrollment.get("type") == "ANTICIPATED":
            is_anticipated = True
        return [math.log(1 + enrollment.get("count", 0)), int(is_anticipated)]
    if isinstance(enrollment, float) and np.isnan(enrollment):
        return [0.0, 0.0]
    return [math.log(1 + (enrollment or 0)), 0.0]


def _sex_vec(sex: Any) -> List[int]:
    if sex is None or isinstance(sex, float):
        return [0, 0, 0]
    sex_to_feats = {
        "ALL": [1, 0, 0],
        "MALE": [0, 1, 0],
        "FEMALE": [0, 0, 1],
    }
    return sex_to_feats.get(sex, [0, 0, 0])


def extract_trial_features(extractor: TrialAttributeFeatures, trial_row: Dict[str, Any]) -> np.ndarray:
    data: Dict[str, Any] = {}
    data["phase_vec"] = _phase_feature_vec(trial_row.get("phase") or [])
    data["enrollment_vec"] = _enrollment_feat(trial_row.get("enrollment"))
    data["gender_sex_vec"] = _sex_vec(trial_row.get("gender_sex"))
    data["minimum_age_vec"] = extractor._age_vec(trial_row.get("minimum_age") or 0.0)
    data["maximum_age_vec"] = extractor._age_vec(trial_row.get("maximum_age") or 0.0)

    def merge_vecs(row: Dict[str, Any]) -> np.ndarray:
        feats: List[float] = []
        for attribute in extractor.attributes:
            if attribute == "phase":
                feats.extend(row["phase_vec"])
            elif attribute == "enrollment":
                feats.extend(row["enrollment_vec"])
            elif attribute == "gender":
                feats.extend(row["gender_sex_vec"])
            elif attribute == "age":
                feats.extend(row["minimum_age_vec"])
                feats.extend(row["maximum_age_vec"])
            elif attribute == "age_class":
                feats.extend(row.get("age_vec_2", []))
            else:
                raise RuntimeError(f"Unknown attributes ({attribute}) for features")
        return np.array(feats)

    return merge_vecs(data)


def get_arm_text(row: Dict[str, Any]) -> Tuple[Dict[Tuple[str, int], str], Dict[str, List[str]]]:
    arm2text: Dict[Tuple[str, int], str] = {}
    nct2text: Dict[str, List[str]] = {}

    summary = row.get("brief_summary", "") or ""

    disease_text = ""
    for disease in row.get("condition", []) or []:
        disease_text += f"{disease} "

    outcome_text = ""
    if not isinstance(row.get("primary_outcome"), float):
        for pom in row.get("primary_outcome", []) or []:
            outcome_text += (pom.get("measure", "") or "") + " "

    criteria = row.get("eligibility_criteria")
    if isinstance(criteria, float) or criteria is None:
        criteria = ""

    arm2intervention: Dict[str, Tuple[str, str]] = {}
    for intervention in row.get("intervention", []) or []:
        intervention_text = (intervention.get("intervention_name", "") or "") + " "
        intervention_desc = (intervention.get("description", "") or "") + " "
        arm_group_label = intervention.get("arm_group_label", ["default"])
        if not isinstance(arm_group_label, list):
            arm_group_label = [arm_group_label]
        for arm_label in arm_group_label:
            arm_label = (arm_label or "default").lower()
            arm2intervention[arm_label] = (intervention_text, intervention_desc)

    arms = row.get("arm_group", [])
    if not isinstance(arms, list) or not arms:
        arms = [{"arm_group_label": "default", "arm_group_type": ""}]

    for idx, arm in enumerate(arms):
        label = arm.get("arm_group_label", "default") or "default"
        arm_text_val = label + " " + (arm.get("description", "") or "")

        if label.lower() in arm2intervention:
            intervention_text, intervention_desc = arm2intervention[label.lower()]
        else:
            intervention_text, intervention_desc = "", ""

        all_text = " ".join(
            [intervention_text, disease_text, outcome_text, arm_text_val, summary, intervention_desc, criteria]
        )

        arm2text[(row["nct_id"], idx)] = all_text
        nct2text[row["nct_id"]] = [disease_text, outcome_text, summary, criteria]

    return arm2text, nct2text


def load_cuid2term() -> Dict[str, Any]:
    basedir = f"{DATA_DIR}/population_data"
    filepath = os.path.join(basedir, "umls_graph_clipper_output.pkl")
    with open(filepath, "rb") as f:
        g_clipper_state = pickle.load(f)
        cuid2term = g_clipper_state["cuid2term"]
    return cuid2term


def build_trial_arms(
    disease_matcher: DiseaseExtract,
    drug_matcher: DrugMatcher,
    umls_utils: UMLSUtils,
    cuid2term: Dict[str, Any],
    parsed_trial: Dict[str, Any],
) -> List[Dict[str, Any]]:
    """
    Build the trial-arm edge representation with robust support for
    multiple PlaNet TrialGraphBuilder API variants.
    """
    entity2cid_path = f"{DATA_DIR}/kg_data/kg-entity2cid-31_7_21.pkl"
    with open(entity2cid_path, "rb") as f:
        entity2cid = pickle.load(f)

    ext_basepath = f"{DATA_DIR}/kg_data/external_data"
    builder = KnowledgeGraphBuilder(
        disease_matcher.mesh_dis_data,
        drug_matcher.drug_data,
        ext_basepath,
        cuid2term,
        umls_utils,
        umls_graph_clip_threshold=10,
        build_ae=False,
    )

    builder.build_external_networks()
    builder._mesh_children()

    # Counterfactual-safe default
    parsed_trial["has_results"] = False

    trial_builder = TrialGraphBuilder(builder, parsed_trial)
    trial_builder.build(use_population=True)

    uf = UnionFind()
    for u, v, data in builder.biokg.graph.edges(data=True):
        if data.get("relation") == "KG-MERGE-SAME":
            uf.union(u, v)

    trial_attribute_featurizer = TrialAttributeFeatures(
        attributes=("age", "gender", "enrollment", "phase")
    )
    trial_attribute_feats = extract_trial_features(trial_attribute_featurizer, parsed_trial)

    arm2text, _ = get_arm_text(parsed_trial)

    # ----------------------------
    # Robust arm label resolution
    # ----------------------------
    arm_items: Optional[List[Tuple[Any, int]]] = None

    # Candidate dict attributes across PlaNet versions
    for cand in ["arm_labels", "arm_label2idx", "arm2idx", "arm_groups"]:
        maybe = getattr(trial_builder, cand, None)
        if isinstance(maybe, dict) and maybe:
            arm_items = list(maybe.items())
            break
        if isinstance(maybe, list) and maybe:
            # If it's a list of labels, map index
            arm_items = [(lbl, i) for i, lbl in enumerate(maybe)]
            break

    # Fallback: derive from parsed_trial arm_groups order
    if not arm_items:
        arms = parsed_trial.get("arm_group") or []
        if isinstance(arms, list) and arms:
            arm_items = []
            for idx, arm in enumerate(arms):
                lbl = (
                    arm.get("arm_group_label")
                    or arm.get("label")
                    or arm.get("name")
                    or f"arm_{idx + 1}"
                )
                arm_items.append((lbl, idx))

    # Ultimate fallback: single default arm
    if not arm_items:
        arm_items = [("default", 0)]

    arm_key_fn = getattr(trial_builder, "arm_key", None)
    if arm_key_fn is None:
        raise RuntimeError(
            "TrialGraphBuilder has no 'arm_key' method in this PlaNet version. "
            "Please check your PlaNet build_graph.py API."
        )

    trial_data: List[Dict[str, Any]] = []

    for arm_label, arm_idx in arm_items:
        trial_arm_data: List[Dict[str, Any]] = []
        try:
            edge_iter = builder.biokg.graph.edges(
                nbunch=[arm_key_fn(arm_idx)],
                data=True,
                keys=True,
            )
        except Exception as e:
            raise RuntimeError(f"Failed to access graph edges for arm_idx={arm_idx}: {e}")

        for u, v, k, data in edge_iter:
            # Resolve KG ID robustly
            parent = uf.find_parent(v)
            kg_id = entity2cid.get(parent)
            if kg_id is None:
                # Best-effort: try direct lookup
                kg_id = entity2cid.get(v)

            trial_arm_data.append(
                {
                    "kg_id": kg_id,
                    "relation": data.get("relation"),
                    "key": k,
                    "data": data,
                }
            )

        arm_text_val = arm2text.get((parsed_trial["nct_id"], arm_idx), "")

        trial_data.append(
            {
                "nct_id": parsed_trial["nct_id"],
                "arm_label": arm_label,
                "arm_idx": arm_idx,
                "trial_arm_edges": trial_arm_data,
                "arm_text": arm_text_val,
                "trial_attribute_feats_vec": trial_attribute_feats,
            }
        )

    return trial_data


# =========================================================================
# DRUG REPLACEMENT / COUNTERFACTUAL FUNCTIONALITY
# =========================================================================

def find_experimental_drug(parsed_trial: Dict[str, Any]) -> Optional[Dict[str, Any]]:
    """
    More permissive experimental intervention detection.

    Pass 1: Prefer Drug/Biological/Dietary Supplement types.
    Pass 2: Any non-placebo intervention with a name.
    """
    interventions = parsed_trial.get("intervention", []) or []

    drugish_types = {"drug", "biological", "dietary supplement"}

    for intervention in interventions:
        name = (intervention.get("intervention_name", "") or "").lower()
        itype = (intervention.get("intervention_type", "") or "").lower()
        if name and "placebo" not in name and itype in drugish_types:
            return intervention

    for intervention in interventions:
        name = (intervention.get("intervention_name", "") or "").lower()
        if name and "placebo" not in name:
            return intervention

    return None


def find_experimental_arm(parsed_trial: Dict[str, Any]) -> Optional[Dict[str, Any]]:
    """
    Locate a non-placebo arm as a fallback anchor.
    """
    arms = parsed_trial.get("arm_group", []) or []

    for arm in arms:
        arm_type = (arm.get("type", "") or arm.get("arm_group_type", "") or "").upper()
        if arm_type in {"EXPERIMENTAL", "ACTIVE_COMPARATOR", "ACTIVE"}:
            return arm

    for arm in arms:
        label = (arm.get("arm_group_label", "") or "").lower()
        desc = (arm.get("description", "") or "").lower()
        if "placebo" not in label and "placebo" not in desc:
            return arm

    return None


def create_synthetic_medex_output(drug_name: str, drugbank_id: Optional[str] = None) -> Dict[str, Any]:
    """
    Minimal medex_out stub to help DrugMatcher in CF mode without full text parsing.
    Route left generic.
    """
    return {
        "drug_name": drug_name,
        "generic_name": drug_name.lower(),
        "brand_name": "",
        "drugbank_id": drugbank_id,
        "rxnorm_id": None,
        "umls_cui": None,
        "strength": "",
        "dose_form": "",
        "route": "",
        "frequency": "",
        "duration": "",
        "necessity": "",
    }


def sanitize_counterfactual_parsed_trial(
    modified: Dict[str, Any],
    original_nct: str,
    original_drug_name: str,
    canonical_name: str,
    keep_original_summary: bool = False,
) -> Dict[str, Any]:
    """
    Remove/neutralize original-drug + real-results leakage in *parsed* CF trials.
    """
    modified["clinical_results"] = {}
    modified["has_results"] = False

    if keep_original_summary:
        orig = (modified.get("brief_summary") or "").strip()
        modified["brief_summary"] = (
            f"[COUNTERFACTUAL ANALYSIS] Synthetic trial based on {original_nct}. "
            f"Experimental intervention replaced with {canonical_name}. "
            f"{orig}"
        ).strip()
    else:
        modified["brief_summary"] = (
            f"[COUNTERFACTUAL ANALYSIS] Synthetic counterfactual trial reusing the "
            f"original population and outcomes of {original_nct}, with the experimental "
            f"intervention replaced by {canonical_name}. No real results are included."
        )

    # Clean legacy keys if present
    for k in ["event_groups", "resultsSection", "hasResults"]:
        if k in modified:
            try:
                del modified[k]
            except Exception:
                modified[k] = None

    # Ensure obvious arm label + description replace
    for arm in modified.get("arm_group", []) or []:
        for field in ["arm_group_label", "description"]:
            val = arm.get(field)
            if isinstance(val, str) and original_drug_name.lower() in val.lower():
                arm[field] = re.sub(
                    re.escape(original_drug_name),
                    canonical_name,
                    val,
                    flags=re.IGNORECASE,
                )

    return modified


def replace_drug_in_parsed_trial(
    parsed_trial: Dict,
    new_drug_name: str,
    new_drug_description: Optional[str] = None,
    new_drug_other_names: Optional[List[str]] = None,
    drugbank_id: Optional[str] = None,
    validate_in_kg: bool = True,
    keep_original_summary: bool = False,
) -> Dict:
    """
    Replace the experimental intervention in a parsed trial while keeping the cohort constant.
    Adds counterfactual leakage controls by default.

    Robustness:
    - If no clear experimental drug exists in the base trial, attempt to infer
      a non-placebo arm and create a minimal synthetic base intervention as an anchor.
    """
    modified = deepcopy(parsed_trial)

    original_drug = find_experimental_drug(parsed_trial)

    # ---- NEW: synthetic anchor fallback ----
    if original_drug is None:
        inferred_arm = find_experimental_arm(parsed_trial)
        if inferred_arm:
            inferred_label = (
                inferred_arm.get("arm_group_label")
                or inferred_arm.get("label")
                or "Experimental Arm"
            )

            modified.setdefault("intervention", []).append({
                "intervention_type": "Drug",
                "intervention_name": inferred_label,
                "description": "",
                "other_name": [],
                "arm_group_label": [inferred_label],
            })

            original_drug = modified["intervention"][-1]
            print(
                f"[WARNING] No explicit experimental drug found. "
                f"Using inferred arm '{inferred_label}' as base intervention for replacement."
            )
        else:
            raise ValueError(
                "No experimental drug found in trial to replace, and no non-placebo arm "
                "could be inferred. This base trial may not be suitable for drug counterfactuals."
            )

    original_drug_name = original_drug.get("intervention_name", "Unknown")

    canonical_name = new_drug_name
    resolved_drugbank_id = drugbank_id

    if validate_in_kg:
        lookup_result = check_drug_in_kg(new_drug_name)

        if lookup_result["found"]:
            canonical_name = lookup_result["canonical_name"]
            resolved_drugbank_id = lookup_result["drugbank_id"]

            if lookup_result["match_type"] == "fuzzy":
                print(
                    f"[INFO] Drug name '{new_drug_name}' matched to '{canonical_name}' "
                    f"({lookup_result['similarity']:.0%} similar)"
                )
            else:
                print(f"[INFO] Drug '{new_drug_name}' found in KG as '{canonical_name}'")
        else:
            print(f"[WARNING] Drug '{new_drug_name}' NOT FOUND in KG!")
            if lookup_result.get("suggestions"):
                print(f"[WARNING] Did you mean: {', '.join(lookup_result['suggestions'][:3])}?")
            print("[WARNING] Proceeding anyway, but drug may not have KG relationships.")

    print(f"[INFO] Replacing intervention: '{original_drug_name}' -> '{canonical_name}'")

    original_nct = parsed_trial.get("nct_id", "UNKNOWN")

    modified["_counterfactual"] = {
        "original_drug": original_drug_name,
        "replacement_drug": canonical_name,
        "input_drug_name": new_drug_name,
        "drugbank_id": resolved_drugbank_id,
        "original_nct_id": original_nct,
    }

    safe_drug_suffix = "".join(c if c.isalnum() else "_" for c in canonical_name)[:20]
    modified["nct_id"] = f"{original_nct}_CF_{safe_drug_suffix}"

    default_desc = (
        f"{canonical_name} administered as the experimental intervention in a "
        f"synthetic counterfactual trial using the original {original_nct} "
        f"population and outcomes."
    )

    # ------------------------------------------------------------------
    # 1) Replace the intervention (and keep its arm_group_label in sync)
    # ------------------------------------------------------------------
    replaced = False
    for i, intervention in enumerate(modified.get("intervention", []) or []):
        name = (intervention.get("intervention_name", "") or "").lower()
        itype = (intervention.get("intervention_type", "") or "").lower()

        # Replace first non-placebo "drugish" intervention
        if "placebo" not in name and itype in {"drug", "biological", "dietary supplement", ""}:
            arm_labels = intervention.get("arm_group_label", [])
            if not isinstance(arm_labels, list):
                arm_labels = [arm_labels]

            new_arm_labels = []
            for lbl in arm_labels:
                if (
                    isinstance(lbl, str)
                    and original_drug_name.lower() in lbl.lower()
                    and "placebo" not in lbl.lower()
                ):
                    new_lbl = re.sub(
                        re.escape(original_drug_name),
                        canonical_name,
                        lbl,
                        flags=re.IGNORECASE,
                    )
                    new_arm_labels.append(new_lbl)
                else:
                    new_arm_labels.append(lbl)

            if not new_arm_labels:
                # If base intervention had no labels, infer from arm
                inferred = find_experimental_arm(modified)
                if inferred:
                    inferred_label = inferred.get("arm_group_label") or "Experimental Arm"
                    new_arm_labels = [inferred_label]

            modified["intervention"][i] = {
                "intervention_type": "Drug",
                "intervention_name": canonical_name,
                "description": new_drug_description or default_desc,
                "other_name": new_drug_other_names or [],
                "arm_group_label": new_arm_labels,
                "medex_out": [
                    create_synthetic_medex_output(canonical_name, resolved_drugbank_id)
                ],
            }
            replaced = True
            break

    # If we somehow didn't replace anything, append a fresh experimental intervention
    if not replaced:
        inferred = find_experimental_arm(modified)
        inferred_label = inferred.get("arm_group_label") if inferred else "Experimental Arm"

        modified.setdefault("intervention", []).append({
            "intervention_type": "Drug",
            "intervention_name": canonical_name,
            "description": new_drug_description or default_desc,
            "other_name": new_drug_other_names or [],
            "arm_group_label": [inferred_label],
            "medex_out": [
                create_synthetic_medex_output(canonical_name, resolved_drugbank_id)
            ],
        })

    # ------------------------------------------------------------------
    # 2) Update arm_group labels and descriptions consistently
    # ------------------------------------------------------------------
    for arm in modified.get("arm_group", []) or []:
        if "interventionNames" in arm:
            new_intervention_names = []
            for nm in arm["interventionNames"]:
                if (
                    isinstance(nm, str)
                    and original_drug_name.lower() in nm.lower()
                    and "placebo" not in nm.lower()
                ):
                    new_intervention_names.append(f"Drug: {canonical_name}")
                else:
                    new_intervention_names.append(nm)
            arm["interventionNames"] = new_intervention_names

        arm_label_val = arm.get("arm_group_label", "")
        if isinstance(arm_label_val, str) and original_drug_name.lower() in arm_label_val.lower():
            arm["arm_group_label"] = re.sub(
                re.escape(original_drug_name),
                canonical_name,
                arm_label_val,
                flags=re.IGNORECASE,
            )

        desc = arm.get("description", "")
        if isinstance(desc, str) and original_drug_name.lower() in desc.lower():
            arm["description"] = re.sub(
                re.escape(original_drug_name),
                canonical_name,
                desc,
                flags=re.IGNORECASE,
            )

    # ------------------------------------------------------------------
    # 3) Sanitise parsed-level leakage
    # ------------------------------------------------------------------
    modified = sanitize_counterfactual_parsed_trial(
        modified=modified,
        original_nct=original_nct,
        original_drug_name=original_drug_name,
        canonical_name=canonical_name,
        keep_original_summary=keep_original_summary,
    )

    return modified


def create_counterfactual_trials(
    parsed_trial: Dict[str, Any],
    replacement_drugs: List[str],
    validate_in_kg: bool = True,
) -> List[Dict[str, Any]]:
    counterfactuals: List[Dict[str, Any]] = []
    for drug_name in replacement_drugs:
        try:
            modified = replace_drug_in_parsed_trial(
                parsed_trial,
                new_drug_name=drug_name,
                validate_in_kg=validate_in_kg,
            )
            counterfactuals.append(modified)
        except ValueError as e:
            print(f"[WARNING] Could not create counterfactual for {drug_name}: {e}")
    return counterfactuals


def process_counterfactual(
    modified_trial: Dict[str, Any],
    drug_matcher: DrugMatcher,
    disease_matcher: DiseaseExtract,
    umls_utils: UMLSUtils,
    cuid2term: Dict[str, Any],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    """
    Process a counterfactual trial through drug matching and graph building.
    Includes extra safety cleanup to prevent results leakage.
    """
    modified_trial["clinical_results"] = {}
    modified_trial["has_results"] = False

    for intervention in modified_trial.get("intervention", []) or []:
        itype = (intervention.get("intervention_type", "") or "").lower()
        if itype in {"drug", "biological", "dietary supplement"}:
            if "medex_out" in intervention:
                print(f"[INFO] Running drug matching for: {intervention.get('intervention_name')}")
                get_intervention_drug_ids(drug_matcher, intervention, modified_trial)

    trial_data = build_trial_arms(
        disease_matcher, drug_matcher, umls_utils, cuid2term, modified_trial
    )

    return modified_trial, trial_data


# =========================================================================
# MAIN FUNCTIONS
# =========================================================================

def load_resources() -> Tuple[DrugMatcher, DiseaseExtract, UMLSUtils, Dict[str, Any]]:
    print("[INFO] Loading resources...")

    drug_matcher = DrugMatcher(
        data_paths={
            "drug_data": f"{DATA_DIR}/drug_data/drugs_all_03_04_21.pkl",
            "pubchem_synonyms": f"{DATA_DIR}/drug_data/pubchem-drugbankid-synonyms.json",
            "rxnorm2drugbank-umls": f"{DATA_DIR}/drug_data/rxnorm2drugbank-umls.pkl",
            "RXNCONSO": f"{DATA_DIR}/drug_data/RXNCONSO.RRF",
        }
    )
    disease_matcher = DiseaseExtract(data_dir=DATA_DIR, data_year=2021)
    umls_utils = UMLSUtils(f"{DATA_DIR}/population_data/umls-install/2020AB")
    umls_utils.load_relations()
    cuid2term = load_cuid2term()

    print("[INFO] Resources loaded successfully")
    return drug_matcher, disease_matcher, umls_utils, cuid2term


def parse_trial_with_resources(
    nct_id: str,
    drug_matcher: DrugMatcher,
    disease_matcher: DiseaseExtract,
    umls_utils: UMLSUtils,
    cuid2term: Dict[str, Any],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    trial = parse(nct_id)
    trial = run_medex_and_parse_output(trial)
    trial = parse_eligiility_criteria(trial)
    trial["mesh_ids"] = disease_matcher.get_disease_ids(trial)

    trial = extract_outcomes(trial)
    trial = population_extraction(umls_utils, trial)

    trial_data = build_trial_arms(disease_matcher, drug_matcher, umls_utils, cuid2term, trial)
    return trial, trial_data


def main(nct_id: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    drug_matcher, disease_matcher, umls_utils, cuid2term = load_resources()
    return parse_trial_with_resources(nct_id, drug_matcher, disease_matcher, umls_utils, cuid2term)


def main_counterfactual(
    nct_id: Optional[str] = None,
    from_parsed: Optional[str] = None,
    replacement_drugs: Optional[List[str]] = None,
    output_dir: Optional[str] = None,
    skip_if_not_found: bool = False,
    keep_original_summary: bool = False,
) -> List[Dict[str, Any]]:
    if not replacement_drugs:
        raise ValueError("No replacement drugs specified")

    output_dir = output_dir or RESULTS_ROOT
    ensure_dir(output_dir)

    print("\n" + "=" * 70)
    print("STEP 1: Validating replacement drugs in Knowledge Graph/drug table")
    print("=" * 70)

    lookup_results = check_drugs_in_kg(replacement_drugs, verbose=True)

    valid_drugs: List[Dict[str, Any]] = []
    for drug_name, result in zip(replacement_drugs, lookup_results):
        if result["found"]:
            valid_drugs.append({
                "input_name": drug_name,
                "canonical_name": result["canonical_name"],
                "drugbank_id": result["drugbank_id"],
            })
        elif not skip_if_not_found:
            valid_drugs.append({
                "input_name": drug_name,
                "canonical_name": drug_name,
                "drugbank_id": None,
            })
            print(f"[WARNING] Including '{drug_name}' even though not found in KG/drug table")

    if not valid_drugs:
        print("[ERROR] No valid drugs to process!")
        return []

    print(f"\n[INFO] Processing {len(valid_drugs)} drug(s)")

    print("\n" + "=" * 70)
    print("STEP 2: Loading resources")
    print("=" * 70)
    drug_matcher, disease_matcher, umls_utils, cuid2term = load_resources()

    print("\n" + "=" * 70)
    print("STEP 3: Loading/parsing original trial")
    print("=" * 70)

    if from_parsed:
        print(f"[INFO] Loading parsed trial from: {from_parsed}")
        original_trial = load_json(from_parsed)
        original_nct = original_trial.get("nct_id", "UNKNOWN")
        original_trial_data = None
    else:
        print(f"[INFO] Parsing original trial: {nct_id}")
        original_trial, original_trial_data = parse_trial_with_resources(
            nct_id, drug_matcher, disease_matcher, umls_utils, cuid2term
        )
        original_nct = nct_id

        save_json(original_trial, os.path.join(output_dir, f"parsed_trial_{original_nct}.json"))
        if original_trial_data is not None:
            save_pkl(original_trial_data, os.path.join(output_dir, f"trial_data_{original_nct}.pkl"))

    original_drug = find_experimental_drug(original_trial)
    if original_drug:
        print(f"[INFO] Original experimental intervention: {original_drug.get('intervention_name')}")
    else:
        print("[WARNING] No clear original experimental intervention detected in base trial.")

    print("\n" + "=" * 70)
    print("STEP 4: Creating counterfactual trials")
    print("=" * 70)

    results: List[Dict[str, Any]] = []

    for drug_info in valid_drugs:
        drug_name = drug_info["input_name"]
        canonical_name = drug_info["canonical_name"]

        print("\n" + "-" * 60)
        print(f"[INFO] Creating counterfactual: {drug_name}")
        if canonical_name != drug_name:
            print(f"[INFO] Using canonical name: {canonical_name}")
        print("-" * 60)

        try:
            modified_trial = replace_drug_in_parsed_trial(
                original_trial,
                new_drug_name=canonical_name,
                drugbank_id=drug_info["drugbank_id"],
                validate_in_kg=False,  # already done above
                keep_original_summary=keep_original_summary,
            )

            modified_trial, trial_data = process_counterfactual(
                modified_trial, drug_matcher, disease_matcher, umls_utils, cuid2term
            )

            cf_nct = modified_trial["nct_id"]

            parsed_path = os.path.join(output_dir, f"parsed_trial_{cf_nct}.json")
            pkl_path = os.path.join(output_dir, f"trial_data_{cf_nct}.pkl")

            save_json(modified_trial, parsed_path)
            save_pkl(trial_data, pkl_path)

            print(f"[INFO] Saved: {parsed_path}")
            print(f"[INFO] Saved: {pkl_path}")

            # Check drug edges in first arm as a quick sanity check
            drug_edges = []
            if trial_data:
                drug_edges = [
                    e for e in trial_data[0].get("trial_arm_edges", [])
                    if e.get("relation") == "arm_tests_drug"
                ]

            if drug_edges:
                print(f"[INFO] ✅ Drug matched in KG with {len(drug_edges)} edge(s)")
                for edge in drug_edges[:5]:
                    print(f"       -> KG ID: {edge.get('kg_id')}")
            else:
                print(f"[WARNING] ⚠️ Drug '{canonical_name}' has NO drug edges in first arm!")

            results.append({
                "input_drug_name": drug_name,
                "canonical_drug_name": canonical_name,
                "drugbank_id": drug_info["drugbank_id"],
                "counterfactual_nct": cf_nct,
                "parsed_path": parsed_path,
                "pkl_path": pkl_path,
                "drug_matched_first_arm": len(drug_edges) > 0,
                "num_drug_edges_first_arm": len(drug_edges),
                "num_total_edges_first_arm": len(trial_data[0].get("trial_arm_edges", [])) if trial_data else 0,
            })

        except Exception as e:
            print(f"[ERROR] Failed to process counterfactual for {drug_name}: {e}")
            import traceback
            traceback.print_exc()
            results.append({
                "input_drug_name": drug_name,
                "error": str(e),
            })

    summary = {
        "original_trial": original_nct,
        "original_drug": original_drug.get("intervention_name") if original_drug else None,
        "counterfactuals": results,
        "drugs_found_in_kg": sum(1 for r in lookup_results if r["found"]),
        "drugs_processed": len([r for r in results if "error" not in r]),
        "keep_original_summary": keep_original_summary,
    }
    summary_path = os.path.join(output_dir, f"counterfactual_summary_{original_nct}.json")
    save_json(summary, summary_path)

    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    print(f"Original trial: {original_nct}")
    print(f"Original intervention: {original_drug.get('intervention_name') if original_drug else 'N/A'}")
    print(f"Drugs checked: {len(replacement_drugs)}")
    print(f"Drugs found in KG/drug table: {sum(1 for r in lookup_results if r['found'])}")
    print(f"Counterfactuals created: {len([r for r in results if 'error' not in r])}")
    print(f"Summary saved to: {summary_path}")
    print("=" * 70 + "\n")

    return results


# -------------------------
# CLI
# -------------------------

def _cli() -> None:
    parser = argparse.ArgumentParser(
        description="Process clinical trial data with optional drug replacement for counterfactual analysis.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=r"""
Examples:
  # Check if drugs exist in KG (quick lookup, no processing)
  python parse_new_drug.py --check-drug "tocilizumab" "remdesivir"

  # Standard trial parsing
  python parse_new_drug.py NCT04678830

  # Replace drug for counterfactual analysis
  python parse_new_drug.py NCT04678830 --replace-drug "Tocilizumab"

  # Multiple drug replacements
  python parse_new_drug.py NCT04678830 -r "tocilizumab" "baricitinib" "dexamethasone"

  # From existing parsed JSON
  python parse_new_drug.py --from-parsed parsed_trial_NCT04678830.json -r "Tocilizumab"

  # Skip drugs not found in KG/drug table
  python parse_new_drug.py NCT04678830 -r "tocilizumab" "unknowndrug" --skip-not-found

  # Not recommended: keep original brief summary text in CF mode
  python parse_new_drug.py NCT04678830 -r "tocilizumab" --keep-original-summary
        """,
    )
    parser.add_argument(
        "nctid",
        type=str,
        nargs="?",
        help="NCT ID of the clinical trial (or path to JSON)",
    )
    parser.add_argument(
        "--from-parsed",
        type=str,
        help="Path to existing parsed trial JSON (instead of parsing from NCT ID)",
    )
    parser.add_argument(
        "--replace-drug",
        "-r",
        type=str,
        nargs="+",
        help="Drug name(s) to substitute for counterfactual analysis",
    )
    parser.add_argument(
        "--check-drug",
        "-c",
        type=str,
        nargs="+",
        help="Check if drug name(s) exist in KG/drug table (no processing, just lookup)",
    )
    parser.add_argument(
        "--skip-not-found",
        action="store_true",
        help="Skip drugs not found in KG/drug table (default: process anyway with warning)",
    )
    parser.add_argument(
        "--keep-original-summary",
        action="store_true",
        help="Keep the original brief summary text in counterfactual mode "
             "(NOT recommended; may leak original-drug signal).",
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        type=str,
        default=None,
        help="Output directory for results",
    )

    args = parser.parse_args()

    # Mode 1: Just check drugs in KG / drug table
    if args.check_drug:
        print("[MODE] Drug KG lookup only")
        results = check_drugs_in_kg(args.check_drug, verbose=True)
        if not all(r["found"] for r in results):
            raise SystemExit(1)
        raise SystemExit(0)

    # Validate other arguments
    if not args.nctid and not args.from_parsed:
        parser.error("Either nctid or --from-parsed must be provided (unless using --check-drug)")

    if args.replace_drug:
        # Mode 2: Counterfactual drug replacement
        print("[MODE] Counterfactual drug replacement")

        results = main_counterfactual(
            nct_id=args.nctid,
            from_parsed=args.from_parsed,
            replacement_drugs=args.replace_drug,
            output_dir=args.output_dir or RESULTS_ROOT,
            skip_if_not_found=args.skip_not_found,
            keep_original_summary=args.keep_original_summary,
        )

        success_count = len([r for r in results if "error" not in r])
        print(f"\n[DONE] Created {success_count}/{len(results)} counterfactual(s)")

    else:
        # Mode 3: Standard parsing
        print("[MODE] Standard trial parsing")
        nct_id_clean = args.nctid.split("/")[-1].replace(".json", "")

        outroot = ensure_dir(args.output_dir or RESULTS_ROOT)
        outdir = ensure_dir(os.path.join(outroot, nct_id_clean))
        print(f"[INFO] output root: {outroot}")
        print(f"[INFO] output trial dir: {outdir}")

        enriched_trial, trial_data = main(args.nctid)

        arms_pkl_new = os.path.join(outroot, f"trial_data_{nct_id_clean}.pkl")
        save_pkl(trial_data, arms_pkl_new)
        print(f"[INFO] wrote {arms_pkl_new}")

        arms_pkl_compat = os.path.join(outroot, f"{nct_id_clean}_results.pkl")
        save_pkl(trial_data, arms_pkl_compat)
        print(f"[INFO] wrote {arms_pkl_compat}")

        parsed_trial_path = os.path.join(outroot, f"parsed_trial_{nct_id_clean}.json")
        save_json(enriched_trial, parsed_trial_path)
        print(f"[INFO] wrote {parsed_trial_path}")

        summary = {
            "nct_id": nct_id_clean,
            "num_arms": len(trial_data),
            "keys_first_arm": sorted(list(trial_data[0].keys())) if trial_data else [],
        }
        summary_path = os.path.join(outroot, f"{nct_id_clean}_summary.json")
        save_json(summary, summary_path)
        print(f"[INFO] wrote {summary_path}")

        print("[INFO] done.")


if __name__ == "__main__":
    _cli()

#### **PBS Script**

```bash
#!/bin/bash
#PBS -P sq95
#PBS -q normal
#PBS -l ncpus=1
#PBS -l mem=48GB
#PBS -l jobfs=1GB
#PBS -l walltime=30:00:00
#PBS -l wd
#PBS -r y
#PBS -N PlaNet_Parse_NCT04809974

# ==============================================================================
# CONFIGURATION
# ==============================================================================
SIF=/scratch/sq95/sp6154/planet/planet.sif
WORKDIR=/scratch/sq95/sp6154/planet/parsing_package

NCT_ID="NCT04809974" # or "NCT04880161" or "NCT05576662"
STUDY_DIR="$WORKDIR/Input_ALL/$NCT_ID"
NLTK_DATA_DIR="$WORKDIR/nltk_data"

RESULTS_DIR="$WORKDIR/LC_Results/$NCT_ID"
LOGS_DIR="$WORKDIR/logs"

# The single JSON to process (Option 2)
JSON_NAME="${NCT_ID}.json"

# ==============================================================================
# DRUG LIST FILE
# ==============================================================================
DRUG_FILE="$WORKDIR/drugs_to_test.txt"

# ==============================================================================
# SETUP
# ==============================================================================
mkdir -p "$RESULTS_DIR"
mkdir -p "$LOGS_DIR"

module load singularity
module load java/jdk-17.0.2

echo "==== Single-trial job (no array) ===="
echo "NCT ID: $NCT_ID"
echo "Start time: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"

# --- Check drug file exists ---
if [ ! -f "$DRUG_FILE" ]; then
    echo "ERROR: Drug file not found: $DRUG_FILE"
    echo "Create it with one drug name per line."
    exit 1
fi

# --- Read drugs from file into Bash array (skip empty/whitespace-only lines) ---
mapfile -t DRUGS_ARRAY < <(grep -v '^[[:space:]]*$' "$DRUG_FILE")

if [ "${#DRUGS_ARRAY[@]}" -eq 0 ]; then
    echo "ERROR: No drugs found in $DRUG_FILE (file is empty or only whitespace)."
    exit 1
fi

echo "Drugs to test (${#DRUGS_ARRAY[@]}):"
for d in "${DRUGS_ARRAY[@]}"; do
    echo "  - $d"
done
echo "----------------------------------------------------"

# --- Check JSON exists ---
if [ ! -f "$STUDY_DIR/$JSON_NAME" ]; then
    echo "ERROR: Input JSON not found: $STUDY_DIR/$JSON_NAME"
    exit 1
fi

echo "Processing trial: $NCT_ID"
echo "Input JSON: $STUDY_DIR/$JSON_NAME"
echo "Results dir: $RESULTS_DIR"
echo "Logs dir:    $LOGS_DIR"
echo "----------------------------------------------------"

# ==============================================================================
# RUN SINGULARITY + PYTHON FOR *ALL* DRUGS
# ==============================================================================
# IMPORTANT FIX:
#   We pass NCT_ID + all drugs as positional parameters to bash -c,
#   then inside the script use "$1" (NCT ID) and "${@:2}" (all drugs).
# ==============================================================================
singularity exec \
    --bind "$WORKDIR":/app \
    --bind "$STUDY_DIR":/app/studies \
    --bind "$NLTK_DATA_DIR":/app/nltk_data \
    --bind "$RESULTS_DIR":/app/results \
    --bind "$LOGS_DIR":/app/logs \
    --bind /apps/java:/apps/java \
    "$SIF" \
    /bin/bash -c '
        set -e

        export NLTK_DATA=/app/nltk_data
        export JAVA_HOME=/apps/java/jdk-17.0.2
        export PATH="$JAVA_HOME/bin:$PATH"
        export RESULTS_DIR=/app/results

        cd /app

        echo "[INFO] Inside container"
        echo "[INFO] NCT ID: $1"
        echo "[INFO] Drugs: ${@:2}"

        python /app/parse_new_drug.py "$1" \
            --replace-drug "${@:2}" \
            --output-dir /app/results
    ' bash "$NCT_ID" "${DRUGS_ARRAY[@]}" \
    2>&1 | grep -v 'Network is unreachable' > "$LOGS_DIR/${NCT_ID}.log"

echo "Finished processing $NCT_ID"
echo "Log saved to: $LOGS_DIR/${NCT_ID}.log"
echo "End time: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
```

### **Predict**

#### **Python Script**

In [None]:
#!/usr/bin/env python3
"""
predict_all_for_new_clinial_trial.py

Runs AE, safety, and efficacy predictions for clinical trial data using the PlaNet pipeline.
Usage:
    python predict_all_for_new_clinial_trial.py \
        --pklpath path/to/trial.pkl \
        [--output-dir results] \
        [--model-dir data/models] \
        [--device cpu]
"""

import os
import sys
import json
import argparse
import logging
import time
import threading
from contextlib import contextmanager
from copy import deepcopy
import re
import pickle as pkl

# Third-party imports
import torch
from torch.utils.data import TensorDataset
import networkx as nx
import numpy as np
if not hasattr(np, 'bool'):
    np.bool = bool
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# Configure logging
def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )
    return logging.getLogger(__name__)

log = setup_logging()


def configure_paths():
    """Insert project directories into sys.path for local imports."""
    cwd = os.getcwd()  # should be /app inside container
    for path in [cwd, '/notebooks', '/planet']:
        if os.path.isdir(path) and path not in sys.path:
            sys.path.insert(0, path)
            log.info(f"Added '{path}' to sys.path")


def find_file(filename, search_paths=None):
    """Search for a file in multiple base directories."""
    if search_paths is None:
        search_paths = ['.', '/app', '/planet', '/planet/parsing_package']
    for base in search_paths:
        candidate = os.path.join(base, filename)
        if os.path.exists(candidate):
            return candidate
    raise FileNotFoundError(f"Could not find '{filename}' in {search_paths}")


def load_kg_vocab():
    kgid2x, x2kgid = {}, {}
    relname2etype, etype2relname = {}, {}

    entities_file = find_file('data/graph/entities.dict')
    log.info(f"Using entities file: {entities_file}")
    with open(entities_file) as f:
        for line in f:
            x, kgid = line.split()
            kgid2x[kgid] = int(x)
            x2kgid[int(x)] = kgid

    relations_file = find_file('data/graph/relations.dict')
    log.info(f"Using relations file: {relations_file}")
    with open(relations_file) as f:
        for line in f:
            etype, name = line.split()
            name = name[name.find('rel-name-') + len('rel-name-'):]
            relname2etype[name] = int(etype)
            etype2relname[int(etype)] = name

    return kgid2x, x2kgid, relname2etype, etype2relname


@contextmanager
def suppress_output_with_progress(description):
    """
    Suppresses stdout/stderr while showing a live elapsed-time bar.
    """
    old_out, old_err = sys.stdout, sys.stderr
    null_fd = os.open(os.devnull, os.O_WRONLY)
    new_out, new_err = os.dup(1), os.dup(2)

    pbar = tqdm(desc=f"{description}... Time elapsed", bar_format="{desc}: {n:.1f}s", ncols=40)
    os.dup2(null_fd, 1)
    os.dup2(null_fd, 2)

    start = time.time()
    running = threading.Event()
    running.set()

    def update():
        while running.is_set():
            pbar.n = time.time() - start
            pbar.refresh()
            time.sleep(0.1)

    t = threading.Thread(target=update, daemon=True)
    t.start()

    try:
        yield
    finally:
        running.clear()
        t.join()
        pbar.close()
        os.dup2(new_out, 1)
        os.dup2(new_err, 2)
        os.close(null_fd)
        os.close(new_out)
        os.close(new_err)
        sys.stdout, sys.stderr = old_out, old_err


def load_kg_utils():
    """Import and verify the knowledge_graph package."""
    try:
        import knowledge_graph  # noqa: F401
        from knowledge_graph import kg  # noqa: F401
        log.info("Imported knowledge_graph successfully")
    except ImportError as e:
        log.error(f"Failed to import knowledge_graph: {e}")
        sys.exit(1)


def get_trial_feature(bert_model, new_trial):
    emb = bert_model._embed(new_trial['arm_text'])
    return np.concatenate([emb, new_trial['trial_attribute_feats_vec']])


def get_new_edges(new_trial, the_x, kgid2x, relname2etype):
    new_edges, new_etypes, seen = [], [], set()
    for edge in new_trial['trial_arm_edges']:
        h, t = the_x, kgid2x[edge['kg_id']]
        r = relname2etype[edge['relation']]
        new_edges += [[h, t], [t, h]]
        new_etypes += [r, r + 26]
        seen.add(r)
    for r in [21, 22, 23, 24, 25]:
        if r not in seen:
            new_edges += [[the_x, 0], [0, the_x]]
            new_etypes += [r, r + 26]
    return torch.tensor(new_edges).t(), torch.tensor(new_etypes)


def add_new_trial_to_dataset(dataset, encoder, new_trial, kgid2x, relname2etype, bert_model):
    df = dataset.df[dataset.df['split'] == 'test'].head(1)
    dataset.df = df
    the_x = df.iloc[0]['x']
    the_kgid = df.iloc[0]['kgid']

    node_feats = deepcopy(dataset.node_feats)
    pos = node_feats[node_feats['node_id'] == the_kgid].index[0]
    node_feats.at[pos, 'emb'] = get_trial_feature(bert_model, new_trial)
    dataset.node_feats = node_feats

    graph = deepcopy(dataset.graph)
    new_edges, new_etypes = get_new_edges(new_trial, the_x, kgid2x, relname2etype)
    mask = graph.data.edge_index.eq(the_x).any(dim=0)
    graph.data.edge_index = torch.cat([graph.data.edge_index[:, ~mask], new_edges], dim=1)
    graph.data.edge_type = torch.cat([graph.data.edge_type[~mask], new_etypes], dim=0)
    dataset.graph = graph

    x = dataset._get_data_x(dataset.df)
    ds = TensorDataset(
        x,
        dataset.task_ys[0][0].repeat(len(x), 1),
        dataset.sample_weight_masks[0][0].repeat(len(x), 1)
    )
    dataset.datasets['test'] = ds
    return dataset, encoder


def add_new_trial_to_efficacy_dataset(dataset, encoder, trial1, trial2, kgid2x, relname2etype, bert_model):
    df = dataset.efficacy_df[dataset.efficacy_df['split'] == 'test'].head(1)
    dataset.efficacy_df = df
    x1, x2 = df.iloc[0]['x1'], df.iloc[0]['x2']
    kg1, kg2 = df.iloc[0]['kgid1'], df.iloc[0]['kgid2']

    node_feats = deepcopy(dataset.node_feats)
    pos1 = node_feats[node_feats['node_id'] == kg1].index[0]
    node_feats.at[pos1, 'emb'] = get_trial_feature(bert_model, trial1)
    pos2 = node_feats[node_feats['node_id'] == kg2].index[0]
    node_feats.at[pos2, 'emb'] = get_trial_feature(bert_model, trial2)
    dataset.node_feats = node_feats

    graph = deepcopy(dataset.graph)
    edges1, types1 = get_new_edges(trial1, x1, kgid2x, relname2etype)
    edges2, types2 = get_new_edges(trial2, x2, kgid2x, relname2etype)
    mask1 = graph.data.edge_index.eq(x1).any(dim=0)
    mask2 = graph.data.edge_index.eq(x2).any(dim=0)
    graph.data.edge_index = torch.cat([
        graph.data.edge_index[:, ~mask1 & ~mask2],
        edges1, edges2
    ], dim=1)
    graph.data.edge_type = torch.cat([
        graph.data.edge_type[~mask1 & ~mask2],
        types1, types2
    ], dim=0)
    dataset.graph = graph

    x = dataset._get_data_x(df)
    ds = TensorDataset(
        x,
        dataset.task_ys[0][0].repeat(len(x), 1),
        dataset.sample_weight_masks[0][0].repeat(len(x), 1)
    )
    dataset.datasets['test'] = ds
    return dataset, encoder


def predict_top_ae(pred, k=5):
    # 1) coerce to 1-d NumPy array
    arr = np.asarray(pred).ravel()
    # 2) if empty, return empty dict
    if arr.size == 0:
        return {}
    # 3) argsort descending
    idx_desc = np.argsort(arr)[::-1]
    # 4) take up to k (or the array’s length, whichever is smaller)
    topk = idx_desc[:min(k, arr.size)]
    # 5) build result dict
    return {int(i): float(arr[i]) for i in topk}


def choose_trial_pair(new_trial_data):
    """
    Choose which two entries to treat as trial_1 and trial_2 for efficacy.

    Preferred behaviour:
      - trial_1: placebo arm
      - trial_2: non-placebo (active) arm

    Fallback:
      - If we can't reliably identify placebo vs active, use indices 0 and 1.
    """
    placebo_idx = None
    active_idx = None

    for i, trial in enumerate(new_trial_data):
        label = str(trial.get('arm_label', '')).lower()
        if 'placebo' in label or 'pbo' in label:
            placebo_idx = i
        else:
            # Take the first non-placebo as the active arm
            if active_idx is None:
                active_idx = i

    if placebo_idx is not None and active_idx is not None:
        log.info(
            f"Selected trial pair for efficacy: placebo index={placebo_idx}, "
            f"active index={active_idx}"
        )
        return placebo_idx, active_idx

    # Fallback: just use first two entries if possible
    if len(new_trial_data) >= 2:
        log.warning(
            "Could not reliably identify placebo/active arms; "
            "defaulting to indices 0 and 1 for efficacy comparison"
        )
        return 0, 1

    raise ValueError("Need at least two trials/arms to compute efficacy")


def run_task(name, model_path, func, *args):
    log.info(f"Starting {name} task")
    try:
        with suppress_output_with_progress(f"{name} prediction"):
            result = func(model_path, *args)
        log.info(f"Completed {name} task")
        return result
    except Exception as e:
        log.error(f"Error in {name} task: {e}", exc_info=True)
        sys.exit(1)


def loader_ae(model_path, new_trial, kgid2x, relname2etype, bert_model, device):
    from utils.demo_utils import load_model_and_data, model_inference, prepare_runner
    (dataset, _), encoder, bert_enc, model, args_, runner = load_model_and_data(model_path, device=device)
    dataset, encoder = add_new_trial_to_dataset(dataset, encoder, new_trial, kgid2x, relname2etype, bert_model)
    args_, runner, encoder = prepare_runner(args_, dataset, encoder, bert_enc, model, device=device)
    _, y_pred, _ = model_inference(runner, mode='test')
    return predict_top_ae(y_pred[0], k=100)


def loader_safety(model_path, new_trial, kgid2x, relname2etype, bert_model, device):
    from utils.demo_utils import load_model_and_data, model_inference, prepare_runner
    (dataset, _), encoder, bert_enc, model, args_, runner = load_model_and_data(model_path, device=device)
    dataset, encoder = add_new_trial_to_dataset(dataset, encoder, new_trial, kgid2x, relname2etype, bert_model)
    args_, runner, encoder = prepare_runner(args_, dataset, encoder, bert_enc, model, device=device)
    _, y_pred, _ = model_inference(runner, mode='test')
    return y_pred[0].item()


def loader_efficacy(model_path, trial1, trial2, kgid2x, relname2etype, bert_model, device):
    from utils.demo_utils import load_model, model_inference, prepare_runner
    (dataset, _), encoder, bert_enc, model, args_ = load_model(model_path)
    dataset, encoder = add_new_trial_to_efficacy_dataset(
        dataset, encoder, trial1, trial2, kgid2x, relname2etype, bert_model
    )
    args_, runner, encoder = prepare_runner(args_, dataset, encoder, bert_enc, model, device=device)
    _, y_pred, _ = model_inference(runner, mode='test')
    return y_pred[0].item()


def main():
    parser = argparse.ArgumentParser(description="Run PlaNet tasks for clinical trial data")
    parser.add_argument('-p', '--pklpath', type=str, required=True, help="Path to the input pickle file")
    parser.add_argument('-o', '--output-dir', type=str, default='results', help="Directory to save results")
    parser.add_argument('-m', '--model-dir', type=str, default='/planet/data/models', help="Base directory for model checkpoints")
    parser.add_argument('-d', '--device', type=str, default='cpu', help="Device for computation (cpu or cuda)")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    configure_paths()
    load_kg_utils()
    kgid2x, x2kgid, relname2etype, etype2relname = load_kg_vocab()

    from utils.text_bert_features import TextBertFeatures
    from gcn_models.utils import set_seed

    set_seed(24)
    bert_model = TextBertFeatures(
        bert_model='/bert_model',
        device=args.device
    )
    log.info("Loaded BERT model")

    new_trial_data = pkl.load(open(args.pklpath, 'rb'))
    log.info(f"Loaded {len(new_trial_data)} trial(s) from {args.pklpath}")

    # Optional: log the arms we have
    for i, t in enumerate(new_trial_data):
        log.info(
            f"Arm [{i}] label={t.get('arm_label')} "
            f"n_edges={len(t.get('trial_arm_edges', []))}"
        )

    # AE predictions (per arm)
    ae_model_path = find_file(os.path.join(args.model_dir, 'ae_model_shxo9bgw', 'ckpt.pt'))
    AE_preds = [
        run_task('AE', ae_model_path, loader_ae, trial, kgid2x, relname2etype, bert_model, args.device)
        for trial in new_trial_data
    ]

    # Safety predictions (per arm)
    saf_model_path = find_file(os.path.join(args.model_dir, 'safety_model_1xekl810', 'ckpt.pt'))
    safety_preds = [
        run_task('Safety', saf_model_path, loader_safety, trial, kgid2x, relname2etype, bert_model, args.device)
        for trial in new_trial_data
    ]

    # Efficacy predictions (if 2+ trials)
    efficacy_pred = None
    if len(new_trial_data) > 1:
        eff_model_path = find_file(os.path.join(args.model_dir, 'efficacy_model_34l5ms9m', 'ckpt.pt'))

        # NEW: choose which two entries to compare (prefer placebo vs active)
        idx1, idx2 = choose_trial_pair(new_trial_data)
        trial1 = new_trial_data[idx1]
        trial2 = new_trial_data[idx2]

        log.info(
            f"Efficacy comparison: trial_1='{trial1.get('arm_label')}', "
            f"trial_2='{trial2.get('arm_label')}'"
        )

        efficacy_pred = run_task(
            'Efficacy', eff_model_path, loader_efficacy,
            trial1, trial2,
            kgid2x, relname2etype, bert_model, args.device
        )
    else:
        log.info("Skipping efficacy (only one trial provided)")

    # Build and save results
    result = {'meta': {}, 'AE': {}, 'safety': {}, 'efficacy': {}}
    for i, trial in enumerate(new_trial_data, start=1):
        result['meta'][f'trial_{i}_label'] = trial['arm_label']
        result['meta'][f'trial_{i}_text'] = trial['arm_text']
    for i, pred in enumerate(AE_preds, start=1):
        result['AE'][f'trial_{i}_ae'] = pred
    for i, pred in enumerate(safety_preds, start=1):
        result['safety'][f'trial_{i}_safety'] = pred
    if efficacy_pred is not None:
        result['efficacy']['prob_trial1_gt_trial2'] = efficacy_pred

    out_file = os.path.join(
        args.output_dir,
        f"result_{os.path.basename(args.pklpath).split('.')[0]}.json"
    )
    with open(out_file, 'w') as f:
        json.dump(result, f, indent=2)
    log.info(f"Saved results to {out_file}")

    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()

#### **PBS Script**

```bash
#!/bin/bash
#PBS -P sq95
#PBS -q normal
#PBS -l ncpus=24
#PBS -l mem=64GB
#PBS -l jobfs=1GB
#PBS -l walltime=48:00:00
#PBS -l wd
#PBS -M pinsy007@mymail.unisa.edu.au
#PBS -m abe
#PBS -N PlaNet_Predict_Many

# ── Paths (host) ────────────────────────────────────────────────────────
PLANET_ROOT=/scratch/sq95/sp6154/planet
NOTEBOOKS_DIR=$PLANET_ROOT/notebooks
PARSE_DIR=$PLANET_ROOT/parsing_package
SIF=$PLANET_ROOT/planet_v2.sif
BERT_MODEL_DIR=$PLANET_ROOT/bert_model

# ── Trial-scoped folders ────────────────────────────────────────────────
TRIAL_ID="NCT04809974" # or another cohort

# Input (host)
INPUT_DIR="$PARSE_DIR/LC_Results/$TRIAL_ID"

# Log + results (host)
LOG_DIR="$PARSE_DIR/logs/$TRIAL_ID"
RESULTS_DIR="$PARSE_DIR/results/$TRIAL_ID"

# ── Load modules & env fixes ────────────────────────────────────────────
module load singularity
export SINGULARITYENV_LD_LIBRARY_PATH="/opt/conda/envs/planet/lib:${LD_LIBRARY_PATH:-}"
export SINGULARITYENV_TRANSFORMERS_CACHE=$BERT_MODEL_DIR
export SINGULARITYENV_HF_HOME=$BERT_MODEL_DIR
export SINGULARITYENV_TRANSFORMERS_OFFLINE=1

# ── Prep output dirs ────────────────────────────────────────────────────
mkdir -p "$RESULTS_DIR" "$LOG_DIR"

# ── Debug info (once) ──────────────────────────────────────────────────
MASTER_LOG="$LOG_DIR/PlaNet_Predict_master.log"
: > "$MASTER_LOG"
echo "[$(date)] Starting multi-trial prediction job (scoped to $TRIAL_ID)" | tee -a "$MASTER_LOG"
echo "  notebooks/data → $(readlink "$NOTEBOOKS_DIR/data" 2>/dev/null)" | tee -a "$MASTER_LOG"
echo "  Host models under parsing_package/data/models:" | tee -a "$MASTER_LOG"
ls -1 "$PARSE_DIR/data/models" 2>&1 | tee -a "$MASTER_LOG"

echo "  Input directory:" | tee -a "$MASTER_LOG"
echo "    $INPUT_DIR" | tee -a "$MASTER_LOG"
echo "  Contents of input directory:" | tee -a "$MASTER_LOG"
ls -1 "$INPUT_DIR" 2>&1 | tee -a "$MASTER_LOG"

# ── Collect PKL files ──────────────────────────────────────────────────
shopt -s nullglob

# Support both old (“*_results.pkl”) and new (“trial_data_*.pkl”) names
PKL_FILES=(
  "$INPUT_DIR"/*_results.pkl
  "$INPUT_DIR"/trial_data_*.pkl
)

if ((${#PKL_FILES[@]} == 0)); then
  echo "[$(date)] ERROR: No *_results.pkl or trial_data_*.pkl files found in $INPUT_DIR" | tee -a "$MASTER_LOG"
  exit 1
fi

echo "[$(date)] Found ${#PKL_FILES[@]} PKL file(s) in $INPUT_DIR" | tee -a "$MASTER_LOG"

# ── Loop over all PKL files ─────────────────────────────────────────────
for INPUT_PKL in "${PKL_FILES[@]}"; do
  BASENAME=$(basename "$INPUT_PKL" .pkl)

  # Derive a nice TRIAL label for logs
  if [[ "$BASENAME" == *_results ]]; then
    # e.g. NCT04678830_results  → NCT04678830
    TRIAL=${BASENAME%_results}
  elif [[ "$BASENAME" == trial_data_* ]]; then
    # e.g. trial_data_NCT04809974_CF_Leronlimab → NCT04809974_CF_Leronlimab
    TRIAL=${BASENAME#trial_data_}
  else
    TRIAL=$BASENAME
  fi

  echo "[$(date)] ===== Trial: $TRIAL (file: $BASENAME.pkl) =====" | tee -a "$MASTER_LOG"

  # Trial-specific log and expected JSON
  LOG="$LOG_DIR/PlaNet_Predict_${BASENAME}.log"
  JSON_OUT="$RESULTS_DIR/result_${BASENAME}.json"

  # Init log for this trial
  : > "$LOG"
  echo "[$(date)] Starting prediction for $INPUT_PKL" | tee -a "$LOG"

  # Run inside container
  echo "[$(date)] Running prediction in Singularity..." | tee -a "$LOG"
  singularity exec \
    --bind "$PLANET_ROOT":/planet \
    --bind "$PARSE_DIR/data":/planet/parsing_package/data \
    --bind "$PARSE_DIR/data":/planet/data \
    --bind "$BERT_MODEL_DIR":/bert_model \
    --pwd /planet/notebooks \
    "$SIF" \
      /opt/conda/envs/planet/bin/python \
        /planet/parsing_package/predict_all_for_new_clinial_trial.py \
          --pklpath "/planet/parsing_package/LC_Results/$TRIAL_ID/${BASENAME}.pkl" \
          --output-dir "/planet/parsing_package/results/$TRIAL_ID" \
    >> "$LOG" 2>&1

  echo "[$(date)] Finished prediction for $TRIAL." | tee -a "$LOG"
  echo "JSON output (expected): $JSON_OUT" | tee -a "$LOG"
  echo "[$(date)] Finished $TRIAL (log: $LOG)" | tee -a "$MASTER_LOG"
done

echo "[$(date)] All trials processed from $INPUT_DIR." | tee -a "$MASTER_LOG"
```

### **Map AE**

#### **Python Script**

In [None]:
#!/usr/bin/env python

import argparse
import json
import os
import pickle


def load_idx2aename(idx_pkl_path: str):
    """
    Load the AE index → KG ID mapping.

    The original PlaNet file ae1017_idx2aename.pkl is a numpy array
    (or list-like) where position i contains the KG ID for AE index i.
    This function normalizes it to a Python dict: {index: kg_id}.
    """
    if not os.path.isfile(idx_pkl_path):
        raise FileNotFoundError(f"idx2aename pkl not found at: {idx_pkl_path}")

    with open(idx_pkl_path, "rb") as f:
        obj = pickle.load(f)

    # If it's already a dict, just return it
    if isinstance(obj, dict):
        return obj

    # Otherwise assume it's list / numpy array-like and convert
    idx2aename = {}
    try:
        for i, v in enumerate(obj):
            idx2aename[i] = v
    except TypeError as e:
        raise TypeError(
            f"Unsupported type for idx2aename (expected dict or sequence), "
            f"got {type(obj)}"
        ) from e

    return idx2aename


def load_kgid2name(kgid_pkl_path: str):
    """
    Load KG ID → human-readable AE label mapping.

    For example, from:
      /planet/notebooks/small_data/ae_kgid2name.pkl
    or, alternatively, from:
      /planet/parsing_package/data/graph/KG_node2name.pkl
    """
    if not os.path.isfile(kgid_pkl_path):
        raise FileNotFoundError(f"KG ID → name pkl not found at: {kgid_pkl_path}")

    with open(kgid_pkl_path, "rb") as f:
        kgid2name = pickle.load(f)

    if not isinstance(kgid2name, dict):
        raise TypeError(
            f"Expected dict for KG ID → name mapping, got {type(kgid2name)}"
        )

    return kgid2name


def map_ae_for_file(json_path: str, idx2aename, kgid2name, out_path: str = None):
    """
    Map AE indices in a PlaNet result JSON file to:
      - KG node IDs (e.g. KG00099326)
      - human-readable AE labels (symptoms)

    Output: TSV with columns:
      trial_id, trial_label, ae_code, ae_kg_id, ae_label, probability
    """
    # 1) Load JSON
    with open(json_path, "r") as f:
        data = json.load(f)

    meta = data.get("meta", {})
    ae_block = data.get("AE", {})

    if not ae_block:
        print(f"[WARN] No 'AE' block found in {json_path}")
        return

    # 2) Prepare output TSV path
    if out_path is None:
        base = os.path.splitext(os.path.basename(json_path))[0]
        out_dir = os.path.dirname(json_path)
        out_path = os.path.join(out_dir, f"{base}_AE_mapped.tsv")

    # 3) Ensure output directory exists
    out_dir = os.path.dirname(out_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    rows = []

    # 4) Loop over each trial_X_ae dict
    for trial_key, ae_dict in ae_block.items():
        # trial_key like "trial_1_ae" → prefix "trial_1" → meta["trial_1_label"]
        prefix = trial_key.replace("_ae", "")
        trial_label_key = f"{prefix}_label"
        trial_label = meta.get(trial_label_key, prefix)

        if not isinstance(ae_dict, dict):
            continue

        # Sort by probability descending, just for readability
        for ae_code_str, prob in sorted(ae_dict.items(),
                                        key=lambda kv: kv[1],
                                        reverse=True):
            # AE code in JSON is a string index
            try:
                idx = int(ae_code_str)
            except ValueError:
                idx = None

            # Step 1: index → KG ID (e.g. "KG00099326")
            if idx is not None:
                kg_id = idx2aename.get(idx, None)
            else:
                kg_id = None

            # Step 2: KG ID → human symptom label
            if kg_id is not None:
                ae_label = kgid2name.get(kg_id, "<UNKNOWN_SYMPTOM>")
            else:
                ae_label = "<INVALID_CODE>"

            rows.append(
                (
                    prefix,           # trial_id
                    trial_label,      # trial_label
                    ae_code_str,      # ae_code (index as string)
                    kg_id,            # ae_kg_id (may be None)
                    ae_label,         # ae_label (symptom text or placeholder)
                    prob,             # probability
                )
            )

    # 5) Write TSV
    with open(out_path, "w") as out_f:
        out_f.write(
            "trial_id\ttrial_label\tae_code\tae_kg_id\tae_label\tprobability\n"
        )
        for trial_id, trial_label, code, kg_id, label, prob in rows:
            out_f.write(
                f"{trial_id}\t{trial_label}\t{code}\t{kg_id}\t{label}\t{prob}\n"
            )

    print(f"[INFO] Wrote AE mapping table to: {out_path}")
    print(f"[INFO] Total rows: {len(rows)}")


def main():
    parser = argparse.ArgumentParser(
        description=(
            "Map AE indices in PlaNet result JSON to KG IDs and "
            "human-readable AE names (symptoms)."
        )
    )
    parser.add_argument(
        "json_path",
        help="Path to result_*.json file (PlaNet output with 'AE' block).",
    )
    parser.add_argument(
        "--idxpkl",
        default="/planet/notebooks/small_data/ae1017_idx2aename.pkl",
        help=(
            "Path to ae1017_idx2aename.pkl "
            "(index → KG ID; default: /planet/notebooks/small_data/ae1017_idx2aename.pkl)"
        ),
    )
    parser.add_argument(
        "--kgid2name",
        default="/planet/notebooks/small_data/ae_kgid2name.pkl",
        help=(
            "Path to ae_kgid2name.pkl (KG ID → AE label). "
            "You can also point this to /planet/parsing_package/data/graph/KG_node2name.pkl "
            "if needed."
        ),
    )
    parser.add_argument(
        "--out",
        default=None,
        help=(
            "Optional output TSV path. If not set, will use "
            "<json_basename>_AE_mapped.tsv next to the JSON."
        ),
    )
    args = parser.parse_args()

    idx2aename = load_idx2aename(args.idxpkl)
    kgid2name = load_kgid2name(args.kgid2name)

    map_ae_for_file(args.json_path, idx2aename, kgid2name, args.out)


if __name__ == "__main__":
    main()

#### **PBS Script**

```bash
#!/bin/bash
#PBS -P sq95
#PBS -q normal
#PBS -l ncpus=1
#PBS -l mem=64GB
#PBS -l jobfs=1GB
#PBS -l walltime=48:00:00
#PBS -l wd
#PBS -M pinsy007@mymail.unisa.edu.au
#PBS -m abe
#PBS -N PlaNet_Map_AE_Only

set -euo pipefail

# ── Trial / cohort ──────────────────────────────────────────────────────
NCT_ID="NCT04809974"

# ── Paths (host) ────────────────────────────────────────────────────────
PLANET_ROOT=/scratch/sq95/sp6154/planet
NOTEBOOKS_DIR=$PLANET_ROOT/notebooks
PARSE_DIR=$PLANET_ROOT/parsing_package
SIF=$PLANET_ROOT/planet_v2.sif
BERT_MODEL_DIR=$PLANET_ROOT/bert_model

# ── Input / output per cohort ───────────────────────────────────────────
RESULTS_DIR="$PARSE_DIR/results/$NCT_ID"
MAP_DIR="$PARSE_DIR/map/$NCT_ID"
LOG_DIR="$PARSE_DIR/logs/$NCT_ID"

# ── Load modules & env fixes ────────────────────────────────────────────
module load singularity
export SINGULARITYENV_LD_LIBRARY_PATH="/opt/conda/envs/planet/lib:${LD_LIBRARY_PATH:-}"
export SINGULARITYENV_TRANSFORMERS_CACHE=$BERT_MODEL_DIR
export SINGULARITYENV_HF_HOME=$BERT_MODEL_DIR
export SINGULARITYENV_TRANSFORMERS_OFFLINE=1

# ── Prep output dirs ────────────────────────────────────────────────────
mkdir -p "$RESULTS_DIR" "$MAP_DIR" "$LOG_DIR"

# ── Debug info (once) ──────────────────────────────────────────────────
MASTER_LOG="$LOG_DIR/PlaNet_MapAE_${NCT_ID}_master.log"
: > "$MASTER_LOG"
echo "[$(date)] Starting AE-mapping job (JSON-only) for $NCT_ID" | tee -a "$MASTER_LOG"
echo "  Input RESULTS_DIR: $RESULTS_DIR" | tee -a "$MASTER_LOG"
echo "  Output MAP_DIR:    $MAP_DIR" | tee -a "$MASTER_LOG"
echo "  Log LOG_DIR:       $LOG_DIR" | tee -a "$MASTER_LOG"
echo "  notebooks/data → $(readlink "$NOTEBOOKS_DIR/data" 2>/dev/null || echo "N/A")" | tee -a "$MASTER_LOG"
echo "  Host models under parsing_package/data/models:" | tee -a "$MASTER_LOG"
ls -1 "$PARSE_DIR/data/models" 2>&1 | tee -a "$MASTER_LOG"

# ── Collect all result_*.json files (from cohort folder) ────────────────
shopt -s nullglob
JSON_FILES=( "$RESULTS_DIR"/result_*.json )

if ((${#JSON_FILES[@]} == 0)); then
  echo "[$(date)] ERROR: No result_*.json files found in $RESULTS_DIR" | tee -a "$MASTER_LOG"
  exit 1
fi

echo "[$(date)] Found ${#JSON_FILES[@]} result_*.json files in $RESULTS_DIR" | tee -a "$MASTER_LOG"

# ── Loop over all JSON files ────────────────────────────────────────────
for JSON_HOST in "${JSON_FILES[@]}"; do
  JSON_FILE_NAME=$(basename "$JSON_HOST")       # e.g. result_trial_data_NCT04809974_CF_Tropisetron.json
  JSON_BASE="${JSON_FILE_NAME%.json}"           # base without .json
  AE_TSV_HOST="$MAP_DIR/${JSON_BASE}_AE_mapped.tsv"

  TRIAL="$JSON_BASE"

  LOG="$LOG_DIR/PlaNet_MapAE_${JSON_BASE}.log"
  : > "$LOG"

  echo "[$(date)] Starting AE mapping for $TRIAL" | tee -a "$LOG"
  echo "[$(date)] JSON file (host): $JSON_HOST" | tee -a "$LOG"
  echo "[$(date)] TSV out (host):   $AE_TSV_HOST" | tee -a "$LOG"

  # ── Run AE mapping inside container ───────────────────────────────────
  singularity exec \
    --bind "$PLANET_ROOT":/planet \
    --bind "$PARSE_DIR/data":/planet/parsing_package/data \
    --bind "$PARSE_DIR/data":/planet/data \
    --bind "$BERT_MODEL_DIR":/bert_model \
    --pwd /planet/parsing_package \
    "$SIF" \
      /opt/conda/envs/planet/bin/python \
        /planet/parsing_package/map_AE.py \
          "/planet/parsing_package/results/$NCT_ID/$JSON_FILE_NAME" \
          --out "/planet/parsing_package/map/$NCT_ID/${JSON_BASE}_AE_mapped.tsv" \
    >> "$LOG" 2>&1

  echo "[$(date)] Finished AE mapping for $TRIAL." | tee -a "$LOG"

  # Sanity check: does the TSV exist on host?
  if [ -f "$AE_TSV_HOST" ]; then
    echo "[$(date)] Confirmed: AE TSV exists at $AE_TSV_HOST" | tee -a "$LOG"
  else
    echo "[$(date)] WARNING: Expected AE TSV not found at $AE_TSV_HOST" | tee -a "$LOG"
  fi

  echo "[$(date)] Current contents of $MAP_DIR (last 20):" | tee -a "$LOG"
  ls -1 "$MAP_DIR" | tail -n 20 | tee -a "$LOG"

  echo "[$(date)] Finished $TRIAL (log: $LOG)" | tee -a "$MASTER_LOG"
done

echo "[$(date)] All JSONs processed for $NCT_ID (AE mapping done)." | tee -a "$MASTER_LOG"
```

### **Score Ranking**

#### **Adverse Events**

##### **No-weighted**

In [45]:
#!/usr/bin/env python3
"""
rank_ae_burden_all_cohorts.py (FIXED)

Compute and rank AE burden for counterfactual drugs across ONE or ALL cohorts.
Now automatically detects which arm is the drug arm based on trial_label.
"""

import os
import glob
import csv
from typing import Dict, List, Optional, Tuple


# -----------------------------------------------------------------------------
# ROOT PATHS (EDIT THESE ONCE)
# -----------------------------------------------------------------------------
MAPPED_ROOT = (
    "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/3_Mapped"
)

RANK_ROOT = (
    "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking"
)

# Labels that indicate a control/placebo arm (case-insensitive)
CONTROL_LABELS = {"placebo", "control", "sham", "standard", "comparator", "soc", "standard of care"}

# Print top N safest per cohort
TOP_PRINT = 20

# If True: run every NCT* folder under MAPPED_ROOT
RUN_ALL_COHORTS = True

# If RUN_ALL_COHORTS is False, set this:
SINGLE_COHORT_ID = "NCT04809974"


def extract_drug_name(filename: str) -> str:
    """
    From:
      result_trial_data_NCT04809974_CF_Maprotiline_AE_mapped.tsv
    return:
      Maprotiline
    """
    base = os.path.splitext(os.path.basename(filename))[0]
    if base.endswith("_AE_mapped"):
        base = base[: -len("_AE_mapped")]
    if "_CF_" in base:
        return base.split("_CF_", 1)[1]
    return base


def get_cohort_ids(mapped_root: str) -> List[str]:
    """Return all cohort folder names like NCT04809974 under mapped_root."""
    if not os.path.isdir(mapped_root):
        return []
    out = []
    for name in sorted(os.listdir(mapped_root)):
        p = os.path.join(mapped_root, name)
        if os.path.isdir(p) and name.startswith("NCT"):
            out.append(name)
    return out


def detect_drug_arm_from_tsv(file_path: str) -> Tuple[Optional[str], Dict[str, str]]:
    """
    Read the TSV and detect which trial_id is the drug arm based on trial_label.
    
    Returns:
        (drug_trial_id, {trial_id: trial_label})
        e.g., ("trial_1", {"trial_1": "active", "trial_2": "control"})
    """
    trial_labels = {}
    
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f, delimiter="\t")
            for row in reader:
                tid = row.get("trial_id", "").strip()
                tlabel = row.get("trial_label", "").strip()
                if tid and tlabel and tid not in trial_labels:
                    trial_labels[tid] = tlabel
                # Stop early once we have both
                if len(trial_labels) >= 2:
                    break
    except Exception:
        pass
    
    # Determine which is the drug arm
    drug_arm = None
    for tid, label in trial_labels.items():
        if label.lower() not in CONTROL_LABELS:
            drug_arm = tid
            break
    
    # Fallback: if we couldn't identify, prefer trial_2 (legacy behavior)
    if drug_arm is None and trial_labels:
        drug_arm = "trial_2" if "trial_2" in trial_labels else list(trial_labels.keys())[0]
    
    return drug_arm, trial_labels


def rank_one_cohort(cohort_id: str, mapped_root: str, rank_root: str) -> Optional[str]:
    """
    Rank AE burden for a single cohort.
    Returns output CSV path if successful, else None.
    """
    input_folder = os.path.join(mapped_root, cohort_id)
    if not os.path.isdir(input_folder):
        print(f"[ERROR] Cohort folder not found: {input_folder}")
        return None

    pattern = os.path.join(
        input_folder,
        f"result_trial_data_{cohort_id}_CF_*_AE_mapped.tsv"
    )
    files = sorted(glob.glob(pattern))

    if not files:
        print(f"[WARN] No mapped AE TSV files found for {cohort_id}")
        print(f"       Pattern: {pattern}")
        return None

    results: List[Dict[str, object]] = []
    arm_info_logged = False

    for path in files:
        # Detect drug arm for this file
        drug_trial_id, trial_labels = detect_drug_arm_from_tsv(path)
        
        if drug_trial_id is None:
            print(f"[WARN] Could not detect drug arm in {os.path.basename(path)}, skipping.")
            continue
        
        # Log arm detection once per cohort
        if not arm_info_logged:
            print(f"\n[INFO] {cohort_id}: Detected arms: {trial_labels}")
            print(f"[INFO] {cohort_id}: Using '{drug_trial_id}' as drug arm (label: '{trial_labels.get(drug_trial_id, 'unknown')}')")
            arm_info_logged = True

        ae_sum = 0.0
        n_events = 0

        try:
            with open(path, "r", encoding="utf-8") as f:
                reader = csv.DictReader(f, delimiter="\t")
                for row in reader:
                    if row.get("trial_id") != drug_trial_id:
                        continue

                    prob_str = (row.get("probability") or "").strip()
                    if prob_str == "":
                        continue

                    try:
                        p = float(prob_str)
                    except ValueError:
                        continue

                    ae_sum += p
                    n_events += 1
        except Exception as e:
            print(f"[WARN] Could not process {os.path.basename(path)}: {e}")
            continue

        if n_events == 0:
            print(f"[WARN] No {drug_trial_id} AE rows in {os.path.basename(path)}, skipping.")
            continue

        drug_name = extract_drug_name(path)
        results.append(
            {
                "drug": drug_name,
                "file": os.path.basename(path),
                "ae_score": float(ae_sum),
                "n_events": int(n_events),
                "trial_id_used": drug_trial_id,
            }
        )

    if not results:
        print(f"[WARN] No valid AE results found for cohort {cohort_id}.")
        return None

    # Rank from lowest AE burden (safer) to highest
    results.sort(key=lambda x: float(x["ae_score"]))

    # Print top N safest
    top_n = min(TOP_PRINT, len(results))
    print(f"\n[{cohort_id}] Top {top_n} drugs by AE profile (lowest total AE burden):\n")
    print(f"{'Rank':<5} {'Drug':<30} {'AE_score (sum prob)':>22} {'n_events':>10}")
    print("-" * 75)
    for i, r in enumerate(results[:top_n], start=1):
        print(
            f"{i:<5} {str(r['drug']):<30} "
            f"{float(r['ae_score']):>22.6f} {int(r['n_events']):>10d}"
        )

    # Output INSIDE cohort subfolder
    output_folder = os.path.join(rank_root, cohort_id)
    os.makedirs(output_folder, exist_ok=True)

    out_csv = os.path.join(
        output_folder,
        f"{cohort_id}_ranked_drugs_by_AE_score.csv"
    )

    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(
            [
                "rank_by_AE_burden",
                "cohort_id",
                "drug",
                "file",
                "ae_score_sum_probability",
                "n_AE_events",
                "trial_id_used",
            ]
        )
        for i, r in enumerate(results, start=1):
            writer.writerow(
                [
                    i,
                    cohort_id,
                    r["drug"],
                    r["file"],
                    f"{float(r['ae_score']):.6f}",
                    int(r["n_events"]),
                    r["trial_id_used"],
                ]
            )

    print(f"\n[{cohort_id}] Full AE ranking saved to:\n{out_csv}")
    return out_csv


def main(cohorts: Optional[List[str]] = None, run_all: bool = True, single_cohort: str = "NCT04809974"):
    """
    Main entry point.
    
    Args:
        cohorts: List of cohort IDs to process. If None, uses auto-discovery or single_cohort.
        run_all: If True and cohorts is None, process all NCT* folders.
        single_cohort: If run_all is False and cohorts is None, process this cohort only.
    """
    if cohorts is not None:
        cohort_ids = cohorts
    elif run_all:
        cohort_ids = get_cohort_ids(MAPPED_ROOT)
        if not cohort_ids:
            print(f"[ERROR] No NCT* cohort folders found under:\n{MAPPED_ROOT}")
            return []
        print(f"[INFO] Found {len(cohort_ids)} cohort folders under 3_Mapped.")
    else:
        cohort_ids = [single_cohort]

    outputs = []
    for cid in cohort_ids:
        out = rank_one_cohort(cid, MAPPED_ROOT, RANK_ROOT)
        if out:
            outputs.append(out)

    print("\n[INFO] Completed.")
    print(f"  Cohorts processed: {len(cohort_ids)}")
    print(f"  Output files written: {len(outputs)}")
    return outputs


if __name__ == "__main__":
    main(run_all=RUN_ALL_COHORTS, single_cohort=SINGLE_COHORT_ID)

[INFO] Found 2 cohort folders under 3_Mapped.

[INFO] NCT04809974: Detected arms: {'trial_1': 'placebo', 'trial_2': 'atp'}
[INFO] NCT04809974: Using 'trial_2' as drug arm (label: 'atp')

[NCT04809974] Top 20 drugs by AE profile (lowest total AE burden):

Rank  Drug                              AE_score (sum prob)   n_events
---------------------------------------------------------------------------
1     Chlorhexidine                                1.015897        100
2     Cetylpyridinium                              1.037758        100
3     Articaine                                    1.077362        100
4     Benzocaine                                   1.294916        100
5     Glycerin                                     1.310743        100
6     Rocuronium                                   1.343040        100
7     Povidone                                     1.415288        100
8     Lactic_acid                                  1.437760        100
9     Mupirocin               

##### **Weighted**

In [46]:
#!/usr/bin/env python3
"""
compute_ae_weighted_scores_all_cohorts.py (FIXED)

Compute AE burden / risk score per arm for one or many cohorts.
Now includes drug_arm detection for downstream filtering.
"""

from pathlib import Path
from typing import List, Optional, Dict, Tuple
import pandas as pd

# =========================
# CONFIG (EDIT THIS ROOT)
# =========================
PLANET_ROOT = Path(
    "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet"
)

MAPPED_ROOT = PLANET_ROOT / "3_Mapped"
SCORE_ROOT = PLANET_ROOT / "4_Score_Ranking"

AE_MASTER_TSV = MAPPED_ROOT / "AE_master_all_with_severity_LC.tsv"

COHORTS_TO_RUN: List[str] = []  # empty => auto-discover
DEFAULT_SEVERITY_WEIGHT = 0.50
DEFAULT_LC_WEIGHT = 1.00

# Labels that indicate a control/placebo arm (case-insensitive)
CONTROL_LABELS = {"placebo", "control", "sham", "standard", "comparator", "soc", "standard of care"}

PRINT_HEAD_N = 10


def infer_cohorts() -> List[str]:
    if not MAPPED_ROOT.is_dir():
        return []
    return sorted([p.name for p in MAPPED_ROOT.glob("NCT*") if p.is_dir()])


def load_ae_master(path: Path) -> pd.DataFrame:
    if not path.is_file():
        raise FileNotFoundError(f"AE master TSV not found: {path}")

    ae = pd.read_csv(path, sep="\t", dtype=str)
    required = ["ae_kg_id", "severity_weight", "lc_weight"]
    missing = [c for c in required if c not in ae.columns]
    if missing:
        raise ValueError(f"AE master table missing columns: {missing}")

    ae["severity_weight"] = pd.to_numeric(ae["severity_weight"], errors="coerce").fillna(DEFAULT_SEVERITY_WEIGHT)
    ae["lc_weight"] = pd.to_numeric(ae["lc_weight"], errors="coerce").fillna(DEFAULT_LC_WEIGHT)
    return ae[["ae_kg_id", "severity_weight", "lc_weight"]]


def load_mapped_files(mapped_dir: Path) -> List[pd.DataFrame]:
    if not mapped_dir.is_dir():
        raise NotADirectoryError(f"Mapped cohort folder not found: {mapped_dir}")

    files = sorted(mapped_dir.glob("*_AE_mapped.tsv"))
    if not files:
        raise FileNotFoundError(f"No *_AE_mapped.tsv files found in: {mapped_dir}")

    dfs: List[pd.DataFrame] = []
    for f in files:
        df = pd.read_csv(f, sep="\t", dtype=str)
        df["source_file"] = f.name
        dfs.append(df)
    return dfs


def detect_drug_arm(df: pd.DataFrame) -> Tuple[Optional[str], Dict[str, str]]:
    """
    Detect which trial_id is the drug arm based on trial_label.
    
    Returns:
        (drug_trial_id, {trial_id: trial_label})
    """
    trial_labels = {}
    for _, row in df[["trial_id", "trial_label"]].drop_duplicates().iterrows():
        tid = str(row.get("trial_id", "")).strip()
        tlabel = str(row.get("trial_label", "")).strip()
        if tid and tlabel:
            trial_labels[tid] = tlabel
    
    # Find the drug arm (not a control label)
    drug_arm = None
    for tid, label in trial_labels.items():
        if label.lower() not in CONTROL_LABELS:
            drug_arm = tid
            break
    
    # Fallback
    if drug_arm is None and trial_labels:
        drug_arm = "trial_2" if "trial_2" in trial_labels else list(trial_labels.keys())[0]
    
    return drug_arm, trial_labels


def merge_with_master(ae_df: pd.DataFrame, master: pd.DataFrame) -> pd.DataFrame:
    if "ae_kg_id" not in ae_df.columns:
        raise ValueError("Expected column 'ae_kg_id' in AE-mapped TSV.")

    merged = ae_df.merge(master, on="ae_kg_id", how="left")

    merged["probability"] = pd.to_numeric(merged.get("probability", 0.0), errors="coerce").fillna(0.0)
    merged["severity_weight"] = pd.to_numeric(merged["severity_weight"], errors="coerce").fillna(DEFAULT_SEVERITY_WEIGHT)
    merged["lc_weight"] = pd.to_numeric(merged["lc_weight"], errors="coerce").fillna(DEFAULT_LC_WEIGHT)

    merged["ae_weighted_prob"] = merged["probability"] * merged["severity_weight"] * merged["lc_weight"]
    return merged


def compute_ae_scores(merged_all: pd.DataFrame) -> pd.DataFrame:
    group_cols = ["source_file", "trial_id", "trial_label"]
    out = (
        merged_all.groupby(group_cols, as_index=False)
        .agg(
            n_ae=("ae_code", "nunique"),
            total_prob=("probability", "sum"),
            total_weighted_prob=("ae_weighted_prob", "sum"),
        )
    )
    out["ae_score"] = out["total_weighted_prob"]
    
    # Add is_drug_arm flag
    out["is_drug_arm"] = out["trial_label"].str.lower().apply(lambda x: x not in CONTROL_LABELS)
    
    return out


def run_one_cohort(cohort: str) -> Optional[Path]:
    mapped_dir = MAPPED_ROOT / cohort
    out_dir = SCORE_ROOT / cohort
    out_dir.mkdir(parents=True, exist_ok=True)

    out_tsv = out_dir / f"{cohort}_AE_weighted_scores.tsv"

    try:
        ae_master = load_ae_master(AE_MASTER_TSV)
        mapped_dfs = load_mapped_files(mapped_dir)
    except Exception as e:
        print(f"[WARN] {cohort}: {e}")
        return None

    merged_all = pd.concat([merge_with_master(df, ae_master) for df in mapped_dfs], ignore_index=True)

    # Detect and log arm structure
    drug_arm, trial_labels = detect_drug_arm(merged_all)
    print(f"\n[INFO] {cohort}: Detected arms: {trial_labels}")
    print(f"[INFO] {cohort}: Drug arm = '{drug_arm}' (label: '{trial_labels.get(drug_arm, 'unknown')}')")

    scores = compute_ae_scores(merged_all)

    # Save
    scores.to_csv(out_tsv, sep="\t", index=False)
    print(f"[INFO] {cohort}: Wrote AE weighted scores to:\n       {out_tsv}")

    # Print first lines
    print(f"[PREVIEW] {cohort}: First {min(PRINT_HEAD_N, len(scores))} rows:")
    print(scores.head(PRINT_HEAD_N).to_string(index=False))

    return out_tsv


def main(cohorts: Optional[List[str]] = None):
    """
    Main entry point.
    
    Args:
        cohorts: List of cohort IDs to process. If None, uses auto-discovery.
    """
    if cohorts is None:
        cohorts = COHORTS_TO_RUN if COHORTS_TO_RUN else infer_cohorts()
    
    if not cohorts:
        print(f"[ERROR] No cohorts found in: {MAPPED_ROOT}")
        return []

    print(f"[INFO] Cohorts to process: {len(cohorts)}")
    outputs = []
    for c in cohorts:
        out = run_one_cohort(c)
        if out:
            outputs.append(out)
    
    return outputs


if __name__ == "__main__":
    main()

[INFO] Cohorts to process: 2

[INFO] NCT04809974: Detected arms: {'trial_1': 'placebo', 'trial_2': 'beta-sitosterol'}
[INFO] NCT04809974: Drug arm = 'trial_2' (label: 'beta-sitosterol')
[INFO] NCT04809974: Wrote AE weighted scores to:
       /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_AE_weighted_scores.tsv
[PREVIEW] NCT04809974: First 10 rows:
                                               source_file trial_id trial_label  n_ae  total_prob  total_weighted_prob  ae_score  is_drug_arm
               result_trial_data_NCT04809974_AE_mapped.tsv  trial_1     placebo   100    5.621217             2.534269  2.534269        False
               result_trial_data_NCT04809974_AE_mapped.tsv  trial_2      niagen   100    5.621217             2.534269  2.534269         True
        result_trial_data_NCT04809974_CF_ATP_AE_mapped.tsv  trial_1     placebo   100    2.459222             1.092514  1.0

##### **Master Table**

In [None]:
#!/usr/bin/env python3
"""
Auto-annotate AE master table with:
- CTCAE-inspired severity_category + severity_weight
- Long COVID relevance category (lc_relevance_category) + lc_weight

This is heuristic / rule-based, designed to be:
- transparent
- reproducible
- easy to tweak
"""

import pandas as pd
import re
from pathlib import Path

# =========  CONFIG  ========= #

INPUT_TSV = "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/3_Mapped/AE_master_all.tsv"   # our current master
OUTPUT_TSV = "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/3_Mapped/AE_master_all_with_severity_LC.tsv"

# Severity weights (CTCAE-inspired)
SEVERITY_WEIGHT_MAP = {
    "mild": 0.25,
    "moderate": 0.50,
    "severe": 0.75,
    "life_threatening": 1.00,
    "death": 1.20,
}

# Long COVID relevance weights
LC_WEIGHT_MAP = {
    "core_LC": 1.50,
    "important_comorbidity": 1.25,
    "general": 1.00,
}


def has(label: str, keywords):
    """Case-insensitive substring check against a list of keywords."""
    text = label.lower()
    return any(k in text for k in keywords)


def classify_severity(label: str) -> str:
    """
    Map AE label to a severity category using CTCAE-inspired logic.
    You can refine keyword lists as you review the output.
    """
    text = label.lower()

    # ---- Death / cardiac arrest / brain death etc. ----
    death_kw = [
        "death", "sudden death", "cardiorespiratory arrest",
        "cardiac arrest", "brain stem infarction", "unattended death",
    ]
    if has(text, death_kw):
        return "death"

    # ---- Life-threatening (Grade 4-like) ----
    life_threatening_kw = [
        "shock", "septicemia", "sepsis", "septic shock",
        "respiratory arrest", "respiratory insufficiency", "respiratory distress",
        "arrest", "cardiogenic shock", "pulmonary embolism", "embolism, pulmonary",
        "acute pulmonary oedema", "ventricular fibrillation", "ventricular arrhythmia",
        "acute heart failure", "acute right ventricular failure",
        "liver failure", "acute liver failure", "failure, acute liver",
        "encephalopathy", "status epilepticus", "myocardial infarction",
        "infarction, myocardial", "cerebral hemorrhage", "brain haemorrhage",
        "intracranial haemorrhage", "subarachnoid haemorrhage",
        "cardiopulmonary insufficiency",
    ]
    if has(text, life_threatening_kw):
        return "life_threatening"

    # ---- Severe (Grade 3-like) ----
    severe_kw = [
        "hospitalisation", "hospitalization", "fracture", "fractured",
        "necrosis", "gangrene", "perforation", "obstruction", "ileus",
        "pneumonia", "aspiration pneumonia", "staphylococcal pneumonia",
        "acute kidney failure", "renal failure", "kidney failure",
        "chronic kidney failure", "pulmonary oedema",
        "stroke", "infarction", "attack, transient ischemic",
        "tamponade", "thrombosis", "embolism", "myocarditis",
        "pulmonary tuberculosis", "tuberculosis", "neutropenia, febrile",
        "neutropenic sepsis", "lymphopenias",
        "malignant", "carcinoma", "lymphoma", "leukemia",
        "gastric ulcer, acute with haemorrhage",
        "haemorrhage", "hemorrhage", "gi bleeding", "gastrointestinal haemorrhage",
        "large intestine perforation", "ulcer, leg", "skin ulcer infected",
        "osteomyelitis", "necrotizing fasciitis",
        "acute hepatitis", "cholangitis", "pancreatitis",
        "dvt", "deep vein thrombosis", "venous thrombosis", "venous thromboses",
    ]
    if has(text, severe_kw):
        return "severe"

    # ---- Moderate (Grade 2-like) ----
    moderate_kw = [
        "severe",  # sometimes appears as 'severe headache etc'
        "fractures", "pneumonia", "bronchitis", "bronchiectasis",
        "syncope", "collapse", "tachycardia", "paroxysmal tachycardia",
        "atrial fibrillation", "supraventricular tachycardia",
        "heart failure", "left ventricular dysfunction",
        "dyspnoea", "dyspnea", "shortness of breath",
        "orthostatic hypotension",
        "hyperglycaemia", "hypoglycaemia", "hypernatraemia",
        "hypokalaemia", "hyperkalaemia", "hypercalcaemia", "hypocalcaemia",
        "diabetes", "ketoacidosis",
        "transaminase", "alt", "ast", "gamma-glutamyltransferase",
        "elevated liver enzymes",
        "encephalitis", "epilepsy", "partial epilepsies",
        "myocardial ischemia", "ischemia, myocardial",
        "thrombocytopaenia", "neutropenia", "leukopenia",
        "cachexia", "malnutrition", "protein-energy",
        "vision, blurring", "retinal detachment",
        "urinary tract infection", "pyelonephritis",
        "peritonitis", "cellulitis", "empyema",
        "pulmonary infiltrates", "interstitial pneumonia",
        "psychotic disorders", "suicidal", "suicide",
        "major depression", "bipolar disorder",
        "myositis", "myalgias"  # moderate not mild in chronic LC context
    ]
    if has(text, moderate_kw):
        return "moderate"

    # ---- Default: Mild (Grade 1-like) ----
    return "mild"


def classify_lc_relevance(label: str) -> str:
    """
    Classify LC relevance:
    - core_LC
    - important_comorbidity
    - general
    """
    text = label.lower()

    # ---- Core LC: cardiopulmonary, vascular, neurocognitive, autonomic, fatigue ----
    core_lc_kw = [
        # vascular / thromboembolic / stroke
        "thrombosis", "embolism", "pulmonary embolism", "embolism, pulmonary",
        "stroke", "infarction", "cerebral infarction", "brain infarction",
        "ischemic", "ischaemic", "attack, transient ischemic",
        "venous thrombosis", "deep vein thrombosis",
        # cardiac / dysautonomia / arrhythmias
        "myocardial infarction", "angina pectoris", "heart failure",
        "cardiomyopathy", "arrhythmia", "tachycardia", "fibrillation",
        "orthostatic hypotension", "syncope", "postural dizziness",
        # pulmonary / respiratory
        "bronchopneumonia", "pneumonia", "interstitial pneumonia",
        "respiratory insufficiency", "respiratory distress",
        "respiratory arrest", "pulmonary congestion",
        "bronchospasm", "dyspnoea", "dyspnea", "shortness of breath",
        # neurocognitive / fatigue
        "confusion", "disorders, cognition", "memory", "impairments memory",
        "encephalopathy", "brain damage", "brain disease",
        "headache", "headache nos", "migraine",
        "asthenia", "fatigue", "lethargy",
        # autonomic / pots-like
        "palpitations", "postural", "orthostatic", "tachycardia, sinus",
        # thromboinflammatory / autoimmune flare
        "vasculitis", "arteritis", "thrombophlebitis",
        # key LC organ damage / sequelae
        "fibrosis, pulmonary", "bronchiectasis", "cardiopulmonary insufficiency",
    ]
    if has(text, core_lc_kw):
        return "core_LC"

    # ---- Important comorbidities: metabolic, renal, hepatic, malignancy ----
    important_comorb_kw = [
        # metabolic / diabetes / obesity / lipids
        "diabetes", "diabetic", "obesity", "dyslipidaemia",
        "hypercholesterolaemia", "hyperlipidaemia", "metabolic",
        # renal
        "kidney failure", "acute kidney failure", "chronic kidney failure",
        "renal failure", "hydronephrosis", "nephrosis",
        # liver
        "hepatitis", "cirrhosis", "liver failure", "liver disease",
        "transaminase", "alanine aminotransferase", "aspartate aminotransferase",
        "gamma-glutamyltransferase", "bilirubin",
        # malignancy / major cancer
        "malignant", "carcinoma", "lymphoma", "leukemia", "neoplasm",
        # serious cardiovascular risk factors
        "hypertensive", "hypertension", "coronary heart disease",
        "ischemic heart disease", "cardiomyopathy",
        # serious infections
        "tuberculosis", "mycobacterium", "opportunistic infection",
        "pneumocystis", "septicemia", "sepsis",
    ]
    if has(text, important_comorb_kw):
        return "important_comorbidity"

    # ---- Default: general ----
    return "general"


def annotate_table(df: pd.DataFrame) -> pd.DataFrame:
    """Apply classification functions and add weight columns."""
    # Ensure required columns exist
    assert "ae_label" in df.columns, "Expected column 'ae_label' not found."

    # Fill NaN labels with empty strings to avoid errors
    df["ae_label"] = df["ae_label"].fillna("")

    # Severity category
    df["severity_category"] = df["severity_category"].fillna("").replace("", None)
    df["severity_category"] = df.apply(
        lambda row: row["severity_category"]
        if row["severity_category"] not in [None, ""]
        else classify_severity(row["ae_label"]),
        axis=1,
    )

    # Severity weight
    df["severity_weight"] = df["severity_category"].map(SEVERITY_WEIGHT_MAP)

    # LC relevance category
    df["lc_relevance_category"] = df["lc_relevance_category"].fillna("").replace("", None)
    df["lc_relevance_category"] = df.apply(
        lambda row: row["lc_relevance_category"]
        if row["lc_relevance_category"] not in [None, ""]
        else classify_lc_relevance(row["ae_label"]),
        axis=1,
    )

    # LC weight
    df["lc_weight"] = df["lc_relevance_category"].map(LC_WEIGHT_MAP)

    return df


def main():
    in_path = Path(INPUT_TSV)
    if not in_path.is_file():
        raise FileNotFoundError(f"Input TSV not found: {in_path}")

    df = pd.read_csv(in_path, sep="\t", dtype=str)

    # Make sure numeric columns that should stay numeric later can be cast,
    # but we keep everything as str during annotation to avoid surprises.
    df = annotate_table(df)

    # 🔎 Print first lines as a quick preview
    print("[INFO] Preview of first 10 rows after annotation:")
    print(df.head(10).to_string(index=False))

    out_path = Path(OUTPUT_TSV)
    df.to_csv(out_path, sep="\t", index=False)
    print(f"[INFO] Wrote annotated AE master table to: {out_path}")


if __name__ == "__main__":
    main()

[INFO] Preview of first 10 rows after annotation:
ae_index   ae_kg_id                ae_label severity_category  severity_weight lc_relevance_category  lc_weight
       0 KG00099268         Hospitalisation            severe             0.75               general        1.0
       1 KG00099959         Neutropenia NOS          moderate             0.50               general        1.0
       2 KG00100772       thrombocytopaenia          moderate             0.50               general        1.0
       3 KG00099365 Orthostatic Hypotension          moderate             0.50               core_LC        1.5
       4 KG00098817               epistaxis              mild             0.25               general        1.0
       5 KG00099169              haematuria              mild             0.25               general        1.0
       6 KG00108262       Brain haemorrhage  life_threatening             1.00               general        1.0
       7 KG00097787    Aggressive Behaviour           

In [4]:
#!/usr/bin/env python3
"""
compute_ae_weighted_scores_all_cohorts.py

Compute an AE burden / risk score per arm (placebo & drug) for one or more cohorts,
using:
- PlaNet AE-mapped TSVs (per drug / cohort): *_AE_mapped.tsv
- AE master table with severity + Long COVID (LC) weights

For each AE row:
    weighted_prob = probability * severity_weight * lc_weight

For each arm (trial_id + trial_label) within each source_file:
    AE_score = sum(weighted_prob)

Higher AE_score = worse AE profile (more / more severe / more LC-relevant AEs).

This version supports:
- Processing ALL cohort folders under a base directory (e.g., 3_Mapped/)
  where cohorts are named like NCT04809974, NCT04880161, NCT05576662, ...
- Writing one output file per cohort INSIDE that cohort’s folder
  OR to a separate scoring folder (configurable)
- Option A preview: prints the first lines (head) of the computed scores per cohort
"""

from __future__ import annotations

import sys
from pathlib import Path
from typing import List, Optional

import pandas as pd

# =============================================================================
# CONFIG (EDIT THESE)
# =============================================================================

MAPPED_BASE_DIR = Path(
    "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/3_Mapped"
)

# AE master table with severity + LC annotations (single shared file)
AE_MASTER_TSV = Path(
    "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/3_Mapped/AE_master_all_with_severity_LC.tsv"
)

# Output mode:
# - If True: write outputs inside each cohort folder: 3_Mapped/<COHORT>/<COHORT>_AE_weighted_scores.tsv
# - If False: write outputs to SCORE_RANKING_DIR as: 4_Score_Ranking/<COHORT>_AE_weighted_scores.tsv
WRITE_OUTPUT_INSIDE_COHORT = True

# Used only when WRITE_OUTPUT_INSIDE_COHORT = False
SCORE_RANKING_DIR = Path(
    "/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/3_Mapped/4_Score_Ranking"
)

# If you want to run only a subset, list them here; otherwise leave empty to auto-discover all NCT* folders
COHORTS_TO_RUN: List[str] = []  # e.g. ["NCT04809974", "NCT05576662"]

# Defaults if any AE is missing weights
DEFAULT_SEVERITY_WEIGHT = 0.50
DEFAULT_LC_WEIGHT = 1.00

# Print preview (Option A): number of rows to print per cohort
PREVIEW_N = 10

# Also print top-N worst arms by AE_score per cohort (optional)
PRINT_TOP_WORST_N = 20


# =============================================================================
# HELPERS
# =============================================================================

def discover_cohorts(base_dir: Path) -> List[Path]:
    """Return cohort directories under base_dir that look like NCT* and are directories."""
    if not base_dir.is_dir():
        raise NotADirectoryError(f"MAPPED_BASE_DIR not found: {base_dir}")

    dirs = sorted([p for p in base_dir.iterdir() if p.is_dir() and p.name.upper().startswith("NCT")])
    return dirs


def load_ae_master(path: Path) -> pd.DataFrame:
    """
    Load AE master table with severity + LC weights.

    Required columns:
        ae_kg_id,
        severity_weight,
        lc_weight
    """
    if not path.is_file():
        raise FileNotFoundError(f"AE master TSV not found: {path}")

    ae = pd.read_csv(path, sep="\t", dtype=str)

    required = ["ae_kg_id", "severity_weight", "lc_weight"]
    missing = [c for c in required if c not in ae.columns]
    if missing:
        raise ValueError(f"AE master table missing columns: {missing}")

    # cast weights
    ae["severity_weight"] = pd.to_numeric(ae["severity_weight"], errors="coerce").fillna(DEFAULT_SEVERITY_WEIGHT)
    ae["lc_weight"] = pd.to_numeric(ae["lc_weight"], errors="coerce").fillna(DEFAULT_LC_WEIGHT)

    # keep only needed columns
    ae = ae[["ae_kg_id", "severity_weight", "lc_weight"]].copy()
    return ae


def load_ae_mapped_files(mapped_dir: Path) -> pd.DataFrame:
    """
    Load all *_AE_mapped.tsv in a cohort directory and concatenate into one df
    with a 'source_file' column.
    """
    if not mapped_dir.is_dir():
        raise NotADirectoryError(f"Cohort folder not found: {mapped_dir}")

    files = sorted(mapped_dir.glob("*_AE_mapped.tsv"))
    if not files:
        raise FileNotFoundError(f"No *_AE_mapped.tsv files found in {mapped_dir}")

    dfs = []
    for f in files:
        df = pd.read_csv(f, sep="\t", dtype=str)
        df["source_file"] = f.name
        dfs.append(df)

    all_df = pd.concat(dfs, ignore_index=True)
    return all_df


def merge_with_master(ae_df: pd.DataFrame, master: pd.DataFrame) -> pd.DataFrame:
    """
    Merge AE-mapped table with master severity/LC weights on ae_kg_id.
    Adds:
      - severity_weight
      - lc_weight
      - ae_weighted_prob
    """
    if "ae_kg_id" not in ae_df.columns:
        raise ValueError("Expected column 'ae_kg_id' in AE-mapped TSV.")
    if "probability" not in ae_df.columns:
        raise ValueError("Expected column 'probability' in AE-mapped TSV.")
    if "trial_id" not in ae_df.columns or "trial_label" not in ae_df.columns:
        raise ValueError("Expected columns 'trial_id' and 'trial_label' in AE-mapped TSV.")
    if "ae_code" not in ae_df.columns:
        raise ValueError("Expected column 'ae_code' in AE-mapped TSV.")
    if "source_file" not in ae_df.columns:
        raise ValueError("Missing 'source_file' column (internal).")

    merged = ae_df.merge(master, on="ae_kg_id", how="left")

    merged["probability"] = pd.to_numeric(merged["probability"], errors="coerce").fillna(0.0)
    merged["severity_weight"] = pd.to_numeric(merged["severity_weight"], errors="coerce").fillna(DEFAULT_SEVERITY_WEIGHT)
    merged["lc_weight"] = pd.to_numeric(merged["lc_weight"], errors="coerce").fillna(DEFAULT_LC_WEIGHT)

    merged["ae_weighted_prob"] = merged["probability"] * merged["severity_weight"] * merged["lc_weight"]
    return merged


def compute_ae_scores(merged_all: pd.DataFrame) -> pd.DataFrame:
    """
    Summarise AE burden per arm (trial_id + trial_label + source_file).

    Outputs columns:
        source_file, trial_id, trial_label,
        n_ae, total_prob, total_weighted_prob, ae_score
    """
    group_cols = ["source_file", "trial_id", "trial_label"]

    grouped = (
        merged_all
        .groupby(group_cols, as_index=False)
        .agg(
            n_ae=("ae_code", "nunique"),
            total_prob=("probability", "sum"),
            total_weighted_prob=("ae_weighted_prob", "sum"),
        )
    )

    grouped["ae_score"] = grouped["total_weighted_prob"]
    return grouped


def output_path_for_cohort(cohort_name: str, cohort_dir: Path) -> Path:
    if WRITE_OUTPUT_INSIDE_COHORT:
        return cohort_dir / f"{cohort_name}_AE_weighted_scores.tsv"
    SCORE_RANKING_DIR.mkdir(parents=True, exist_ok=True)
    return SCORE_RANKING_DIR / f"{cohort_name}_AE_weighted_scores.tsv"


def process_one_cohort(cohort_dir: Path, ae_master: pd.DataFrame) -> Optional[Path]:
    cohort = cohort_dir.name

    try:
        mapped_all = load_ae_mapped_files(cohort_dir)
    except FileNotFoundError as e:
        print(f"[WARN] {cohort}: {e}")
        return None

    merged = merge_with_master(mapped_all, ae_master)
    scores = compute_ae_scores(merged)

    # ---------- OPTION A PREVIEW ----------
    print(f"\n[PREVIEW] {cohort} — first {PREVIEW_N} rows of AE weighted scores:")
    print(scores.head(PREVIEW_N).to_string(index=False))
    print()

    # Save
    out_path = output_path_for_cohort(cohort, cohort_dir)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    scores.to_csv(out_path, sep="\t", index=False)
    print(f"[INFO] {cohort}: wrote AE scores to: {out_path}")

    # Optional: top worst arms
    if PRINT_TOP_WORST_N and PRINT_TOP_WORST_N > 0:
        topn = min(PRINT_TOP_WORST_N, len(scores))
        print(f"\n[INFO] {cohort}: Top {topn} arms by AE_score (higher = worse):")
        print(scores.sort_values("ae_score", ascending=False).head(topn).to_string(index=False))

    return out_path


# =============================================================================
# MAIN
# =============================================================================

def main():
    print("[INFO] Loading AE master table...")
    ae_master = load_ae_master(AE_MASTER_TSV)
    print(f"[INFO] AE master rows: {len(ae_master):,}")

    # Determine cohort folders to process
    if COHORTS_TO_RUN:
        cohort_dirs = []
        for c in COHORTS_TO_RUN:
            p = MAPPED_BASE_DIR / c
            if not p.is_dir():
                print(f"[WARN] Cohort folder not found, skipping: {p}")
                continue
            cohort_dirs.append(p)
    else:
        cohort_dirs = discover_cohorts(MAPPED_BASE_DIR)

    if not cohort_dirs:
        print(f"[ERROR] No cohort folders found under: {MAPPED_BASE_DIR}")
        sys.exit(1)

    print(f"[INFO] Cohorts to process: {len(cohort_dirs)}")
    for p in cohort_dirs:
        print(f"  - {p.name}")

    outputs: List[Path] = []
    for cohort_dir in cohort_dirs:
        out = process_one_cohort(cohort_dir, ae_master)
        if out is not None:
            outputs.append(out)

    print("\n[INFO] Done.")
    print(f"[INFO] Outputs written: {len(outputs)}")
    for o in outputs:
        print(f"  - {o}")


if __name__ == "__main__":
    main()

[INFO] Loading AE master table...
[INFO] AE master rows: 1,017
[INFO] Cohorts to process: 2
  - NCT04809974
  - NCT04880161

[PREVIEW] NCT04809974 — first 10 rows of AE weighted scores:
                                               source_file trial_id trial_label  n_ae  total_prob  total_weighted_prob  ae_score
               result_trial_data_NCT04809974_AE_mapped.tsv  trial_1     placebo   100    5.621217             2.534269  2.534269
               result_trial_data_NCT04809974_AE_mapped.tsv  trial_2      niagen   100    5.621217             2.534269  2.534269
        result_trial_data_NCT04809974_CF_ATP_AE_mapped.tsv  trial_1     placebo   100    2.459222             1.092514  1.092514
        result_trial_data_NCT04809974_CF_ATP_AE_mapped.tsv  trial_2         atp   100    2.459222             1.092514  1.092514
   result_trial_data_NCT04809974_CF_Abacavir_AE_mapped.tsv  trial_1     placebo   100    3.353535             1.367907  1.367907
   result_trial_data_NCT04809974_CF_Abac

#### **Safety**

In [47]:
#!/usr/bin/env python3
"""
rank_by_safety_all_cohorts.py (FIXED)

Rank counterfactual drugs by safety for one or many cohorts.
Now automatically detects which trial arm is the drug arm based on meta labels.
"""

from pathlib import Path
import json
import csv
from typing import List, Dict, Optional, Tuple

# =========================
# CONFIG (EDIT THIS ROOT)
# =========================
PLANET_ROOT = Path(
    "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet"
)

PREDICTED_ROOT = PLANET_ROOT / "2_Predicted"
SCORE_ROOT = PLANET_ROOT / "4_Score_Ranking"

COHORTS_TO_RUN: List[str] = []  # empty => auto-discover
TOP_PRINT_N = 20

# Labels that indicate a control/placebo arm (case-insensitive)
CONTROL_LABELS = {"placebo", "control", "sham", "standard", "comparator", "soc", "standard of care"}


def extract_drug_name(json_path: Path, cohort: str) -> str:
    """
    From 'result_trial_data_<COHORT>_CF_Abacavir.json' -> 'Abacavir'
    """
    base = json_path.stem
    tag = f"result_trial_data_{cohort}_CF_"
    if base.startswith(tag):
        return base[len(tag):]
    if "_CF_" in base:
        return base.split("_CF_", 1)[1]
    return base


def infer_cohorts() -> List[str]:
    if not PREDICTED_ROOT.is_dir():
        return []
    return sorted([p.name for p in PREDICTED_ROOT.glob("NCT*") if p.is_dir()])


def detect_drug_arm_from_json(data: dict) -> Tuple[str, Dict[str, str]]:
    """
    Detect which trial is the drug arm based on meta labels.
    
    Returns:
        (drug_trial_num, {trial_num: label})
        e.g., ("1", {"1": "active", "2": "control"})
    """
    meta = data.get("meta", {}) or {}
    
    trial_labels = {}
    for key, value in meta.items():
        if key.endswith("_label"):
            # Extract trial number from "trial_1_label" -> "1"
            parts = key.replace("_label", "").split("_")
            if len(parts) >= 2 and parts[-1].isdigit():
                trial_num = parts[-1]
                trial_labels[trial_num] = str(value).strip()
    
    # Find drug arm (not a control label)
    drug_arm = None
    for trial_num, label in trial_labels.items():
        if label.lower() not in CONTROL_LABELS:
            drug_arm = trial_num
            break
    
    # Fallback to trial_2 if we can't determine
    if drug_arm is None:
        drug_arm = "2" if "2" in trial_labels else ("1" if trial_labels else "2")
    
    return drug_arm, trial_labels


def rank_one_cohort(cohort: str) -> Optional[Path]:
    in_dir = PREDICTED_ROOT / cohort
    out_dir = SCORE_ROOT / cohort
    out_dir.mkdir(parents=True, exist_ok=True)

    pattern = f"result_trial_data_{cohort}_CF_*.json"
    files = sorted(in_dir.glob(pattern))

    if not files:
        print(f"[WARN] {cohort}: No JSON files found matching: {in_dir / pattern}")
        return None

    results: List[Dict[str, object]] = []
    arm_info_logged = False

    for path in files:
        try:
            with open(path, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] {cohort}: Could not read {path.name}: {e}")
            continue

        # Detect drug arm
        drug_arm_num, trial_labels = detect_drug_arm_from_json(data)
        
        if not arm_info_logged:
            print(f"\n[INFO] {cohort}: Detected arms: {trial_labels}")
            print(f"[INFO] {cohort}: Using trial_{drug_arm_num} as drug arm (label: '{trial_labels.get(drug_arm_num, 'unknown')}')")
            arm_info_logged = True

        # Get safety value for drug arm
        safety_block = data.get("safety", {}) or {}
        safety_key = f"trial_{drug_arm_num}_safety"
        safety_drug = safety_block.get(safety_key, None)
        
        # Fallback to any available safety value
        if safety_drug is None:
            for key in ["trial_2_safety", "trial_1_safety"]:
                if key in safety_block:
                    safety_drug = safety_block[key]
                    break

        if safety_drug is None:
            print(f"[WARN] {cohort}: No valid safety value in {path.name}, skipping.")
            continue

        try:
            safety_drug = float(safety_drug)
        except Exception:
            print(f"[WARN] {cohort}: Bad safety value in {path.name}, skipping.")
            continue

        drug_name = extract_drug_name(path, cohort)
        results.append({
            "drug": drug_name, 
            "file": path.name, 
            "safety_drug": safety_drug,
            "trial_used": f"trial_{drug_arm_num}",
        })

    if not results:
        print(f"[WARN] {cohort}: No valid safety results found.")
        return None

    # higher = safer
    results.sort(key=lambda x: x["safety_drug"], reverse=True)

    # print top N
    top_n = min(TOP_PRINT_N, len(results))
    print(f"\n[INFO] {cohort}: Top {top_n} drugs by safety (highest to lowest):")
    print(f"{'Rank':<5} {'Drug':<30} {'Safety':>10}")
    print("-" * 55)
    for i, r in enumerate(results[:top_n], start=1):
        print(f"{i:<5} {str(r['drug']):<30} {float(r['safety_drug']):>10.4f}")

    # save
    out_csv = out_dir / f"{cohort}_ranked_drugs_by_safety.csv"
    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["rank", "drug", "file", "safety_drug", "trial_used"])
        for i, r in enumerate(results, start=1):
            w.writerow([i, r["drug"], r["file"], f"{float(r['safety_drug']):.6f}", r["trial_used"]])

    print(f"[INFO] {cohort}: Saved full safety ranking to:\n       {out_csv}")
    return out_csv


def main(cohorts: Optional[List[str]] = None):
    """
    Main entry point.
    
    Args:
        cohorts: List of cohort IDs to process. If None, uses auto-discovery.
    """
    if cohorts is None:
        cohorts = COHORTS_TO_RUN if COHORTS_TO_RUN else infer_cohorts()
    
    if not cohorts:
        print(f"[ERROR] No cohorts found in: {PREDICTED_ROOT}")
        return []

    print(f"[INFO] Cohorts to process: {len(cohorts)}")
    outputs = []
    for c in cohorts:
        out = rank_one_cohort(c)
        if out:
            outputs.append(out)
    
    return outputs


if __name__ == "__main__":
    main()

[INFO] Cohorts to process: 2

[INFO] NCT04809974: Detected arms: {'1': 'placebo', '2': 'atp'}
[INFO] NCT04809974: Using trial_2 as drug arm (label: 'atp')

[INFO] NCT04809974: Top 20 drugs by safety (highest to lowest):
Rank  Drug                               Safety
-------------------------------------------------------
1     Trametinib                         0.6821
2     Ceritinib                          0.6131
3     Lenvatinib                         0.5978
4     Pomalidomide                       0.5873
5     Selumetinib                        0.5677
6     Regorafenib                        0.5447
7     Idelalisib                         0.5436
8     Leniolisib                         0.5342
9     Tolvaptan                          0.5322
10    Ruxolitinib                        0.5290
11    Carfilzomib                        0.5249
12    Dabrafenib                         0.5226
13    Pentostatin                        0.5197
14    Dasatinib                          0.5178
15  

#### **Efficacy**

In [49]:
#!/usr/bin/env python3
"""
rank_by_efficacy_all_cohorts.py (FIXED)

Rank counterfactual drugs by efficacy for one or many cohorts.
Now automatically detects arm structure and computes P(drug > control) correctly.

The model outputs prob_trial1_gt_trial2 = P(trial_1 > trial_2).
- If trial_1 is drug:  P(drug > control) = prob_trial1_gt_trial2
- If trial_2 is drug:  P(drug > control) = 1 - prob_trial1_gt_trial2
"""

from pathlib import Path
import json
import csv
from typing import List, Dict, Optional, Tuple

# =========================
# CONFIG (EDIT THIS ROOT)
# =========================
PLANET_ROOT = Path(
    "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet"
)

PREDICTED_ROOT = PLANET_ROOT / "2_Predicted"
SCORE_ROOT = PLANET_ROOT / "4_Score_Ranking"

COHORTS_TO_RUN: List[str] = []  # empty => auto-discover
TOP_PRINT_N = 20

# Labels that indicate a control/placebo arm (case-insensitive)
CONTROL_LABELS = {"placebo", "control", "sham", "standard", "comparator", "soc", "standard of care"}


def extract_drug_name(json_path: Path, cohort: str) -> str:
    base = json_path.stem
    tag = f"result_trial_data_{cohort}_CF_"
    if base.startswith(tag):
        return base[len(tag):]
    if "_CF_" in base:
        return base.split("_CF_", 1)[1]
    return base


def infer_cohorts() -> List[str]:
    if not PREDICTED_ROOT.is_dir():
        return []
    return sorted([p.name for p in PREDICTED_ROOT.glob("NCT*") if p.is_dir()])


def detect_drug_arm_from_json(data: dict) -> Tuple[str, Dict[str, str]]:
    """
    Detect which trial is the drug arm based on meta labels.
    
    Returns:
        (drug_trial_num, {trial_num: label})
        e.g., ("1", {"1": "active", "2": "control"})
    """
    meta = data.get("meta", {}) or {}
    
    trial_labels = {}
    for key, value in meta.items():
        if key.endswith("_label"):
            parts = key.replace("_label", "").split("_")
            if len(parts) >= 2 and parts[-1].isdigit():
                trial_num = parts[-1]
                trial_labels[trial_num] = str(value).strip()
    
    # Find drug arm (not a control label)
    drug_arm = None
    for trial_num, label in trial_labels.items():
        if label.lower() not in CONTROL_LABELS:
            drug_arm = trial_num
            break
    
    # Fallback to trial_2 if we can't determine
    if drug_arm is None:
        drug_arm = "2" if "2" in trial_labels else ("1" if trial_labels else "2")
    
    return drug_arm, trial_labels


def rank_one_cohort(cohort: str) -> Optional[Path]:
    in_dir = PREDICTED_ROOT / cohort
    out_dir = SCORE_ROOT / cohort
    out_dir.mkdir(parents=True, exist_ok=True)

    pattern = f"result_trial_data_{cohort}_CF_*.json"
    files = sorted(in_dir.glob(pattern))

    if not files:
        print(f"[WARN] {cohort}: No JSON files found matching: {in_dir / pattern}")
        return None

    results: List[Dict[str, object]] = []
    arm_info_logged = False

    for path in files:
        try:
            with open(path, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] {cohort}: Could not read {path.name}: {e}")
            continue

        # Detect drug arm
        drug_arm_num, trial_labels = detect_drug_arm_from_json(data)
        
        if not arm_info_logged:
            print(f"\n[INFO] {cohort}: Detected arms: {trial_labels}")
            print(f"[INFO] {cohort}: Drug arm = trial_{drug_arm_num} (label: '{trial_labels.get(drug_arm_num, 'unknown')}')")
            arm_info_logged = True

        eff_block = data.get("efficacy", {}) or {}
        p_trial1_gt_trial2 = eff_block.get("prob_trial1_gt_trial2", None)
        
        if p_trial1_gt_trial2 is None:
            print(f"[WARN] {cohort}: No efficacy value in {path.name}, skipping.")
            continue

        try:
            p_trial1_gt_trial2 = float(p_trial1_gt_trial2)
        except Exception:
            print(f"[WARN] {cohort}: Bad efficacy value in {path.name}, skipping.")
            continue

        # Compute P(drug > control) based on which arm is the drug
        if drug_arm_num == "1":
            # trial_1 is drug, trial_2 is control
            # P(drug > control) = P(trial_1 > trial_2) = prob_trial1_gt_trial2
            p_drug_gt_control = p_trial1_gt_trial2
        else:
            # trial_2 is drug, trial_1 is control
            # P(drug > control) = P(trial_2 > trial_1) = 1 - P(trial_1 > trial_2)
            p_drug_gt_control = 1.0 - p_trial1_gt_trial2

        drug_name = extract_drug_name(path, cohort)

        results.append(
            {
                "drug": drug_name,
                "file": path.name,
                "P_drug_gt_placebo": p_drug_gt_control,
                "P_placebo_gt_drug": 1.0 - p_drug_gt_control,
                "raw_prob_trial1_gt_trial2": p_trial1_gt_trial2,
                "drug_arm": f"trial_{drug_arm_num}",
            }
        )

    if not results:
        print(f"[WARN] {cohort}: No valid efficacy results found.")
        return None

    results.sort(key=lambda x: x["P_drug_gt_placebo"], reverse=True)

    top_n = min(TOP_PRINT_N, len(results))
    print(f"\n[INFO] {cohort}: Top {top_n} drugs by P(drug > control):")
    print(f"{'Rank':<5} {'Drug':<30} {'P(drug>ctrl)':>14} {'P(ctrl>drug)':>14} {'Drug Arm':>10}")
    print("-" * 85)
    for i, r in enumerate(results[:top_n], start=1):
        print(
            f"{i:<5} {str(r['drug']):<30} "
            f"{float(r['P_drug_gt_placebo']):>14.4f} "
            f"{float(r['P_placebo_gt_drug']):>14.4f} "
            f"{r['drug_arm']:>10}"
        )

    out_csv = out_dir / f"{cohort}_ranked_drugs_by_efficacy.csv"
    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["rank", "drug", "file", "P_drug_gt_placebo", "P_placebo_gt_drug", "drug_arm", "raw_prob_trial1_gt_trial2"])
        for i, r in enumerate(results, start=1):
            w.writerow(
                [
                    i,
                    r["drug"],
                    r["file"],
                    f"{float(r['P_drug_gt_placebo']):.6f}",
                    f"{float(r['P_placebo_gt_drug']):.6f}",
                    r["drug_arm"],
                    f"{float(r['raw_prob_trial1_gt_trial2']):.6f}",
                ]
            )

    print(f"[INFO] {cohort}: Saved full efficacy ranking to:\n       {out_csv}")
    return out_csv


def main(cohorts: Optional[List[str]] = None):
    """
    Main entry point.
    
    Args:
        cohorts: List of cohort IDs to process. If None, uses auto-discovery.
    """
    if cohorts is None:
        cohorts = COHORTS_TO_RUN if COHORTS_TO_RUN else infer_cohorts()
    
    if not cohorts:
        print(f"[ERROR] No cohorts found in: {PREDICTED_ROOT}")
        return []

    print(f"[INFO] Cohorts to process: {len(cohorts)}")
    outputs = []
    for c in cohorts:
        out = rank_one_cohort(c)
        if out:
            outputs.append(out)
    
    return outputs


if __name__ == "__main__":
    main()

[INFO] Cohorts to process: 2

[INFO] NCT04809974: Detected arms: {'1': 'placebo', '2': 'atp'}
[INFO] NCT04809974: Drug arm = trial_2 (label: 'atp')

[INFO] NCT04809974: Top 20 drugs by P(drug > control):
Rank  Drug                             P(drug>ctrl)   P(ctrl>drug)   Drug Arm
-------------------------------------------------------------------------------------
1     Minoxidil                              0.5308         0.4692    trial_2
2     Iloprost                               0.5305         0.4695    trial_2
3     Goserelin                              0.5301         0.4699    trial_2
4     Nedocromil                             0.5288         0.4712    trial_2
5     Cetrorelix                             0.5284         0.4716    trial_2
6     Cabergoline                            0.5284         0.4716    trial_2
7     Tipranavir                             0.5283         0.4717    trial_2
8     Roflumilast                            0.5283         0.4717    trial_2
9     Pi

#### **Composite Score**

In [60]:
#!/usr/bin/env python3
"""
composite_rank_SAE_all_cohorts.py (FIXED)

Composite ranking of drugs using Safety, Efficacy, and AE scores.
Now automatically detects the drug arm from trial labels.
"""

from __future__ import annotations

from pathlib import Path
from typing import List, Optional, Dict
import pandas as pd
import numpy as np


# =============================================================================
# CONFIG (EDIT THESE)
# =============================================================================

SCORE_RANKING_ROOT = Path(
    "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking"
)

COHORTS_TO_RUN: List[str] = []  # empty => auto-discover

# Input filename templates
SAFETY_FILENAME_TEMPLATE = "{cohort}_ranked_drugs_by_safety.csv"
EFFICACY_FILENAME_TEMPLATE = "{cohort}_ranked_drugs_by_efficacy.csv"
AE_FILENAME_TEMPLATE = "{cohort}_AE_weighted_scores.tsv"

# Output filename template
OUTPUT_FILENAME_TEMPLATE = "{cohort}_ranked_drugs_by_composite_SAE.tsv"

# Weights
W_SAFETY = 0.4
W_EFFICACY = 0.4
W_AE = 0.2

# Printing
TOP_PRINT_N = 20
PREVIEW_N = 10

AE_EPSILON = 1e-9

# Labels that indicate a control/placebo arm
CONTROL_LABELS = {"placebo", "control", "sham", "standard", "comparator", "soc", "standard of care"}


# =============================================================================
# HELPERS
# =============================================================================

def min_max_normalize(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    if s.isna().all():
        return pd.Series(0.0, index=series.index)
    min_v, max_v = s.min(), s.max()
    if pd.isna(min_v) or pd.isna(max_v) or max_v == min_v:
        return pd.Series(0.0, index=series.index)
    return (s - min_v) / (max_v - min_v)


def add_file_base(df: pd.DataFrame, col: str) -> pd.DataFrame:
    base = (
        df[col]
        .astype(str)
        .str.replace(r"^result_trial_data_", "", regex=True)
        .str.replace(r"\.json$", "", regex=True)
        .str.replace(r"_AE_mapped\.tsv$", "", regex=True)
    )
    out = df.copy()
    out["file_base"] = base
    return out


def infer_cohorts(score_root: Path) -> List[str]:
    return sorted([p.name for p in score_root.glob("NCT*") if p.is_dir()])


def load_safety(path: Path) -> pd.DataFrame:
    if not path.is_file():
        raise FileNotFoundError(f"Safety file not found: {path}")
    df = pd.read_csv(path, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]
    required = ["drug", "file", "safety_drug"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Safety file missing columns: {missing}")
    df = add_file_base(df, "file")
    df["safety_score"] = pd.to_numeric(df["safety_drug"], errors="coerce")
    return df[["file_base", "drug", "safety_score"]]


def load_efficacy(path: Path) -> pd.DataFrame:
    if not path.is_file():
        raise FileNotFoundError(f"Efficacy file not found: {path}")
    df = pd.read_csv(path, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]
    required = ["drug", "file", "P_drug_gt_placebo"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Efficacy file missing columns: {missing}")
    df = add_file_base(df, "file")
    df["efficacy_score"] = pd.to_numeric(df["P_drug_gt_placebo"], errors="coerce")
    return df[["file_base", "drug", "efficacy_score"]]


def load_ae_weighted_scores(path: Path) -> pd.DataFrame:
    """
    Load AE scores and filter to drug arm only using the is_drug_arm flag
    or by detecting based on trial_label.
    """
    if not path.is_file():
        raise FileNotFoundError(f"AE file not found: {path}")

    df = pd.read_csv(path, sep="\t", dtype=str)
    df.columns = [c.strip() for c in df.columns]

    required = ["source_file", "trial_id", "trial_label", "ae_score"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"AE file missing columns: {missing}")

    df = df.copy()
    
    # Use is_drug_arm if available, otherwise detect based on trial_label
    if "is_drug_arm" in df.columns:
        df["is_drug_arm"] = df["is_drug_arm"].astype(str).str.lower().isin(["true", "1", "yes"])
        df = df[df["is_drug_arm"]]
    else:
        # Filter to drug arm: keep rows where trial_label is NOT a control label
        df["_is_drug"] = df["trial_label"].str.lower().apply(lambda x: x not in CONTROL_LABELS)
        df = df[df["_is_drug"]]

    df = add_file_base(df, "source_file")
    df["ae_score"] = pd.to_numeric(df["ae_score"], errors="coerce")

    out = df.groupby("file_base", as_index=False).agg(ae_score=("ae_score", "mean"))
    return out[["file_base", "ae_score"]]


def build_paths_for_cohort(cohort: str) -> Dict[str, Path]:
    cohort_dir = SCORE_RANKING_ROOT / cohort
    return {
        "cohort_dir": cohort_dir,
        "safety": cohort_dir / SAFETY_FILENAME_TEMPLATE.format(cohort=cohort),
        "efficacy": cohort_dir / EFFICACY_FILENAME_TEMPLATE.format(cohort=cohort),
        "ae": cohort_dir / AE_FILENAME_TEMPLATE.format(cohort=cohort),
        "out": cohort_dir / OUTPUT_FILENAME_TEMPLATE.format(cohort=cohort),
    }


def compute_composite_for_cohort(cohort: str) -> Optional[Path]:
    paths = build_paths_for_cohort(cohort)
    cohort_dir = paths["cohort_dir"]

    if not cohort_dir.is_dir():
        print(f"[WARN] {cohort}: cohort folder not found: {cohort_dir}")
        return None

    missing_inputs = [k for k in ["safety", "efficacy", "ae"] if not paths[k].is_file()]
    if missing_inputs:
        print(f"\n[WARN] {cohort}: Missing inputs:")
        for k in missing_inputs:
            print(f"       - {paths[k].name}")
        return None

    print("\n" + "=" * 88)
    print(f"[INFO] Cohort: {cohort}")

    safety = load_safety(paths["safety"])
    efficacy = load_efficacy(paths["efficacy"])
    ae = load_ae_weighted_scores(paths["ae"])

    print(f"[INFO] Safety rows:   {len(safety)}")
    print(f"[INFO] Efficacy rows: {len(efficacy)}")
    print(f"[INFO] AE rows (drug arm only): {len(ae)}")

    merged = safety.merge(efficacy, on="file_base", how="outer", suffixes=("_safety", "_efficacy"))

    if "drug_safety" in merged.columns and "drug_efficacy" in merged.columns:
        merged["drug"] = merged["drug_safety"].combine_first(merged["drug_efficacy"])
        merged.drop(columns=["drug_safety", "drug_efficacy"], inplace=True)

    merged = merged.merge(ae, on="file_base", how="left")

    for col in ["safety_score", "efficacy_score", "ae_score"]:
        merged[col] = pd.to_numeric(merged.get(col, np.nan), errors="coerce")

    merged["safety_norm"] = min_max_normalize(merged["safety_score"])
    merged["efficacy_norm"] = min_max_normalize(merged["efficacy_score"])
    merged["ae_inverse_raw"] = 1.0 / (merged["ae_score"] + AE_EPSILON)
    merged["ae_inverse_norm"] = min_max_normalize(merged["ae_inverse_raw"])

    merged["composite_score"] = (
        W_SAFETY * merged["safety_norm"]
        + W_EFFICACY * merged["efficacy_norm"]
        + W_AE * merged["ae_inverse_norm"]
    )

    merged_sorted = merged.sort_values(
        ["composite_score", "efficacy_norm", "safety_norm"],
        ascending=[False, False, False],
    ).reset_index(drop=True)
    merged_sorted["rank"] = merged_sorted.index + 1

    cols_order = [
        "rank", "drug", "file_base",
        "safety_score", "efficacy_score", "ae_score",
        "safety_norm", "efficacy_norm", "ae_inverse_norm",
        "composite_score",
    ]
    merged_sorted = merged_sorted[[c for c in cols_order if c in merged_sorted.columns]]

    paths["out"].parent.mkdir(parents=True, exist_ok=True)
    merged_sorted.to_csv(paths["out"], sep="\t", index=False)

    print(f"[INFO] Wrote composite ranking to:\n       {paths['out']}")

    print(f"\n[PREVIEW] First {min(PREVIEW_N, len(merged_sorted))} rows:")
    print(merged_sorted.head(PREVIEW_N).to_string(index=False))

    top_n = min(TOP_PRINT_N, len(merged_sorted))
    print(f"\n[INFO] Top {top_n} drugs by composite_score:")
    print(
        merged_sorted.head(top_n)[
            [c for c in ["rank", "drug", "composite_score", "safety_score", "efficacy_score", "ae_score"] if c in merged_sorted.columns]
        ].to_string(index=False)
    )

    return paths["out"]


def main(cohorts: Optional[List[str]] = None):
    """
    Main entry point.
    
    Args:
        cohorts: List of cohort IDs. If None, auto-discovers.
    """
    if not SCORE_RANKING_ROOT.is_dir():
        raise NotADirectoryError(f"SCORE_RANKING_ROOT not found: {SCORE_RANKING_ROOT}")

    if cohorts is None:
        cohorts = COHORTS_TO_RUN if COHORTS_TO_RUN else infer_cohorts(SCORE_RANKING_ROOT)

    if not cohorts:
        print(f"[ERROR] No cohort folders found under: {SCORE_RANKING_ROOT}")
        return []

    print(f"[INFO] Cohorts to process: {len(cohorts)}")
    outputs = []
    for cohort in cohorts:
        out = compute_composite_for_cohort(cohort)
        if out:
            outputs.append(out)

    print("\n" + "=" * 88)
    print(f"[INFO] Done. Composite TSVs written: {len(outputs)}")
    return outputs


if __name__ == "__main__":
    main()

[INFO] Cohorts to process: 2

[INFO] Cohort: NCT04809974
[INFO] Safety rows:   1625
[INFO] Efficacy rows: 1625
[INFO] AE rows (drug arm only): 1626
[INFO] Wrote composite ranking to:
       /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_ranked_drugs_by_composite_SAE.tsv

[PREVIEW] First 10 rows:
 rank        drug                  file_base  safety_score  efficacy_score  ae_score  safety_norm  efficacy_norm  ae_inverse_norm  composite_score
    1   Goserelin   NCT04809974_CF_Goserelin      0.334783        0.530122  1.713776     0.417590       0.957677         0.218745         0.593856
    2   Minoxidil   NCT04809974_CF_Minoxidil      0.288513        0.530847  1.448009     0.339994       1.000000         0.274695         0.590937
    3  Idelalisib  NCT04809974_CF_Idelalisib      0.543565        0.525138  3.496859     0.767725       0.666725         0.063307         0.586442
    4   Dasati

#### **AE/S/E/Comp Ranking**

In [61]:
#!/usr/bin/env python3
"""
rank_avoid_lists.py (FIXED v3)

FIX v3: Extract drug name from source_file column, not trial_label
        (NCT04880161 has trial_label="active"/"control", not drug names)
"""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd


# =========================
# CONFIG - Control/Placebo Labels
# =========================

CONTROL_LABELS: Set[str] = {
    "placebo", "control", "sham", "comparator", "soc", "vehicle",
    "standard of care", "usual care", "standard care", "no treatment",
}


# =========================
# Helpers
# =========================

def _norm_str(x: object) -> str:
    """Normalize string for comparison."""
    if x is None:
        return ""
    s = str(x).strip().lower()
    s = " ".join(s.split())
    return s


def _is_control_label(label: str) -> bool:
    """Check if label indicates a control/placebo arm (exact match only)."""
    norm = _norm_str(label)
    if not norm:
        return False
    if norm in CONTROL_LABELS:
        return True
    for prefix in ["placebo ", "control ", "sham "]:
        if norm.startswith(prefix):
            return True
    return False


def _is_drug_label(label: str) -> bool:
    """Check if label indicates a drug/active arm."""
    return not _is_control_label(label)


def _read_csv_any(path: Path, sep: Optional[str] = None) -> pd.DataFrame:
    """Read CSV or TSV file, auto-detecting separator."""
    if not path.exists():
        return pd.DataFrame()
    if sep is None:
        sep = "\t" if path.suffix.lower() in [".tsv", ".txt"] else ","
    try:
        df = pd.read_csv(path, sep=sep)
        if df.shape[1] == 1 and sep == ",":
            df = pd.read_csv(path, sep="\t")
        return df
    except Exception as e:
        print(f"[ERROR] Could not read {path}: {e}")
        return pd.DataFrame()


def _quantile(series: pd.Series, q: float) -> float:
    """Compute quantile, handling NaN."""
    s = pd.to_numeric(series, errors="coerce").dropna()
    return float(s.quantile(q)) if not s.empty else np.nan


def _rank_worst_first(values: pd.Series, higher_is_worse: bool) -> pd.Series:
    """Rank worst-first (rank 1 = worst)."""
    x = pd.to_numeric(values, errors="coerce")
    return x.rank(ascending=False if higher_is_worse else True, method="min")


def _ensure_dir(d: Path) -> None:
    """Create directory if needed."""
    d.mkdir(parents=True, exist_ok=True)


def extract_drug_name_from_filename(filename: str) -> str:
    """
    Extract drug name from source filename.
    
    From: result_trial_data_NCT04880161_CF_Abacavir_AE_mapped.tsv
    Return: Abacavir
    """
    if pd.isna(filename) or not filename:
        return ""
    
    base = str(filename).strip()
    
    # Remove extension
    if "." in base:
        base = base.rsplit(".", 1)[0]
    
    # Remove _AE_mapped suffix
    if base.endswith("_AE_mapped"):
        base = base[:-len("_AE_mapped")]
    
    # Extract part after _CF_
    if "_CF_" in base:
        return base.split("_CF_", 1)[1]
    
    return base


# =========================
# Loaders
# =========================

def load_ae_weighted_scores(cohort_dir: Path, cohort: str) -> pd.DataFrame:
    """
    Load AE weighted scores and filter to drug arm only.
    
    FIXED v3: Extract drug name from source_file, not trial_label.
    """
    path = cohort_dir / f"{cohort}_AE_weighted_scores.tsv"
    df = _read_csv_any(path, sep="\t")
    if df.empty:
        print(f"[WARN] AE weighted scores file is empty or not found: {path}")
        return df

    # Normalize labels for arm detection
    if "trial_label" in df.columns:
        df["trial_label_norm"] = df["trial_label"].apply(_norm_str)
        df["is_drug_arm"] = df["trial_label_norm"].apply(_is_drug_label)
    else:
        df["trial_label_norm"] = ""
        df["is_drug_arm"] = True

    if "source_file" not in df.columns:
        df["source_file"] = "unknown_source"

    def pick_drug_row(g: pd.DataFrame) -> pd.DataFrame:
        """Select the drug arm row from a group."""
        drug_rows = g[g["is_drug_arm"] == True]
        if not drug_rows.empty:
            return drug_rows.head(1)
        non_ctrl = g[~g["trial_label_norm"].apply(_is_control_label)]
        if not non_ctrl.empty:
            return non_ctrl.head(1)
        return g.head(1)

    drug_rows = (
        df.groupby("source_file", as_index=False, sort=False)
          .apply(lambda g: pick_drug_row(g).copy())
    )
    drug_rows = drug_rows.reset_index(drop=True)

    # FIXED: Extract drug name from source_file, NOT from trial_label
    drug_rows["drug"] = drug_rows["source_file"].apply(extract_drug_name_from_filename)

    # Ensure numeric columns
    for col in ["ae_score", "total_prob", "total_weighted_prob", "n_ae"]:
        if col in drug_rows.columns:
            drug_rows[col] = pd.to_numeric(drug_rows[col], errors="coerce")

    drug_rows["drug_key"] = drug_rows["drug"].apply(_norm_str)

    print(f"[INFO] Loaded {len(drug_rows)} drug rows from AE weighted scores")
    
    return drug_rows


def load_ranked_safety(cohort_dir: Path, cohort: str) -> pd.DataFrame:
    """Load safety rankings CSV."""
    path = cohort_dir / f"{cohort}_ranked_drugs_by_safety.csv"
    df = _read_csv_any(path, sep=",")
    if df.empty:
        return df

    df["drug_key"] = df["drug"].apply(_norm_str) if "drug" in df.columns else ""

    if "safety_drug" in df.columns:
        df["safety_drug"] = pd.to_numeric(df["safety_drug"], errors="coerce")
    elif "safety_score" in df.columns:
        df["safety_drug"] = pd.to_numeric(df["safety_score"], errors="coerce")
    else:
        df["safety_drug"] = np.nan

    return df[["drug", "drug_key", "safety_drug"]].drop_duplicates("drug_key")


def load_ranked_efficacy(cohort_dir: Path, cohort: str) -> pd.DataFrame:
    """Load efficacy rankings CSV."""
    path = cohort_dir / f"{cohort}_ranked_drugs_by_efficacy.csv"
    df = _read_csv_any(path, sep=",")
    if df.empty:
        return df

    df["drug_key"] = df["drug"].apply(_norm_str) if "drug" in df.columns else ""

    if "P_drug_gt_placebo" in df.columns:
        df["P_drug_gt_placebo"] = pd.to_numeric(df["P_drug_gt_placebo"], errors="coerce")
    elif "efficacy_score" in df.columns:
        df["P_drug_gt_placebo"] = pd.to_numeric(df["efficacy_score"], errors="coerce")
    else:
        df["P_drug_gt_placebo"] = np.nan

    return df[["drug", "drug_key", "P_drug_gt_placebo"]].drop_duplicates("drug_key")


def load_ranked_composite(cohort_dir: Path, cohort: str) -> pd.DataFrame:
    """Load composite SAE rankings TSV."""
    path = cohort_dir / f"{cohort}_ranked_drugs_by_composite_SAE.tsv"
    df = _read_csv_any(path, sep="\t")
    if df.empty:
        return df

    df["drug_key"] = df["drug"].apply(_norm_str) if "drug" in df.columns else ""

    for col in ["composite_score", "safety_score", "efficacy_score", "ae_score"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")

    keep_cols = [c for c in [
        "drug", "drug_key", "safety_score", "efficacy_score", "ae_score",
        "safety_norm", "efficacy_norm", "ae_inverse_norm", "composite_score",
    ] if c in df.columns]

    return df[keep_cols].drop_duplicates("drug_key")


def load_ground_truth(ground_truth_dir: Path, cohort: str) -> pd.DataFrame:
    """Load ground truth data for the real/original drug."""
    gt_path = ground_truth_dir / f"Ground_Truth_{cohort}.csv"
    
    if not gt_path.exists():
        print(f"[WARN] Ground truth file not found: {gt_path}")
        return pd.DataFrame()
    
    df = _read_csv_any(gt_path, sep=",")
    if df.empty:
        return df
    
    df.columns = [c.strip() for c in df.columns]
    
    # Handle different column names
    col_mapping = {
        "arm_label": ["arm_label", "arm", "label", "trial_label", "group", "arm_name"],
        "safety": ["safety", "safety_score", "safety_drug"],
        "efficacy": ["efficacy", "efficacy_score", "p_drug_gt_placebo"],
        "ae_prob": ["ae_prob", "ae_score", "ae", "total_prob"],
    }
    
    for target, options in col_mapping.items():
        if target not in df.columns:
            for opt in options:
                for col in df.columns:
                    if col.lower() == opt.lower():
                        df[target] = df[col]
                        break
                if target in df.columns:
                    break
    
    for col in ["safety", "efficacy", "ae_prob"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")
    
    if "arm_label" in df.columns:
        df["arm_label_norm"] = df["arm_label"].apply(_norm_str)
        df["is_control"] = df["arm_label_norm"].apply(_is_control_label)
        df_drug = df[~df["is_control"]].copy()
        
        if df_drug.empty:
            df_drug = df[df["arm_label_norm"] != "placebo"].copy()
        
        return df_drug
    
    return df


def summarize_real_from_ground_truth(df_drug: pd.DataFrame) -> Dict[str, float]:
    """Get baseline metrics from ground truth."""
    if df_drug.empty:
        return {"safety_real": np.nan, "efficacy_real": np.nan, "ae_totalprob_real": np.nan}

    safety_real = float(df_drug["safety"].dropna().iloc[0]) if "safety" in df_drug.columns and df_drug["safety"].notna().any() else np.nan
    efficacy_real = float(df_drug["efficacy"].dropna().iloc[0]) if "efficacy" in df_drug.columns and df_drug["efficacy"].notna().any() else np.nan
    ae_totalprob_real = float(df_drug["ae_prob"].dropna().sum()) if "ae_prob" in df_drug.columns else np.nan

    return {
        "safety_real": safety_real,
        "efficacy_real": efficacy_real,
        "ae_totalprob_real": ae_totalprob_real,
    }


# =========================
# Build master table
# =========================

def build_master_table(
    cohort_dir: Path,
    cohort: str,
    ground_truth_dir: Path,
) -> Tuple[pd.DataFrame, Dict[str, float]]:
    """Build master table with all metrics and avoid rankings."""
    ae_df = load_ae_weighted_scores(cohort_dir, cohort)
    s_df = load_ranked_safety(cohort_dir, cohort)
    e_df = load_ranked_efficacy(cohort_dir, cohort)
    c_df = load_ranked_composite(cohort_dir, cohort)

    print(f"[INFO] {cohort}: Loaded AE={len(ae_df)}, Safety={len(s_df)}, Efficacy={len(e_df)}, Composite={len(c_df)} rows")

    # Union of drug keys
    keys = set()
    for df in [ae_df, s_df, e_df, c_df]:
        if not df.empty and "drug_key" in df.columns:
            keys |= set(df["drug_key"].dropna().astype(str).tolist())

    master = pd.DataFrame({"drug_key": sorted([k for k in keys if k != ""])})
    master["drug"] = master["drug_key"]

    def _merge(master_df: pd.DataFrame, df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
        if df.empty:
            for c in cols:
                if c not in master_df.columns:
                    master_df[c] = np.nan
            return master_df
        tmp = df.copy()
        if "drug" in tmp.columns:
            disp = tmp[["drug_key", "drug"]].drop_duplicates("drug_key")
            master_df = master_df.merge(disp, on="drug_key", how="left", suffixes=("", "_y"))
            master_df["drug"] = master_df["drug_y"].combine_first(master_df["drug"])
            master_df = master_df.drop(columns=["drug_y"], errors="ignore")
        add = tmp[["drug_key"] + [c for c in cols if c in tmp.columns]].drop_duplicates("drug_key")
        return master_df.merge(add, on="drug_key", how="left")

    master = _merge(master, ae_df, ["ae_score", "total_prob", "total_weighted_prob", "n_ae"])
    master = _merge(master, s_df, ["safety_drug"])
    master = _merge(master, e_df, ["P_drug_gt_placebo"])
    master = _merge(master, c_df, [
        "safety_score", "efficacy_score", "ae_score", "safety_norm", "efficacy_norm", "ae_inverse_norm", "composite_score"
    ])

    # Consolidate columns
    if "safety_drug" in master.columns and "safety_score" in master.columns:
        master["safety_used"] = master["safety_drug"].combine_first(master["safety_score"])
    elif "safety_drug" in master.columns:
        master["safety_used"] = master["safety_drug"]
    elif "safety_score" in master.columns:
        master["safety_used"] = master["safety_score"]
    else:
        master["safety_used"] = np.nan

    if "P_drug_gt_placebo" in master.columns and "efficacy_score" in master.columns:
        master["efficacy_used"] = master["P_drug_gt_placebo"].combine_first(master["efficacy_score"])
    elif "P_drug_gt_placebo" in master.columns:
        master["efficacy_used"] = master["P_drug_gt_placebo"]
    elif "efficacy_score" in master.columns:
        master["efficacy_used"] = master["efficacy_score"]
    else:
        master["efficacy_used"] = np.nan

    # AE: prefer weighted from AE file, fallback to total_prob, then composite ae_score
    if "ae_score" in master.columns and master["ae_score"].notna().any():
        # Check if ae_score from AE file has values (not all NaN)
        master["ae_used"] = master["ae_score"]
        master["ae_units"] = "weighted_ae_score"
    elif "total_prob" in master.columns and master["total_prob"].notna().any():
        master["ae_used"] = master["total_prob"]
        master["ae_units"] = "unweighted_total_prob"
    else:
        master["ae_used"] = np.nan
        master["ae_units"] = "unknown"

    master["composite_used"] = master.get("composite_score", np.nan)

    # Avoid ranks
    master["avoid_rank_ae"] = _rank_worst_first(master["ae_used"], higher_is_worse=True)
    master["avoid_rank_safety"] = _rank_worst_first(master["safety_used"], higher_is_worse=False)
    master["avoid_rank_efficacy"] = _rank_worst_first(master["efficacy_used"], higher_is_worse=False)
    master["avoid_rank_composite"] = _rank_worst_first(master["composite_used"], higher_is_worse=False)

    # Ground truth baseline
    gt_df = load_ground_truth(ground_truth_dir, cohort)
    baseline = summarize_real_from_ground_truth(gt_df)

    # Delta metrics
    master["delta_safety_vs_real"] = master["safety_used"] - baseline["safety_real"]
    master["delta_efficacy_vs_real"] = master["efficacy_used"] - baseline["efficacy_real"]
    master["delta_ae_vs_real"] = master["ae_used"] - baseline["ae_totalprob_real"]
    
    baseline["ae_real_used"] = baseline["ae_totalprob_real"]
    baseline["ae_real_units"] = "unweighted_total_prob (ground truth)"

    return master, baseline


# =========================
# Cutoff logic
# =========================

def make_primary_avoid_list(master: pd.DataFrame, q: float) -> pd.DataFrame:
    """Primary: worst tail of composite."""
    df = master.copy()
    thr = _quantile(df["composite_used"], q)
    df["flag_primary_composite_worst"] = pd.to_numeric(df["composite_used"], errors="coerce") <= thr
    out = df[df["flag_primary_composite_worst"]].copy()
    return out.sort_values(["avoid_rank_composite", "avoid_rank_ae", "avoid_rank_safety", "avoid_rank_efficacy"])


def make_secondary_guard_list(master: pd.DataFrame, q: float, require_at_least: int = 1) -> pd.DataFrame:
    """Secondary: guard-rail extremes."""
    df = master.copy()
    ae_thr = _quantile(df["ae_used"], 1 - q)
    s_thr = _quantile(df["safety_used"], q)
    e_thr = _quantile(df["efficacy_used"], q)

    df["flag_ae_extreme"] = pd.to_numeric(df["ae_used"], errors="coerce") >= ae_thr
    df["flag_safety_extreme"] = pd.to_numeric(df["safety_used"], errors="coerce") <= s_thr
    df["flag_efficacy_extreme"] = pd.to_numeric(df["efficacy_used"], errors="coerce") <= e_thr
    df["n_extremes"] = df[["flag_ae_extreme", "flag_safety_extreme", "flag_efficacy_extreme"]].sum(axis=1)

    out = df[df["n_extremes"] >= require_at_least].copy()
    return out.sort_values(["n_extremes", "avoid_rank_ae", "avoid_rank_safety", "avoid_rank_efficacy"],
                           ascending=[False, True, True, True])


def make_delta_avoid_list(master: pd.DataFrame, q: float) -> pd.DataFrame:
    """Delta-based avoid list."""
    df = master.copy()
    ae_thr = _quantile(df["delta_ae_vs_real"], 1 - q)
    s_thr = _quantile(df["delta_safety_vs_real"], q)
    e_thr = _quantile(df["delta_efficacy_vs_real"], q)

    df["flag_delta_ae_worse"] = pd.to_numeric(df["delta_ae_vs_real"], errors="coerce") >= ae_thr
    df["flag_delta_safety_worse"] = pd.to_numeric(df["delta_safety_vs_real"], errors="coerce") <= s_thr
    df["flag_delta_efficacy_worse"] = pd.to_numeric(df["delta_efficacy_vs_real"], errors="coerce") <= e_thr
    df["n_delta_flags"] = df[["flag_delta_ae_worse", "flag_delta_safety_worse", "flag_delta_efficacy_worse"]].sum(axis=1)

    out = df[df["n_delta_flags"] >= 1].copy()
    return out.sort_values(["n_delta_flags", "delta_ae_vs_real", "delta_safety_vs_real", "delta_efficacy_vs_real"],
                           ascending=[False, False, True, True])


# =========================
# Display functions
# =========================

def display_outputs(cohort: str, outputs: Dict[str, Path], preview_rows: int = 5) -> None:
    """Display formatted output summary."""
    display_order = [("master", "MASTER"), ("primary", "PRIMARY"), ("secondary", "SECONDARY"),
                     ("delta", "DELTA"), ("baseline", "BASELINE")]
    
    for key, label in display_order:
        if key not in outputs:
            continue
        path = outputs[key]
        if not path.exists():
            continue
        
        try:
            df = pd.read_csv(path, sep="\t" if path.suffix.lower() == ".tsv" else ",")
        except:
            continue
        
        print(f"\n{'='*100}")
        print(f"{label}  |  {path}")
        pd.set_option('display.max_columns', 15)
        pd.set_option('display.width', 120)
        print(df.head(preview_rows).to_string())
        print(f"\n[{len(df)} rows x {len(df.columns)} columns]")
    
    print(f"\n{'='*100}")
    print(f"SUMMARY FOR {cohort}")
    print(f"{'='*100}")
    
    for key, label in display_order:
        if key not in outputs:
            continue
        path = outputs[key]
        if path.exists():
            try:
                df = pd.read_csv(path, sep="\t" if path.suffix.lower() == ".tsv" else ",")
                print(f"{label.lower():<10}: {len(df):,} rows  ->  {path}")
            except:
                pass


def display_all_cohorts(results: Dict[str, Dict[str, Path]], preview_rows: int = 5) -> None:
    """Display outputs for all cohorts."""
    for cohort, outputs in results.items():
        print(f"\n{'#'*100}")
        print(f"#  COHORT: {cohort}")
        print(f"{'#'*100}")
        display_outputs(cohort, outputs, preview_rows=preview_rows)


# =========================
# Runner
# =========================

@dataclass
class RunConfig:
    """Configuration for avoid list generation."""
    q_primary_composite: float = 0.10
    q_guard: float = 0.20
    q_sensitivity: List[float] = field(default_factory=lambda: [0.05, 0.10, 0.20])
    q_delta: float = 0.10
    guard_require_at_least: int = 1


def run_one_cohort(
    main_path: Path,
    cohort: str,
    ground_truth_dir: Path,
    out_dir: Optional[Path] = None,
    config: Optional[RunConfig] = None,
    display: bool = True,
) -> Dict[str, Path]:
    """Run one cohort and save outputs."""
    if config is None:
        config = RunConfig()

    cohort_dir = main_path / cohort
    if out_dir is None:
        out_dir = cohort_dir
    _ensure_dir(out_dir)

    print(f"\n{'='*80}")
    print(f"[INFO] Processing cohort: {cohort}")
    print(f"{'='*80}")

    master, baseline = build_master_table(cohort_dir, cohort, ground_truth_dir)

    out_master = out_dir / f"{cohort}_avoid_master_table.csv"
    master.to_csv(out_master, index=False)
    print(f"[INFO] Saved master table: {out_master.name} ({len(master)} drugs)")

    primary = make_primary_avoid_list(master, config.q_primary_composite)
    out_primary = out_dir / f"{cohort}_avoid_primary_composite_q{int(config.q_primary_composite*100):02d}.csv"
    primary.to_csv(out_primary, index=False)
    print(f"[INFO] Primary avoid list (q={config.q_primary_composite}): {len(primary)} drugs")

    secondary = make_secondary_guard_list(master, config.q_guard, require_at_least=config.guard_require_at_least)
    out_secondary = out_dir / f"{cohort}_avoid_secondary_guard_q{int(config.q_guard*100):02d}.csv"
    secondary.to_csv(out_secondary, index=False)
    print(f"[INFO] Secondary guard list (q={config.q_guard}): {len(secondary)} drugs")

    sens_paths: List[Path] = []
    for q in config.q_sensitivity:
        p = make_primary_avoid_list(master, q)
        s = make_secondary_guard_list(master, q, require_at_least=config.guard_require_at_least)
        out_p = out_dir / f"{cohort}_avoid_sensitivity_primary_q{int(q*100):02d}.csv"
        out_s = out_dir / f"{cohort}_avoid_sensitivity_secondary_q{int(q*100):02d}.csv"
        p.to_csv(out_p, index=False)
        s.to_csv(out_s, index=False)
        sens_paths.extend([out_p, out_s])
    print(f"[INFO] Sensitivity lists: {len(sens_paths)} files")

    delta = make_delta_avoid_list(master, config.q_delta)
    out_delta = out_dir / f"{cohort}_avoid_delta_q{int(config.q_delta*100):02d}.csv"
    delta.to_csv(out_delta, index=False)
    print(f"[INFO] Delta avoid list (q={config.q_delta}): {len(delta)} drugs")

    base_df = pd.DataFrame([{
        "cohort": cohort,
        "safety_real": baseline.get("safety_real", np.nan),
        "efficacy_real": baseline.get("efficacy_real", np.nan),
        "ae_real_used": baseline.get("ae_real_used", np.nan),
        "ae_real_units": baseline.get("ae_real_units", ""),
    }])
    out_base = out_dir / f"{cohort}_real_baseline_summary.csv"
    base_df.to_csv(out_base, index=False)

    outputs = {
        "master": out_master,
        "primary": out_primary,
        "secondary": out_secondary,
        "delta": out_delta,
        "baseline": out_base,
    }
    
    if display:
        display_outputs(cohort, outputs, preview_rows=5)
    
    return outputs


def main(
    main_path: Optional[Path] = None,
    cohorts: Optional[List[str]] = None,
    ground_truth_dir: Optional[Path] = None,
    out_dir: Optional[Path] = None,
    config: Optional[RunConfig] = None,
    display: bool = True,
) -> Dict[str, Dict[str, Path]]:
    """Main entry point."""
    if main_path is None:
        main_path = Path(
            "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/"
            "3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking"
        )
    
    if ground_truth_dir is None:
        ground_truth_dir = Path(
            "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/"
            "3_Third_Paper/SM/3_Ground_Truth"
        )
    
    if config is None:
        config = RunConfig()
    
    main_path = Path(main_path)
    ground_truth_dir = Path(ground_truth_dir)
    
    if not ground_truth_dir.exists():
        print(f"[ERROR] Ground truth directory not found: {ground_truth_dir}")
        return {}
    
    if cohorts is None:
        cohorts = sorted([p.name for p in main_path.glob("NCT*") if p.is_dir()])
    
    if not cohorts:
        print(f"[ERROR] No cohorts found under: {main_path}")
        return {}
    
    print(f"[INFO] Cohorts to process: {len(cohorts)}")
    
    results = {}
    for cohort in cohorts:
        cohort_out_dir = out_dir / cohort if out_dir else None
        try:
            outputs = run_one_cohort(
                main_path=main_path,
                cohort=cohort,
                ground_truth_dir=ground_truth_dir,
                out_dir=cohort_out_dir,
                config=config,
                display=display,
            )
            results[cohort] = outputs
        except Exception as e:
            print(f"[ERROR] Failed to process {cohort}: {e}")
            import traceback
            traceback.print_exc()
    
    print(f"\n{'='*80}")
    print(f"[INFO] Done. Processed {len(results)}/{len(cohorts)} cohorts successfully.")
    print(f"{'='*80}")
    
    return results


if __name__ == "__main__":
    main()

[INFO] Cohorts to process: 2

[INFO] Processing cohort: NCT04809974
[INFO] Loaded 1626 drug rows from AE weighted scores
[INFO] NCT04809974: Loaded AE=1626, Safety=1625, Efficacy=1625, Composite=1625 rows
[INFO] Saved master table: NCT04809974_avoid_master_table.csv (1626 drugs)
[INFO] Primary avoid list (q=0.1): 163 drugs
[INFO] Secondary guard list (q=0.2): 850 drugs
[INFO] Sensitivity lists: 6 files
[INFO] Delta avoid list (q=0.1): 456 drugs

MASTER  |  /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_avoid_master_table.csv
        drug_key           drug  ae_score_x  total_prob  total_weighted_prob  n_ae  safety_drug  P_drug_gt_placebo  safety_score  efficacy_score  ae_score_y  safety_norm  efficacy_norm  ae_inverse_norm  composite_score  safety_used  efficacy_used    ae_used               ae_units  composite_used  avoid_rank_ae  avoid_rank_safety  avoid_rank_efficacy  avoid_rank_comp

## **Filtering**

In [63]:
#!/usr/bin/env python3
"""
lc_prefilter_and_viability_filter_all_cohorts.py

Per-cohort filtering of a ranked PlaNet composite table to produce TWO LC lists:
1) "Top-ranked winners" shortlist:
   KEEP if (kept_by_top OR kept_by_hints OR hard_whitelist) AND not hard_blocked
   SOFT BLOCK -> REVIEW

2) "Phenotype-bucket LC" shortlist:
   KEEP if (kept_by_hints OR hard_whitelist) AND not hard_blocked
   SOFT BLOCK -> REVIEW
   (This excludes drugs kept only because they were top-ranked.)

Input (per cohort folder):
  - <COHORT>_ranked_drugs_by_composite_SAE.tsv

Outputs (per cohort folder):
  - <COHORT>_ranked_drugs_by_composite_SAE_LC_winners.tsv
  - <COHORT>_ranked_drugs_by_composite_SAE_LC_phenotype.tsv

Python 3.8+ compatible.
"""

from __future__ import annotations

from pathlib import Path
import re
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple, Set, Optional


# =============================================================================
# 0) PATHS / COHORT DISCOVERY
# =============================================================================

SCORE_RANKING_ROOT = Path(
    "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking"
)

# If empty -> auto-discover cohort folders under SCORE_RANKING_ROOT (NCT*)
COHORTS_TO_RUN: List[str] = []  # e.g. ["NCT04809974", "NCT05576662"]

# Input filename template inside each cohort folder
INPUT_FILENAME_TEMPLATE = "{cohort}_ranked_drugs_by_composite_SAE.tsv"

# Output suffixes (we append to INPUT stem)
WINNERS_SUFFIX = "_LC_winners.tsv"
PHENOTYPE_SUFFIX = "_LC_phenotype.tsv"


# =============================================================================
# 1) FAST PREFILTER SETTINGS
# =============================================================================
TOP_N = 300                  # keeps only top N by composite_score before applying rules
ALWAYS_KEEP_TOP_N = 60       # winners-mode: keep top-N ranks (unless HARD BLOCKED)

# Output control:
OUTPUT_INCLUDE_DROPPED_WINNERS = False   # if True, output ALL rows with winners_status labels
OUTPUT_INCLUDE_DROPPED_PHENO = False     # if True, output ALL rows with phenotype_status labels

# Optional threshold gates (set to None to disable) applied after TOP_N
MIN_SAFETY_NORM = None       # e.g. 0.20
MIN_EFFICACY_NORM = None     # e.g. 0.55
MIN_AE_INV_NORM = None       # e.g. 0.15

# Optional Pareto gate (non-dominated) on the *_norm columns (after TOP_N)
USE_PARETO = False
PARETO_COLS = ("safety_norm", "efficacy_norm", "ae_inverse_norm")  # maximize all


# =============================================================================
# 2) RULES: HARD BLOCK / SOFT BLOCK / KEEP HINTS
# =============================================================================
def _rx(pattern: str) -> re.Pattern:
    return re.compile(pattern, flags=re.IGNORECASE | re.VERBOSE)

HARD_BLOCK_PATTERNS: Dict[str, str] = {
    "endocrine_axis": r"""
        \b(
            goserelin|cetrorelix|buserelin|triptorelin|degarelix|
            protirelin|
            anastrozole|letrozole|exemestane|
            dutasteride|finasteride|
            abiraterone|enzalutamide
        )\b
    """,
    "oncology_targeted_cytotoxic": r"""
        \b(
            pentostatin|cladribine|
            plerixafor|
            bexarotene|
            \w*(?:tinib|ciclib|parib|lisib)\b
        )
    """,
    "diagnostic_or_procedural": r"""
        \b(methacholine)\b
    """,
    "withdrawn_or_unacceptable": r"""
        \b(sitaxentan|tolrestat)\b
    """,
    "acute_iv_only": r"""
        \b(
            terlipressin|
            norepinephrine|epinephrine|dopamine|
            rocuronium|succinylcholine|neostigmine|
            propofol|ketamine
        )\b
    """,
    "ophthalmic_topical": r"""
        \b(latanoprost|timolol|dorzolamide)\b
    """,
    "excipient_simple_acid": r"""
        \b(citric\s*acid)\b
    """,
}

SOFT_BLOCK_PATTERNS: Dict[str, str] = {
    "immunosuppression_higher_risk": r"""
        \b(teriflunomide|mycophenolic\s*acid|mycophenolate|tacrolimus|cyclosporine)\b
    """,
    "hepatic_interactions_monitoring": r"""
        \b(bosentan)\b
    """,
}

KEEP_HINTS: Dict[str, str] = {
    "pulmonary_vascular_endothelial": r"""
        \b(iloprost|treprostinil|selexipag|vardenafil|sildenafil|tadalafil|bosentan)\b
    """,
    "respiratory_antiinflammatory": r"""
        \b(roflumilast|montelukast|pranlukast|nedocromil|cromoglicic\s*acid|salbutamol|
           terbutaline|fenoterol|theophylline|aminophylline)\b
    """,
    "neuro_cognitive_psych": r"""
        \b(sertraline|fluoxetine|paroxetine|bupropion|agomelatine|modafinil|
           ramelteon|flumazenil|pimavanserin)\b
    """,
    "bile_acids_metabolic_fibrotic": r"""
        \b(ursodeoxycholic|chenodeoxycholic|cholic\s*acid|deoxycholic\s*acid|
           obeticholic|telmisartan)\b
    """,
    "low_risk_supportive": r"""
        \b(acetylcarnitine|levocarnitine|tyrosine|cystine|methionine|
           glycine\s*betaine|resveratrol|niacin|riboflavin|biotin)\b
    """,
    "autonomic_pots": r"""
        \b(ivabradine|midodrine|fludrocortisone|propranolol|pyridostigmine|droxidopa)\b
    """,
}

HARD_WHITELIST: Set[str] = set([
    # "iloprost",
])

WHITELIST_OVERRIDES_SOFT_BLOCK = True


# =============================================================================
# 3) INTERNAL HELPERS
# =============================================================================
HARD_BLOCK_RX: Dict[str, re.Pattern] = {k: _rx(v) for k, v in HARD_BLOCK_PATTERNS.items()}
SOFT_BLOCK_RX: Dict[str, re.Pattern] = {k: _rx(v) for k, v in SOFT_BLOCK_PATTERNS.items()}
KEEP_RX: Dict[str, re.Pattern] = {k: _rx(v) for k, v in KEEP_HINTS.items()}

SEP = "\t"


def normalize_name(x: str) -> str:
    return str(x).strip().lower()


def match_name_from_norm(drug_norm: str) -> str:
    return drug_norm.replace("_", " ")


def find_reasons(match_name: str) -> Tuple[List[str], List[str], List[str]]:
    hard = [k for k, rx in HARD_BLOCK_RX.items() if rx.search(match_name)]
    soft = [k for k, rx in SOFT_BLOCK_RX.items() if rx.search(match_name)]
    keep = [k for k, rx in KEEP_RX.items() if rx.search(match_name)]
    return hard, soft, keep


def to_num(df: pd.DataFrame, col: str) -> None:
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce")


def pareto_front(df: pd.DataFrame, cols: Tuple[str, ...]) -> pd.Series:
    X = df.loc[:, list(cols)].to_numpy(dtype=float)
    n = X.shape[0]
    keep = np.ones(n, dtype=bool)

    for i in range(n):
        if not keep[i]:
            continue
        for j in range(n):
            if i == j:
                continue
            if np.all(X[j] >= X[i]) and np.any(X[j] > X[i]):
                keep[i] = False
                break
    return pd.Series(keep, index=df.index)


def _apply_status_logic(df: pd.DataFrame, mode: str) -> pd.DataFrame:
    out = df.copy()

    if mode not in ("winners", "phenotype"):
        raise ValueError(f"Unknown mode: {mode}")

    if mode == "winners":
        out["lc_keep"] = (
            (out["hard_whitelist"] | out["kept_by_hints"] | out["kept_by_top"])
            & (~out["hard_blocked"])
        )
        status_col = "lc_status"
        out[status_col] = "DROP"
        out.loc[out["lc_keep"], status_col] = "KEEP"
        out.loc[out["lc_keep"] & out["soft_blocked"], status_col] = "REVIEW"
        if WHITELIST_OVERRIDES_SOFT_BLOCK:
            out.loc[out["hard_whitelist"] & (out[status_col] == "REVIEW"), status_col] = "KEEP"
    else:
        out["lc_keep_pheno"] = (
            (out["hard_whitelist"] | out["kept_by_hints"])
            & (~out["hard_blocked"])
        )
        status_col = "lc_status_pheno"
        out[status_col] = "DROP"
        out.loc[out["lc_keep_pheno"], status_col] = "KEEP"
        out.loc[out["lc_keep_pheno"] & out["soft_blocked"], status_col] = "REVIEW"
        if WHITELIST_OVERRIDES_SOFT_BLOCK:
            out.loc[out["hard_whitelist"] & (out[status_col] == "REVIEW"), status_col] = "KEEP"

    return out


def _select_and_sort_output(df: pd.DataFrame, mode: str, include_dropped: bool) -> pd.DataFrame:
    status_col = "lc_status" if mode == "winners" else "lc_status_pheno"

    out = df.copy() if include_dropped else df[df[status_col].isin(["KEEP", "REVIEW"])].copy()

    has_comp = ("composite_score" in out.columns) and out["composite_score"].notna().any()

    if mode == "winners":
        if has_comp:
            out = out.sort_values([status_col, "composite_score", "rank_num"],
                                  ascending=[True, False, True])
        else:
            out = out.sort_values([status_col, "rank_num"], ascending=[True, True])
    else:
        if "keep_hints" in out.columns:
            out["_first_bucket"] = out["keep_hints"].apply(lambda x: str(x).split(";")[0] if str(x) else "")
        else:
            out["_first_bucket"] = ""

        if has_comp:
            out = out.sort_values([status_col, "_first_bucket", "composite_score", "rank_num"],
                                  ascending=[True, True, False, True])
        else:
            out = out.sort_values([status_col, "_first_bucket", "rank_num"],
                                  ascending=[True, True, True])

        out = out.drop(columns=["_first_bucket"], errors="ignore")

    return out


def infer_cohorts(score_root: Path) -> List[str]:
    return sorted([p.name for p in score_root.glob("NCT*") if p.is_dir()])


def build_paths_for_cohort(cohort: str) -> Dict[str, Path]:
    cohort_dir = SCORE_RANKING_ROOT / cohort
    inp = cohort_dir / INPUT_FILENAME_TEMPLATE.format(cohort=cohort)

    winners_out = cohort_dir / (inp.stem + WINNERS_SUFFIX)
    pheno_out = cohort_dir / (inp.stem + PHENOTYPE_SUFFIX)

    return {"cohort_dir": cohort_dir, "input": inp, "winners": winners_out, "pheno": pheno_out}


# =============================================================================
# 4) PER-COHORT RUN
# =============================================================================
def run_for_one_cohort(cohort: str) -> Optional[Tuple[pd.DataFrame, pd.DataFrame]]:
    paths = build_paths_for_cohort(cohort)

    if not paths["cohort_dir"].is_dir():
        print(f"[WARN] {cohort}: cohort folder not found: {paths['cohort_dir']}")
        return None

    if not paths["input"].is_file():
        print(f"[WARN] {cohort}: missing composite input:")
        print(f"       - expected at: {paths['input']}")
        return None

    INPUT_TSV = paths["input"]
    OUTPUT_WINNERS_TSV = paths["winners"]
    OUTPUT_PHENOTYPE_TSV = paths["pheno"]

    # Load
    df = pd.read_csv(INPUT_TSV, sep=SEP, dtype=str)
    df.columns = [c.strip() for c in df.columns]

    if "drug" not in df.columns:
        raise ValueError(f"{cohort}: Expected a 'drug' column. Found columns: {list(df.columns)}")

    # Numeric columns (if present)
    for c in [
        "composite_score",
        "safety_score", "efficacy_score", "ae_score",
        "safety_norm", "efficacy_norm", "ae_inverse_norm",
        "rank_num", "rank"
    ]:
        to_num(df, c)

    # Rank handling
    if "rank_num" in df.columns and df["rank_num"].notna().any():
        pass
    elif "rank" in df.columns and df["rank"].notna().any():
        df["rank_num"] = df["rank"]
    else:
        if "composite_score" not in df.columns:
            raise ValueError(f"{cohort}: Need either rank/rank_num or composite_score to rank rows.")
        df = df.sort_values("composite_score", ascending=False).reset_index(drop=True)
        df["rank_num"] = np.arange(1, len(df) + 1)

    # Prefilter: TOP_N by composite_score if available, else by rank_num
    if "composite_score" in df.columns and df["composite_score"].notna().any():
        df = df.sort_values("composite_score", ascending=False).head(TOP_N).copy()
    else:
        df = df.sort_values("rank_num", ascending=True).head(TOP_N).copy()

    # Optional threshold gates
    if MIN_SAFETY_NORM is not None and "safety_norm" in df.columns:
        df = df[df["safety_norm"] >= MIN_SAFETY_NORM].copy()
    if MIN_EFFICACY_NORM is not None and "efficacy_norm" in df.columns:
        df = df[df["efficacy_norm"] >= MIN_EFFICACY_NORM].copy()
    if MIN_AE_INV_NORM is not None and "ae_inverse_norm" in df.columns:
        df = df[df["ae_inverse_norm"] >= MIN_AE_INV_NORM].copy()

    # Optional Pareto gate
    if USE_PARETO:
        missing = [c for c in PARETO_COLS if c not in df.columns]
        if missing:
            print(f"[WARN] {cohort}: Pareto requested but missing columns: {missing}. Skipping Pareto.")
        else:
            mask = pareto_front(df, PARETO_COLS)
            keep_top = df["rank_num"] <= ALWAYS_KEEP_TOP_N
            df = df[mask | keep_top].copy()

    # Normalize drug names
    df["drug_norm"] = df["drug"].apply(normalize_name)
    df["drug_match"] = df["drug_norm"].apply(match_name_from_norm)

    # Reasons
    reasons = df["drug_match"].apply(find_reasons)
    df["hard_block_reasons"] = reasons.apply(lambda x: ";".join(x[0]) if x[0] else "")
    df["soft_block_reasons"] = reasons.apply(lambda x: ";".join(x[1]) if x[1] else "")
    df["keep_hints"] = reasons.apply(lambda x: ";".join(x[2]) if x[2] else "")

    df["hard_blocked"] = df["hard_block_reasons"].ne("")
    df["soft_blocked"] = df["soft_block_reasons"].ne("")
    df["kept_by_hints"] = df["keep_hints"].ne("")
    df["kept_by_top"] = df["rank_num"].le(ALWAYS_KEEP_TOP_N)
    df["hard_whitelist"] = df["drug_norm"].isin(HARD_WHITELIST)

    # Winners mode
    df_w = _apply_status_logic(df, mode="winners")
    out_w = _select_and_sort_output(df_w, mode="winners", include_dropped=OUTPUT_INCLUDE_DROPPED_WINNERS)
    OUTPUT_WINNERS_TSV.parent.mkdir(parents=True, exist_ok=True)
    out_w.to_csv(OUTPUT_WINNERS_TSV, sep=SEP, index=False)

    # Phenotype mode
    df_p = _apply_status_logic(df, mode="phenotype")
    out_p = _select_and_sort_output(df_p, mode="phenotype", include_dropped=OUTPUT_INCLUDE_DROPPED_PHENO)
    OUTPUT_PHENOTYPE_TSV.parent.mkdir(parents=True, exist_ok=True)
    out_p.to_csv(OUTPUT_PHENOTYPE_TSV, sep=SEP, index=False)

    # Summary
    input_rows = len(pd.read_csv(INPUT_TSV, sep=SEP))
    print("\n" + "=" * 88)
    print(f"[INFO] Cohort: {cohort}")
    print("[INFO] LC shortlist filtering complete (TWO OUTPUTS)")
    print(f"  Input file:                        {INPUT_TSV}")
    print(f"  Input rows loaded:                 {input_rows:,}")
    print(f"  After TOP_N={TOP_N}:               {len(df):,}")
    print(f"  HARD blocked (within TOP_N):       {int(df['hard_blocked'].sum()):,}")
    print(f"  SOFT blocked (within TOP_N):       {int(df['soft_blocked'].sum()):,}")
    print(f"  Kept by hints (within TOP_N):      {int(df['kept_by_hints'].sum()):,}")
    print(f"  Kept by top-{ALWAYS_KEEP_TOP_N}:   {int(df['kept_by_top'].sum()):,}")

    print("\n[INFO] Winners output")
    print(f"  Output file:                       {OUTPUT_WINNERS_TSV}")
    print(f"  KEEP count (within TOP_N):         {int((df_w['lc_status'] == 'KEEP').sum()):,}")
    print(f"  REVIEW count (within TOP_N):       {int((df_w['lc_status'] == 'REVIEW').sum()):,}")
    print(f"  Final output rows:                 {len(out_w):,}")

    print("\n[INFO] Phenotype output")
    print(f"  Output file:                       {OUTPUT_PHENOTYPE_TSV}")
    print(f"  KEEP count (within TOP_N):         {int((df_p['lc_status_pheno'] == 'KEEP').sum()):,}")
    print(f"  REVIEW count (within TOP_N):       {int((df_p['lc_status_pheno'] == 'REVIEW').sum()):,}")
    print(f"  Final output rows:                 {len(out_p):,}")

    # Preview
    cols_preview = [c for c in [
        "rank_num", "drug",
        "lc_status", "lc_status_pheno",
        "keep_hints", "hard_block_reasons", "soft_block_reasons",
        "safety_score", "efficacy_score", "ae_score", "composite_score"
    ] if c in df_w.columns or c in df_p.columns]

    print("\n[INFO] Preview (winners, top 30):")
    _cols_w = [c for c in cols_preview if c in out_w.columns]
    if _cols_w:
        print(out_w[_cols_w].head(30).to_string(index=False))

    print("\n[INFO] Preview (phenotype, top 30):")
    _cols_p = [c for c in cols_preview if c in out_p.columns]
    if _cols_p:
        print(out_p[_cols_p].head(30).to_string(index=False))

    return out_w, out_p


# =============================================================================
# 5) MAIN
# =============================================================================
def main() -> None:
    if not SCORE_RANKING_ROOT.is_dir():
        raise NotADirectoryError(f"SCORE_RANKING_ROOT not found: {SCORE_RANKING_ROOT}")

    cohorts = COHORTS_TO_RUN if COHORTS_TO_RUN else infer_cohorts(SCORE_RANKING_ROOT)

    if not cohorts:
        print(f"[ERROR] No cohorts found under: {SCORE_RANKING_ROOT}")
        return

    print(f"[INFO] Cohorts to process: {len(cohorts)}")
    for c in cohorts:
        print(f"  - {c}")

    ok = 0
    for cohort in cohorts:
        res = run_for_one_cohort(cohort)
        if res is not None:
            ok += 1

    print("\n" + "=" * 88)
    print("[INFO] Done.")
    print(f"[INFO] Cohorts processed successfully: {ok}/{len(cohorts)}")


if __name__ == "__main__":
    main()

[INFO] Cohorts to process: 2
  - NCT04809974
  - NCT04880161

[INFO] Cohort: NCT04809974
[INFO] LC shortlist filtering complete (TWO OUTPUTS)
  Input file:                        /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_ranked_drugs_by_composite_SAE.tsv
  Input rows loaded:                 1,625
  After TOP_N=300:               300
  HARD blocked (within TOP_N):       52
  SOFT blocked (within TOP_N):       6
  Kept by hints (within TOP_N):      42
  Kept by top-60:   60

[INFO] Winners output
  Output file:                       /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_ranked_drugs_by_composite_SAE_LC_winners.tsv
  KEEP count (within TOP_N):         54
  REVIEW count (within TOP_N):       5
  Final output rows:                 59

[INFO] Phenotype output
  Output file:      

## **Wet-lab Enrichment**

In [16]:
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional

import pandas as pd
import requests
import warnings

warnings.filterwarnings("ignore")


# =============================================================================
# CONFIG
# =============================================================================

class Config:
    PUBCHEM_BASE = "https://pubchem.ncbi.nlm.nih.gov/rest/pug"
    CHEMBL_BASE = "https://www.ebi.ac.uk/chembl/api/data"
    OPENTARGETS_BASE = "https://api.platform.opentargets.org/api/v4/graphql"
    DAILYMED_BASE = "https://dailymed.nlm.nih.gov/dailymed/services/v2"

    REQUEST_DELAY = 0.5
    MAX_RETRIES = 3
    TIMEOUT = 30


# =============================================================================
# SMALL UTILS
# =============================================================================

def safe_to_float(x: Any) -> Optional[float]:
    """Convert value to float if possible; otherwise return None."""
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    if isinstance(x, str):
        s = x.strip()
        if s == "":
            return None
        s = s.replace(",", "")
        try:
            return float(s)
        except Exception:
            return None
    return None


def normalize_drug_key(x: Any) -> str:
    """Stable key for resume logic."""
    if x is None:
        return ""
    s = str(x).strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s


def ensure_parent_dir(path: str) -> None:
    Path(path).parent.mkdir(parents=True, exist_ok=True)


def append_row_to_csv(csv_path: str, row: Dict[str, Any]) -> None:
    """
    Append one row to a CSV checkpoint. Creates the file (with header) if missing.
    """
    ensure_parent_dir(csv_path)
    df_row = pd.DataFrame([row])
    file_exists = os.path.exists(csv_path) and os.path.getsize(csv_path) > 0
    df_row.to_csv(csv_path, mode="a", header=not file_exists, index=False)


def dedupe_checkpoint(df_ckpt: pd.DataFrame) -> pd.DataFrame:
    """
    Keep the best/latest record per input drug.
    Preference:
      1) non-error over error
      2) later fetched_at_utc
    """
    if df_ckpt.empty:
        return df_ckpt

    if "input_drug_key" not in df_ckpt.columns:
        df_ckpt["input_drug_key"] = df_ckpt.get("input_drug", "").apply(normalize_drug_key)

    if "error" not in df_ckpt.columns:
        df_ckpt["error"] = None

    if "fetched_at_utc" not in df_ckpt.columns:
        df_ckpt["fetched_at_utc"] = None

    # non-error first (so when we keep last, we sort by "is_error" then time)
    df_ckpt["is_error"] = df_ckpt["error"].notna() & (df_ckpt["error"].astype(str).str.strip() != "")

    # Robust sort: fetched_at_utc could be missing / non-ISO; treat as string safely
    df_ckpt["_t"] = df_ckpt["fetched_at_utc"].astype(str)

    # Sort so that the "best" row is last within each group
    df_ckpt = df_ckpt.sort_values(by=["input_drug_key", "is_error", "_t"], ascending=[True, True, True])

    out = df_ckpt.groupby("input_drug_key", as_index=False).tail(1).drop(columns=["is_error", "_t"], errors="ignore")
    return out


def save_output(result_df: pd.DataFrame, output_file: str) -> None:
    ensure_parent_dir(output_file)
    if output_file.lower().endswith(".xlsx"):
        result_df.to_excel(output_file, index=False, engine="openpyxl")
    else:
        result_df.to_csv(output_file, index=False)


# =============================================================================
# DATA CLASSES
# =============================================================================

@dataclass
class PharmacologicalProfile:
    drug_name: str

    pubchem_cid: Optional[int] = None
    chembl_id: Optional[str] = None
    opentargets_id: Optional[str] = None

    cas_number: Optional[str] = None
    inchi_key: Optional[str] = None
    smiles: Optional[str] = None

    molecular_weight: Optional[float] = None
    logp: Optional[float] = None
    logd: Optional[float] = None
    psa: Optional[float] = None
    hbd: Optional[int] = None
    hba: Optional[int] = None
    rotatable_bonds: Optional[int] = None
    aqueous_solubility: Optional[str] = None

    ic50_values: List[Dict] = field(default_factory=list)
    ec50_values: List[Dict] = field(default_factory=list)
    ki_values: List[Dict] = field(default_factory=list)
    kd_values: List[Dict] = field(default_factory=list)

    primary_targets: List[str] = field(default_factory=list)
    mechanism_of_action: Optional[str] = None

    approval_status: Optional[str] = None
    approved_indications: List[str] = field(default_factory=list)
    black_box_warnings: List[str] = field(default_factory=list)
    route_of_administration: List[str] = field(default_factory=list)


# =============================================================================
# API BASE
# =============================================================================

class APIFetcher:
    @staticmethod
    def safe_request(
        url: str,
        params: dict = None,
        method: str = "GET",
        json_data: dict = None,
        headers: dict = None
    ) -> Optional[dict]:
        for attempt in range(Config.MAX_RETRIES):
            try:
                if method == "GET":
                    resp = requests.get(url, params=params, headers=headers, timeout=Config.TIMEOUT)
                else:
                    resp = requests.post(url, json=json_data, headers=headers, timeout=Config.TIMEOUT)

                if resp.status_code == 200:
                    return resp.json() if resp.text else {}
                if resp.status_code == 404:
                    return None

            except Exception as e:
                if attempt == Config.MAX_RETRIES - 1:
                    print(f"      ⚠️ API error: {str(e)[:80]}")
                time.sleep(1)
        return None


# =============================================================================
# PUBCHEM
# =============================================================================

class PubChemFetcher(APIFetcher):
    @staticmethod
    def get_cid_by_name(drug_name: str) -> Optional[int]:
        url = f"{Config.PUBCHEM_BASE}/compound/name/{drug_name}/cids/JSON"
        data = PubChemFetcher.safe_request(url)
        if data and "IdentifierList" in data:
            cids = data["IdentifierList"].get("CID", [])
            return cids[0] if cids else None
        return None

    @staticmethod
    def get_compound_properties(cid: int) -> Dict:
        properties = [
            "MolecularWeight", "XLogP", "TPSA", "HBondDonorCount",
            "HBondAcceptorCount", "RotatableBondCount", "CanonicalSMILES",
            "InChIKey", "MolecularFormula"
        ]
        prop_string = ",".join(properties)
        url = f"{Config.PUBCHEM_BASE}/compound/cid/{cid}/property/{prop_string}/JSON"
        data = PubChemFetcher.safe_request(url)

        if data and "PropertyTable" in data:
            props = data["PropertyTable"].get("Properties", [{}])[0]
            return {
                "molecular_weight": safe_to_float(props.get("MolecularWeight")),
                "logp": safe_to_float(props.get("XLogP")),
                "psa": safe_to_float(props.get("TPSA")),
                "hbd": props.get("HBondDonorCount"),
                "hba": props.get("HBondAcceptorCount"),
                "rotatable_bonds": props.get("RotatableBondCount"),
                "smiles": props.get("CanonicalSMILES"),
                "inchi_key": props.get("InChIKey"),
                "formula": props.get("MolecularFormula"),
            }
        return {}

    @staticmethod
    def get_bioassay_data(cid: int) -> List[Dict]:
        url = f"{Config.PUBCHEM_BASE}/compound/cid/{cid}/assaysummary/JSON"
        data = PubChemFetcher.safe_request(url)

        assay_results = []
        if data and "Table" in data:
            columns = data["Table"].get("Columns", {}).get("Column", [])
            rows = data["Table"].get("Row", [])
            for row in rows[:50]:
                cells = row.get("Cell", [])
                if len(cells) >= len(columns):
                    result = dict(zip(columns, cells))
                    if result.get("Activity Outcome") == "Active":
                        assay_results.append({
                            "assay_name": result.get("Assay Name", ""),
                            "target": result.get("Target Name", ""),
                            "activity_value": result.get("Activity Value", ""),
                            "activity_unit": result.get("Activity Unit", "")
                        })
        return assay_results


# =============================================================================
# ChEMBL
# =============================================================================

class ChEMBLFetcher(APIFetcher):
    @staticmethod
    def search_molecule(drug_name: str) -> Optional[str]:
        url = f"{Config.CHEMBL_BASE}/molecule/search.json"
        params = {"q": drug_name, "limit": 1}
        data = ChEMBLFetcher.safe_request(url, params=params)
        if data and "molecules" in data and data["molecules"]:
            return data["molecules"][0].get("molecule_chembl_id")
        return None

    @staticmethod
    def get_molecule_details(chembl_id: str) -> Dict:
        url = f"{Config.CHEMBL_BASE}/molecule/{chembl_id}.json"
        data = ChEMBLFetcher.safe_request(url)
        if data:
            props = data.get("molecule_properties", {}) or {}
            return {
                "molecular_weight": safe_to_float(props.get("full_mwt")),
                "alogp": safe_to_float(props.get("alogp")),
                "psa": safe_to_float(props.get("psa")),
                "hbd": props.get("hbd"),
                "hba": props.get("hba"),
                "rotatable_bonds": props.get("rtb"),
                "max_phase": data.get("max_phase"),
                "oral": data.get("oral"),
                "parenteral": data.get("parenteral"),
                "topical": data.get("topical"),
                "black_box_warning": data.get("black_box_warning"),
            }
        return {}

    @staticmethod
    def get_bioactivities(chembl_id: str) -> List[Dict]:
        url = f"{Config.CHEMBL_BASE}/activity.json"
        params = {
            "molecule_chembl_id": chembl_id,
            "limit": 100,
            "standard_type__in": "IC50,Ki,Kd,EC50,IC90"
        }
        data = ChEMBLFetcher.safe_request(url, params=params)

        activities = []
        if data and "activities" in data:
            for act in data["activities"]:
                activities.append({
                    "type": act.get("standard_type"),
                    "value": act.get("standard_value"),
                    "units": act.get("standard_units"),
                    "relation": act.get("standard_relation", "="),
                    "target": act.get("target_pref_name"),
                    "target_chembl_id": act.get("target_chembl_id"),
                    "target_organism": act.get("target_organism"),
                    "assay_type": act.get("assay_type"),
                    "pchembl_value": act.get("pchembl_value"),
                })
        return activities

    @staticmethod
    def get_mechanism(chembl_id: str) -> List[Dict]:
        url = f"{Config.CHEMBL_BASE}/mechanism.json"
        params = {"molecule_chembl_id": chembl_id}
        data = ChEMBLFetcher.safe_request(url, params=params)

        mechanisms = []
        if data and "mechanisms" in data:
            for mech in data["mechanisms"]:
                mechanisms.append({
                    "mechanism": mech.get("mechanism_of_action"),
                    "action_type": mech.get("action_type"),
                    "target": mech.get("target_chembl_id"),
                })
        return mechanisms

    @staticmethod
    def get_drug_indications(chembl_id: str) -> List[Dict]:
        url = f"{Config.CHEMBL_BASE}/drug_indication.json"
        params = {"molecule_chembl_id": chembl_id}
        data = ChEMBLFetcher.safe_request(url, params=params)

        indications = []
        if data and "drug_indications" in data:
            for ind in data["drug_indications"]:
                indications.append({
                    "indication": ind.get("mesh_heading"),
                    "max_phase": ind.get("max_phase_for_ind"),
                })
        return indications


# =============================================================================
# OPEN TARGETS (search by NAME)
# =============================================================================

class OpenTargetsFetcher(APIFetcher):
    @staticmethod
    def search_drug_by_name(drug_name: str) -> Optional[str]:
        query = """
        query searchDrug($drugName: String!) {
          search(queryString: $drugName, entityNames: ["drug"], page: {size: 1, index: 0}) {
            hits { id name }
          }
        }
        """
        data = OpenTargetsFetcher.safe_request(
            Config.OPENTARGETS_BASE,
            method="POST",
            json_data={"query": query, "variables": {"drugName": drug_name}},
            headers={"Content-Type": "application/json"},
        )
        if data and "data" in data:
            hits = data["data"].get("search", {}).get("hits", [])
            return hits[0]["id"] if hits else None
        return None

    @staticmethod
    def get_drug_details_by_id(drug_id: str) -> Dict:
        query = """
        query drugInfo($drugId: String!) {
          drug(chemblId: $drugId) {
            id
            name
            maximumClinicalTrialPhase
            hasBeenWithdrawn
            withdrawnNotice { reason year }
            mechanismsOfAction { rows { mechanismOfAction actionType targetName } }
            indications { rows { disease { name } maxPhaseForIndication } }
            drugWarnings { warningType description }
          }
        }
        """
        data = OpenTargetsFetcher.safe_request(
            Config.OPENTARGETS_BASE,
            method="POST",
            json_data={"query": query, "variables": {"drugId": drug_id}},
            headers={"Content-Type": "application/json"},
        )
        if data and "data" in data and data["data"].get("drug"):
            return data["data"]["drug"]
        return {}


# =============================================================================
# DAILYMED
# =============================================================================

class DailyMedFetcher(APIFetcher):
    @staticmethod
    def search_drug(drug_name: str) -> Optional[str]:
        url = f"{Config.DAILYMED_BASE}/spls.json"
        params = {"drug_name": drug_name, "page": 1, "pagesize": 1}
        data = DailyMedFetcher.safe_request(url, params=params)
        if data and "data" in data and data["data"]:
            return data["data"][0].get("setid")
        return None

    @staticmethod
    def get_label_sections(set_id: str) -> Dict:
        url = f"{Config.DAILYMED_BASE}/spls/{set_id}.json"
        data = DailyMedFetcher.safe_request(url)
        result = {}
        if data and "data" in data:
            drug_data = data["data"]
            result["published_date"] = drug_data.get("published_date")
            result["product_type"] = drug_data.get("product_type")
        return result


# =============================================================================
# AGGREGATOR
# =============================================================================

class DrugDataAggregator:
    def __init__(self, drug_name: str):
        self.raw_name = str(drug_name)
        self.drug_name = self._clean_drug_name(self.raw_name)
        self.profile = PharmacologicalProfile(drug_name=self.drug_name)

    @staticmethod
    def _clean_drug_name(name: str) -> str:
        name = str(name).replace("_", " ")
        name = re.sub(r"\s*(etabonate|acetate|hydrochloride|sodium|potassium|sulfate)$", "", name, flags=re.I)
        return name.strip()

    def fetch_all_data(self, verbose: bool = True) -> PharmacologicalProfile:
        if verbose:
            print(f"\n{'='*60}\n🔍 Fetching data for: {self.drug_name}\n{'='*60}")

        self._fetch_pubchem(verbose)
        time.sleep(Config.REQUEST_DELAY)

        self._fetch_chembl(verbose)
        time.sleep(Config.REQUEST_DELAY)

        self._fetch_opentargets_by_name(verbose)
        time.sleep(Config.REQUEST_DELAY)

        self._fetch_dailymed(verbose)
        return self.profile

    def _fetch_pubchem(self, verbose: bool = True):
        if verbose:
            print("   → Querying PubChem...")

        cid = PubChemFetcher.get_cid_by_name(self.drug_name)
        if not cid:
            if verbose:
                print("      ❌ Not found in PubChem")
            return

        self.profile.pubchem_cid = cid
        if verbose:
            print(f"      ✅ Found CID: {cid}")

        props = PubChemFetcher.get_compound_properties(cid)
        if props:
            self.profile.molecular_weight = props.get("molecular_weight")
            self.profile.logp = props.get("logp")
            self.profile.psa = props.get("psa")
            self.profile.hbd = props.get("hbd")
            self.profile.hba = props.get("hba")
            self.profile.rotatable_bonds = props.get("rotatable_bonds")
            self.profile.smiles = props.get("smiles")
            self.profile.inchi_key = props.get("inchi_key")

        assays = PubChemFetcher.get_bioassay_data(cid)
        if assays and verbose:
            print(f"      ✅ Found {len(assays)} active bioassays")

    def _fetch_chembl(self, verbose: bool = True):
        if verbose:
            print("   → Querying ChEMBL...")

        chembl_id = ChEMBLFetcher.search_molecule(self.drug_name)
        if not chembl_id:
            if verbose:
                print("      ❌ Not found in ChEMBL")
            return

        self.profile.chembl_id = chembl_id
        if verbose:
            print(f"      ✅ Found: {chembl_id}")

        details = ChEMBLFetcher.get_molecule_details(chembl_id)
        if details:
            if details.get("max_phase") is not None:
                phases = {0: "Preclinical", 1: "Phase I", 2: "Phase II", 3: "Phase III", 4: "Approved"}
                self.profile.approval_status = phases.get(details["max_phase"], str(details["max_phase"]))

            routes = []
            if details.get("oral"):
                routes.append("Oral")
            if details.get("parenteral"):
                routes.append("Parenteral")
            if details.get("topical"):
                routes.append("Topical")
            self.profile.route_of_administration = routes

            if details.get("black_box_warning"):
                self.profile.black_box_warnings.append("Has black box warning (see FDA label)")

        activities = ChEMBLFetcher.get_bioactivities(chembl_id)
        if activities:
            if verbose:
                print(f"      ✅ Found {len(activities)} bioactivity records")

            for act in activities:
                act_type = (act.get("type") or "").upper()
                entry = {
                    "value": act.get("value"),
                    "units": act.get("units"),
                    "target": act.get("target"),
                    "pchembl": act.get("pchembl_value"),
                }
                if act_type == "IC50":
                    self.profile.ic50_values.append(entry)
                elif act_type == "EC50":
                    self.profile.ec50_values.append(entry)
                elif act_type == "KI":
                    self.profile.ki_values.append(entry)
                elif act_type == "KD":
                    self.profile.kd_values.append(entry)

                tgt = act.get("target")
                if tgt and tgt not in self.profile.primary_targets:
                    self.profile.primary_targets.append(tgt)

        mechanisms = ChEMBLFetcher.get_mechanism(chembl_id)
        if mechanisms:
            moa_list = [m.get("mechanism") for m in mechanisms if m.get("mechanism")]
            self.profile.mechanism_of_action = "; ".join(moa_list[:3]) if moa_list else None

        indications = ChEMBLFetcher.get_drug_indications(chembl_id)
        if indications:
            self.profile.approved_indications = [i.get("indication") for i in indications if i.get("indication")][:10]

    def _fetch_opentargets_by_name(self, verbose: bool = True):
        if verbose:
            print("   → Querying Open Targets (by name)...")

        ot_id = OpenTargetsFetcher.search_drug_by_name(self.drug_name)
        if not ot_id:
            if verbose:
                print("      ❌ Not found in Open Targets by name")
            return

        self.profile.opentargets_id = ot_id
        details = OpenTargetsFetcher.get_drug_details_by_id(ot_id)
        if not details:
            if verbose:
                print("      ⚠️ Found hit, but could not retrieve details")
            return

        if verbose:
            print("      ✅ Found additional data")

        warnings_ = details.get("drugWarnings", []) or []
        for w in warnings_:
            desc = w.get("description")
            if desc:
                self.profile.black_box_warnings.append(desc)

        if not self.profile.approval_status:
            phase = details.get("maximumClinicalTrialPhase")
            if phase is not None:
                phase_map = {0: "Preclinical", 1: "Phase I", 2: "Phase II", 3: "Phase III", 4: "Approved"}
                self.profile.approval_status = phase_map.get(phase, str(phase))

    def _fetch_dailymed(self, verbose: bool = True):
        if verbose:
            print("   → Querying DailyMed...")

        set_id = DailyMedFetcher.search_drug(self.drug_name)
        if not set_id:
            if verbose:
                print("      ❌ Not found in DailyMed")
            return

        if verbose:
            print("      ✅ Found FDA label")
        label = DailyMedFetcher.get_label_sections(set_id)
        if label and not self.profile.approval_status:
            self.profile.approval_status = "FDA Listed"


# =============================================================================
# OUTPUT HELPERS
# =============================================================================

def summarize_activity_values(values: List[Dict], activity_type: str) -> str:
    if not values:
        return "N/A"
    summaries = []
    for v in values[:5]:
        val = v.get("value")
        units = v.get("units", "nM")
        target = v.get("target", "Unknown target")
        if val:
            summaries.append(f"{val} {units} ({str(target)[:30]})")
    return "; ".join(summaries) if summaries else "N/A"


def calculate_ro5_violations(profile: PharmacologicalProfile) -> int:
    violations = 0
    mw = safe_to_float(profile.molecular_weight)
    logp = safe_to_float(profile.logp)
    hbd = safe_to_float(profile.hbd)
    hba = safe_to_float(profile.hba)

    if mw is not None and mw > 500:
        violations += 1
    if logp is not None and logp > 5:
        violations += 1
    if hbd is not None and hbd > 5:
        violations += 1
    if hba is not None and hba > 10:
        violations += 1
    return violations


def profile_to_dict(profile: PharmacologicalProfile) -> Dict[str, Any]:
    return {
        "drug_name": profile.drug_name,
        "pubchem_cid": profile.pubchem_cid,
        "chembl_id": profile.chembl_id,
        "opentargets_id": profile.opentargets_id,
        "inchi_key": profile.inchi_key,
        "smiles": (profile.smiles[:100] if profile.smiles else None),

        "MW": safe_to_float(profile.molecular_weight),
        "LogP": safe_to_float(profile.logp),
        "PSA": safe_to_float(profile.psa),
        "HBD": profile.hbd,
        "HBA": profile.hba,
        "RotBonds": profile.rotatable_bonds,

        "IC50_summary": summarize_activity_values(profile.ic50_values, "IC50"),
        "EC50_summary": summarize_activity_values(profile.ec50_values, "EC50"),
        "Ki_summary": summarize_activity_values(profile.ki_values, "Ki"),
        "Kd_summary": summarize_activity_values(profile.kd_values, "Kd"),
        "n_IC50_records": len(profile.ic50_values),
        "n_EC50_records": len(profile.ec50_values),
        "n_Ki_records": len(profile.ki_values),

        "primary_targets": ("; ".join(profile.primary_targets[:5]) if profile.primary_targets else "N/A"),
        "n_targets": len(profile.primary_targets),
        "mechanism_of_action": profile.mechanism_of_action or "N/A",

        "approval_status": profile.approval_status or "Unknown",
        "approved_indications": ("; ".join(profile.approved_indications[:3]) if profile.approved_indications else "N/A"),
        "route_of_administration": ("; ".join(profile.route_of_administration) if profile.route_of_administration else "N/A"),
        "black_box_warnings": ("; ".join(profile.black_box_warnings[:2]) if profile.black_box_warnings else "None reported"),

        "Ro5_violations": calculate_ro5_violations(profile),
    }


# =============================================================================
# MAIN PIPELINE (RESUMABLE)
# =============================================================================

def run_pharmacology_enrichment(
    input_file: str,
    output_file: Optional[str] = None,
    sep: Optional[str] = None,
    drug_col: str = "drug",
    top_n: Optional[int] = None,
    verbose: bool = True,
    checkpoint_file: Optional[str] = None,
    save_every: int = 1,
    retry_failed: bool = True,
) -> pd.DataFrame:
    """
    Resumable enrichment:
      - Writes a per-drug checkpoint row after each drug.
      - On restart, skips drugs already completed (or also skips failed if retry_failed=False).
      - Periodically writes full merged output (output_file) every `save_every` processed drugs.
    """

    print(f"\n📂 Reading input file: {input_file}")
    if sep is None:
        df = pd.read_csv(input_file)
    else:
        df = pd.read_csv(input_file, sep=sep)

    if drug_col not in df.columns:
        raise ValueError(f"Input file must have a '{drug_col}' column")

    all_drugs_raw = df[drug_col].dropna().astype(str).unique().tolist()
    if top_n is not None:
        all_drugs_raw = all_drugs_raw[:top_n]

    # Default checkpoint: alongside output file, or next to input file.
    if checkpoint_file is None:
        if output_file:
            checkpoint_file = str(Path(output_file).with_suffix(".checkpoint.csv"))
        else:
            checkpoint_file = str(Path(input_file).with_suffix(".checkpoint.csv"))

    print(f"📋 Total unique drugs to consider: {len(all_drugs_raw)}")
    print(f"🧷 Checkpoint file: {checkpoint_file}")
    if output_file:
        print(f"📦 Final output file: {output_file}")
    print("=" * 60)

    # Load existing checkpoint (if any)
    ckpt_df = pd.DataFrame()
    if checkpoint_file and os.path.exists(checkpoint_file) and os.path.getsize(checkpoint_file) > 0:
        try:
            ckpt_df = pd.read_csv(checkpoint_file)
            ckpt_df = dedupe_checkpoint(ckpt_df)
            print(f"✅ Loaded checkpoint: {len(ckpt_df)} unique drugs already recorded")
        except Exception as e:
            print(f"⚠️ Could not read checkpoint '{checkpoint_file}': {e}")
            ckpt_df = pd.DataFrame()

    # Decide what is "done" for skipping
    done_keys: set[str] = set()
    if not ckpt_df.empty:
        ckpt_df["input_drug_key"] = ckpt_df.get("input_drug", "").apply(normalize_drug_key)
        if retry_failed:
            ok_mask = ~(ckpt_df.get("error", "").astype(str).str.strip() != "")
            done_keys = set(ckpt_df.loc[ok_mask, "input_drug_key"].tolist())
        else:
            done_keys = set(ckpt_df["input_drug_key"].tolist())

    # Iterate drugs
    newly_processed = 0
    records_in_memory: List[Dict[str, Any]] = []
    if not ckpt_df.empty:
        records_in_memory = ckpt_df.to_dict(orient="records")

    total = len(all_drugs_raw)
    for idx, drug_raw in enumerate(all_drugs_raw, 1):
        key = normalize_drug_key(drug_raw)
        if key in done_keys:
            if verbose:
                print(f"[{idx}/{total}] Skipping (already done): {drug_raw}")
            continue

        print(f"\n[{idx}/{total}] Processing: {drug_raw}")

        row_out: Dict[str, Any] = {
            "input_drug": drug_raw,
            "input_drug_key": key,
            "query_drug": None,
            "fetched_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "error": None,
        }

        try:
            aggregator = DrugDataAggregator(drug_raw)
            row_out["query_drug"] = aggregator.drug_name
            profile = aggregator.fetch_all_data(verbose=verbose)
            row_out.update(profile_to_dict(profile))
        except Exception as e:
            msg = str(e)
            print(f"   ⚠️ Error processing {drug_raw}: {msg}")
            row_out["error"] = msg

        # Save checkpoint immediately (so you can resume after any crash)
        append_row_to_csv(checkpoint_file, row_out)

        # Keep in memory (and de-dupe later)
        records_in_memory.append(row_out)
        newly_processed += 1

        # Periodic full output save
        if output_file and save_every > 0 and (newly_processed % save_every == 0):
            tmp_ckpt = pd.DataFrame(records_in_memory)
            tmp_ckpt = dedupe_checkpoint(tmp_ckpt)

            pharma_df = tmp_ckpt.copy()
            # Merge back to original input on the original input drug string
            if "input_drug" in pharma_df.columns:
                result_df = df.merge(pharma_df, left_on=drug_col, right_on="input_drug", how="left")
                result_df = result_df.drop(columns=["input_drug_key"], errors="ignore")
            else:
                # Fallback
                result_df = df.copy()

            print(f"\n{'='*60}\n💾 Saving intermediate results to: {output_file}")
            save_output(result_df, output_file)
            print("✅ Intermediate save complete")

        time.sleep(Config.REQUEST_DELAY)

    # Final output
    final_ckpt = pd.DataFrame(records_in_memory)
    if not final_ckpt.empty:
        final_ckpt = dedupe_checkpoint(final_ckpt)

    if not final_ckpt.empty and "input_drug" in final_ckpt.columns:
        result_df = df.merge(final_ckpt, left_on=drug_col, right_on="input_drug", how="left")
        result_df = result_df.drop(columns=["input_drug_key"], errors="ignore")
    else:
        result_df = df.copy()

    if output_file:
        print(f"\n{'='*60}\n💾 Saving FINAL results to: {output_file}")
        save_output(result_df, output_file)
        print("✅ Final save complete")

    return result_df


def display_summary(result_df: pd.DataFrame, output_file: Optional[str] = None) -> None:
    print("\n" + "=" * 60)
    print("📊 FINAL SUMMARY")
    print("=" * 60)
    print(f"Total rows: {len(result_df)}")
    if "chembl_id" in result_df.columns:
        print(f"Drugs with ChEMBL data: {result_df['chembl_id'].notna().sum()}")
    if "pubchem_cid" in result_df.columns:
        print(f"Drugs with PubChem data: {result_df['pubchem_cid'].notna().sum()}")
    if "opentargets_id" in result_df.columns:
        print(f"Drugs with Open Targets hit: {result_df['opentargets_id'].notna().sum()}")
    if "error" in result_df.columns:
        n_err = (result_df["error"].astype(str).str.strip() != "").sum()
        print(f"Rows with error: {n_err}")
    if output_file:
        print(f"\n📁 Output saved to: {output_file}")

#### **Cuttoff Enrichment**

##### **NCT04809974**

In [17]:
from pathlib import Path

# Folder containing the avoid_* files
BASE_DIR = Path("/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974")

# Loop over all files that start with: NCT04809974_avoid_
avoid_files = sorted(BASE_DIR.glob("NCT04809974_avoid_*"))

if not avoid_files:
    raise FileNotFoundError(f"No files found matching 'NCT04809974_avoid_*' in: {BASE_DIR}")

for input_path in avoid_files:
    # output + checkpoint names derived from the input filename
    stem = input_path.stem  # e.g., NCT04809974_avoid_primary_composite_q10
    output_xlsx = BASE_DIR / f"{stem}_enriched_pharmacology.xlsx"
    checkpoint_csv = BASE_DIR / f"{stem}_enriched_pharmacology.checkpoint.csv"

    print("\n" + "=" * 100)
    print(f"[RUN] Input:       {input_path.name}")
    print(f"      Output:      {output_xlsx.name}")
    print(f"      Checkpoint:  {checkpoint_csv.name}")

    result = run_pharmacology_enrichment(
        input_file=str(input_path),
        output_file=str(output_xlsx),
        sep="\t" if input_path.suffix.lower() == ".tsv" else ",",  # auto sep by extension
        drug_col="drug",
        top_n=None,        # processes ALL unique drugs
        verbose=True,
        checkpoint_file=str(checkpoint_csv),
        save_every=1,      # write full Excel every 1 drug
        retry_failed=True  # retry previously failed drugs on re-run
    )

    display_summary(result, str(output_xlsx))


[RUN] Input:       NCT04809974_avoid_delta_q10.csv
      Output:      NCT04809974_avoid_delta_q10_enriched_pharmacology.xlsx
      Checkpoint:  NCT04809974_avoid_delta_q10_enriched_pharmacology.checkpoint.csv

📂 Reading input file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_avoid_delta_q10.csv
📋 Total unique drugs to consider: 457
🧷 Checkpoint file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_avoid_delta_q10_enriched_pharmacology.checkpoint.csv
📦 Final output file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_avoid_delta_q10_enriched_pharmacology.xlsx

[1/457] Processing: Regorafenib

🔍 Fetching data for: Regorafenib
   → Querying PubChem...
      ✅ Found CID: 11167602
      ✅ F

##### **NCT04880161**

In [65]:
from pathlib import Path

# Folder containing the avoid_* files
BASE_DIR = Path("/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161")

# Loop over all files that start with: NCT04809974_avoid_
avoid_files = sorted(BASE_DIR.glob("NCT04880161_avoid_*"))

if not avoid_files:
    raise FileNotFoundError(f"No files found matching 'NCT04880161_avoid_*' in: {BASE_DIR}")

for input_path in avoid_files:
    # output + checkpoint names derived from the input filename
    stem = input_path.stem  # e.g., NCT04880161_avoid_primary_composite_q10
    output_xlsx = BASE_DIR / f"{stem}_enriched_pharmacology.xlsx"
    checkpoint_csv = BASE_DIR / f"{stem}_enriched_pharmacology.checkpoint.csv"

    print("\n" + "=" * 100)
    print(f"[RUN] Input:       {input_path.name}")
    print(f"      Output:      {output_xlsx.name}")
    print(f"      Checkpoint:  {checkpoint_csv.name}")

    result = run_pharmacology_enrichment(
        input_file=str(input_path),
        output_file=str(output_xlsx),
        sep="\t" if input_path.suffix.lower() == ".tsv" else ",",  # auto sep by extension
        drug_col="drug",
        top_n=None,        # processes ALL unique drugs
        verbose=True,
        checkpoint_file=str(checkpoint_csv),
        save_every=1,      # write full Excel every 1 drug
        retry_failed=True  # retry previously failed drugs on re-run
    )

    display_summary(result, str(output_xlsx))


[RUN] Input:       NCT04880161_avoid_delta_q10.csv
      Output:      NCT04880161_avoid_delta_q10_enriched_pharmacology.xlsx
      Checkpoint:  NCT04880161_avoid_delta_q10_enriched_pharmacology.checkpoint.csv

📂 Reading input file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_avoid_delta_q10.csv
📋 Total unique drugs to consider: 412
🧷 Checkpoint file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_avoid_delta_q10_enriched_pharmacology.checkpoint.csv
📦 Final output file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_avoid_delta_q10_enriched_pharmacology.xlsx

[1/412] Processing: Selexipag

🔍 Fetching data for: Selexipag
   → Querying PubChem...
      ✅ Found CID: 9913767
      ✅ Found 

### **LC Enrichment**

#### **NCT04809974**

In [None]:
# ============================================================================
# Call Function
# ============================================================================
if __name__ == "__main__":
    result = run_pharmacology_enrichment(
        input_file="/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_ranked_drugs_by_composite_SAE_LC_phenotype.tsv",
        output_file="/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_enriched_pharmacology.xlsx",
        sep="\t",
        drug_col="drug",
        top_n=None,          # ✅ processes ALL unique drugs
        verbose=True,
        checkpoint_file="/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_enriched_pharmacology.checkpoint.csv",
        save_every=1,       # write full Excel every 1 drug
        retry_failed=True    # retry previously failed drugs on re-run
    )
    display_summary(
        result,
        "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_enriched_pharmacology.xlsx"
    )


📂 Reading input file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_ranked_drugs_by_composite_SAE_LC_phenotype.tsv
📋 Total unique drugs to consider: 42
🧷 Checkpoint file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_enriched_pharmacology.checkpoint.csv
📦 Final output file: /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04809974/NCT04809974_enriched_pharmacology.xlsx
✅ Loaded checkpoint: 1 unique drugs already recorded

[1/42] Processing: Propranolol

🔍 Fetching data for: Propranolol
   → Querying PubChem...
      ✅ Found CID: 4946
      ✅ Found 5 active bioassays
   → Querying ChEMBL...
      ✅ Found: CHEMBL1671
      ✅ Found 22 bioactivity records
   → Querying Open Targets (by name)...
      ⚠️ Found hi

#### **NCT04880161**

In [None]:
# ============================================================================
# Call Function
# ============================================================================
if __name__ == "__main__":
    result = run_pharmacology_enrichment(
        input_file="/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_ranked_drugs_by_composite_SAE_LC_phenotype.tsv",
        output_file="/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_enriched_pharmacology.xlsx",
        sep="\t",
        drug_col="drug",
        top_n=None,          # ✅ processes ALL unique drugs
        verbose=True,
        checkpoint_file="/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_enriched_pharmacology.checkpoint.csv",
        save_every=1,       # write full Excel every 1 drug
        retry_failed=True    # retry previously failed drugs on re-run
    )
    display_summary(
        result,
        "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_enriched_pharmacology.xlsx"
    )


📂 Reading input file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_ranked_drugs_by_composite_SAE_LC_phenotype.tsv
📋 Total unique drugs to consider: 12
🧷 Checkpoint file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_enriched_pharmacology.checkpoint.csv
📦 Final output file: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/4_Score_Ranking/NCT04880161/NCT04880161_enriched_pharmacology.xlsx
✅ Loaded checkpoint: 1 unique drugs already recorded

[1/12] Processing: Ivabradine

🔍 Fetching data for: Ivabradine
   → Querying PubChem...
      ✅ Found CID: 132999
      ✅ Found 9 active bioassays
   → Querying ChEMBL...
      ⚠️ API error: HTTPSConnectionPool(host='www.ebi.ac.uk', port=443): Read timed out. (read timeo
      ❌ Not found in ChEMBL
   → Querying Open Targets (by

## **Files Convertion**

### **JSON/HTML**

In [None]:
# Transform JSON files in a directory into pretty-printed HTML files.
import json
from pathlib import Path

# ========= CONFIG =========
DATA_DIR = Path(r"/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/Ground_Truth")

# ========= PER-FILE HTML GENERATION =========
for path in sorted(DATA_DIR.glob("*.json")):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Pretty-print JSON
    pretty = json.dumps(data, indent=2)

    # Escape HTML special chars
    pretty_escaped = (
        pretty.replace("&", "&amp;")
              .replace("<", "&lt;")
              .replace(">", "&gt;")
    )

    # Build HTML for this file
    html_parts = [
        "<!doctype html>",
        "<html>",
        "<head>",
        "  <meta charset='utf-8'>",
        f"  <title>{path.name} – LC Results JSON</title>",
        "  <style>",
        "    body { font-family: monospace; white-space: pre-wrap; }",
        "    h1 { margin-bottom: 0.5em; }",
        "    pre { background: #f8f8f8; padding: 10px; border-radius: 4px; }",
        "  </style>",
        "</head>",
        "<body>",
        f"  <h1>{path.name}</h1>",
        "  <pre>",
        pretty_escaped,
        "  </pre>",
        "</body>",
        "</html>",
    ]

    # Output HTML path: same name, .html extension
    output_html = path.with_suffix(".html")

    # Write this file's HTML
    with open(output_html, "w", encoding="utf-8") as f:
        f.write("\n".join(html_parts))

    print(f"HTML written to: {output_html}")

HTML written to: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/Ground_Truth/result_trial_data_NCT04809974.html
HTML written to: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/Ground_Truth/result_trial_data_NCT04880161.html
HTML written to: /mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/Ground_Truth/result_trial_data_NCT05576662.html


In [None]:
# Script to list and pretty-print contents of .json files in a directory.
import os
import json
import pprint

# Folder that contains your results_.pkl and _summary.json files
folder = '/mnt/c/Users/pinsy007/OneDrive - University of South Australia/3_Third_Paper/6_CG_Drugs_PlaNet/Ground_Truth'
filename = 'result_trial_data_NCT04809974.json'
file_path = os.path.join(folder, filename)

if not os.path.isfile(file_path):
    print(f"File not found: {filename}")
else:
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        pprint.pprint(data)
    except Exception as e:
        print(f"[ERROR] Could not read '{filename}': {e}")

{'AE': {'trial_1_ae': {'1': 0.029755525290966034,
                       '101': 0.06167344003915787,
                       '102': 0.030780911445617676,
                       '104': 0.06344840675592422,
                       '11': 0.028401242569088936,
                       '114': 0.0713878720998764,
                       '116': 0.03615580126643181,
                       '123': 0.13929274678230286,
                       '126': 0.04959367215633392,
                       '130': 0.03860737383365631,
                       '131': 0.0441211499273777,
                       '132': 0.024857692420482635,
                       '138': 0.029055727645754814,
                       '139': 0.02797342836856842,
                       '14': 0.08303157240152359,
                       '142': 0.023525943979620934,
                       '144': 0.036463163793087006,
                       '147': 0.060583263635635376,
                       '15': 0.03232036530971527,
                       '151': 

### **CSV**

In [None]:
import json
import csv
import os
import glob

# ====== YOUR PATHS ======
JSON_FOLDER = "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/5_Ground_Truth/Prediction"
MAP_FOLDER = "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/5_Ground_Truth/map/results"
OUTPUT_CSV = "/Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/5_Ground_Truth/Prediction/all_trials_AE_long_with_map.csv"
# ========================


def get_id_from_filename(filename: str) -> str:
    """
    From:  result_NCT04678830_results.json  -> NCT04678830
    From:  result_NCT04809974_CF_Leronlimab_results.json -> NCT04809974_CF_Leronlimab
    From:  counterfactual_summary_NCT04809974.json -> counterfactual_summary_NCT04809974
    """
    base = os.path.basename(filename)
    if base.startswith("result_"):
        base = base[len("result_"):]
    if base.endswith("_results.json"):
        base = base[: -len("_results.json")]
    elif base.endswith(".json"):
        base = base[:-5]
    return base


def load_mapping_for_trial(trial_id: str, map_folder: str):
    """
    Load TSV mapping for a given trial (by using the trial_id in the filename).
    Returns a dict keyed by (trial_id, ae_code_str) -> dict with ae_kg_id, ae_label, map_prob.
    """
    pattern = os.path.join(map_folder, f"*{trial_id}*.tsv")
    tsv_files = glob.glob(pattern)

    mapping = {}

    if not tsv_files:
        print(f"[WARN] No mapping TSV found for trial '{trial_id}' with pattern {pattern}")
        return mapping

    for tsv in tsv_files:
        try:
            with open(tsv, "r", encoding="utf-8") as f:
                header = f.readline().rstrip("\n").split("\t")
                col_idx = {name: i for i, name in enumerate(header)}
                for line in f:
                    line = line.rstrip("\n")
                    if not line:
                        continue
                    parts = line.split("\t")
                    if len(parts) < len(header):
                        continue

                    t_id = parts[col_idx.get("trial_id")]
                    ae_code = str(parts[col_idx.get("ae_code")])
                    ae_kg_id = parts[col_idx.get("ae_kg_id")]
                    ae_label = parts[col_idx.get("ae_label")]
                    map_prob = parts[col_idx.get("probability")]

                    mapping[(t_id, ae_code)] = {
                        "ae_kg_id": ae_kg_id,
                        "ae_label": ae_label,
                        "map_prob": map_prob,
                    }
        except Exception as e:
            print(f"[WARN] Could not read TSV mapping file {tsv}: {e}")

    return mapping


def extract_AE_rows_from_json(data: dict, source_name: str, map_dict: dict):
    """
    For one JSON, return a list of rows with columns:
    nct_id, arm, arm_label, safety, efficacy, ae_code, ae_kg_id, ae_label, ae_prob
    """
    rows = []

    # If the JSON doesn't have AE / safety / efficacy, skip it
    if not isinstance(data, dict):
        return rows
    if "AE" not in data or "safety" not in data or "efficacy" not in data:
        return rows

    nct_id = get_id_from_filename(source_name)

    meta = data.get("meta", {}) or {}
    ae = data.get("AE", {}) or {}
    safety = data.get("safety", {}) or {}
    efficacy_block = data.get("efficacy", {}) or {}

    # efficacy = P(drug > placebo) = 1 - prob_trial1_gt_trial2
    p_1_gt_2 = efficacy_block.get("prob_trial1_gt_trial2")
    efficacy_val = 1.0 - float(p_1_gt_2) if p_1_gt_2 is not None else None

    arms = [
        {
            "arm_key": "trial_1",
            "ae_key": "trial_1_ae",
            "safety_key": "trial_1_safety",
            "label_key": "trial_1_label",
        },
        {
            "arm_key": "trial_2",
            "ae_key": "trial_2_ae",
            "safety_key": "trial_2_safety",
            "label_key": "trial_2_label",
        },
    ]

    for arm in arms:
        arm_name = arm["arm_key"]                  # "trial_1" or "trial_2"
        ae_dict = ae.get(arm["ae_key"], {}) or {}  # e.g. AE["trial_2_ae"]
        safety_val = safety.get(arm["safety_key"])
        arm_label = meta.get(arm["label_key"], arm_name)

        if not isinstance(ae_dict, dict) or not ae_dict:
            continue

        for ae_code, ae_prob in ae_dict.items():
            key = (arm_name, str(ae_code))
            map_info = map_dict.get(key, {})
            ae_kg_id = map_info.get("ae_kg_id")
            ae_label = map_info.get("ae_label")

            rows.append({
                "nct_id": nct_id,
                "arm": arm_name,
                "arm_label": arm_label,
                "safety": safety_val,
                "efficacy": efficacy_val,
                "ae_code": ae_code,
                "ae_kg_id": ae_kg_id,
                "ae_label": ae_label,
                "ae_prob": ae_prob,
            })

    return rows


def convert_many_json_to_AE_long_with_map(json_folder: str, map_folder: str, output_csv_path: str):
    # 🔹 NOW: use ALL .json files in the folder
    pattern = os.path.join(json_folder, "*.json")
    json_files = sorted(glob.glob(pattern))

    if not json_files:
        raise FileNotFoundError(f"No .json files found in {json_folder}")

    all_rows = []

    for jf in json_files:
        trial_id = get_id_from_filename(jf)
        map_dict = load_mapping_for_trial(trial_id, map_folder)

        try:
            with open(jf, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] Could not read {jf}: {e}")
            continue

        if isinstance(data, list):
            if not data:
                print(f"[WARN] Empty list in {jf}, skipping.")
                continue
            data = data[0]

        if not isinstance(data, dict):
            print(f"[WARN] Skipping {jf}: JSON root is not an object.")
            continue

        rows = extract_AE_rows_from_json(data, source_name=jf, map_dict=map_dict)
        if not rows:
            # likely a non-AE JSON such as a summary file
            continue

        all_rows.extend(rows)

    if not all_rows:
        raise ValueError("No AE rows found – check JSON and mapping TSV structure/patterns.")

    os.makedirs(os.path.dirname(output_csv_path) or ".", exist_ok=True)

    # Column order exactly as requested
    fieldnames = [
        "nct_id",
        "arm",
        "arm_label",
        "safety",
        "efficacy",
        "ae_code",
        "ae_kg_id",
        "ae_label",
        "ae_prob",
    ]

    with open(output_csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(all_rows)

    print(f"✅ Wrote {len(all_rows)} AE rows (with mapping) from {len(json_files)} JSON files to {output_csv_path}")


if __name__ == "__main__":
    convert_many_json_to_AE_long_with_map(JSON_FOLDER, MAP_FOLDER, OUTPUT_CSV)


✅ Wrote 800 AE rows (with mapping) from 4 JSON files to /Users/sindypinero/Library/CloudStorage/OneDrive-UniversityofSouthAustralia/3_Third_Paper/5_Ground_Truth/Prediction/all_trials_AE_long_with_map.csv


## **Plots**