# MVP: Drug-Disease Adverse Outcome Prediction

**Scope**: 1-2 disease areas, 10-20 drugs with known adverse outcomes, binary outcome prediction

**Approach**:
- **GDi**: Disease → Associated Genes (PrimeKG) → Pathways (PrimeKG)
- **GDr**: Drugs → Target Genes (PrimeKG) → Pathways (PrimeKG)
- **Features**: Shared genes, shared pathways, pathway overlap, graph distance
- **Baseline**: Simple PrimeKG embeddings
- **Models**: Logistic Regression & Random Forest (no GNNs yet)

**Reference**: See `explore_primekg.ipynb` for PrimeKG dataset details


## 1. Setup & Dependencies


In [None]:
# Install dependencies
!pip install pandas numpy networkx scikit-learn matplotlib seaborn tqdm requests node2vec -q


In [None]:
import pandas as pd
import numpy as np
import networkx as nx
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

print("Dependencies loaded successfully!")


## 2. Download PrimeKG

PrimeKG is available from Harvard Dataverse. Reference: `explore_primekg.ipynb` for details.


In [None]:
import os
import requests

# PrimeKG download URL from Harvard Dataverse
PRIMEKG_URL = "https://dataverse.harvard.edu/api/access/datafile/6180620"
DATA_PATH = "kg.csv"

def download_primekg(url, filepath, chunk_size=8192):
    """Download PrimeKG dataset with progress bar."""
    if os.path.exists(filepath):
        print(f"Dataset already exists at {filepath}")
        return
    
    print(f"Downloading PrimeKG from {url}...")
    print("Note: This is ~1GB and may take several minutes.")
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filepath, 'wb') as f:
        with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
            for chunk in response.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))
    
    print(f"Download complete! Saved to {filepath}")

# Download the dataset
download_primekg(PRIMEKG_URL, DATA_PATH)


## 3. Load PrimeKG Data


In [None]:
# Load the dataset
print("Loading PrimeKG dataset...")
df = pd.read_csv(DATA_PATH, low_memory=False)
print(f"Dataset loaded: {df.shape[0]:,} edges, {df.shape[1]} columns")
print(f"Columns: {list(df.columns)}")
df.head()


## 4. Select Disease Areas & Drugs with Adverse Outcomes

We'll focus on diseases with contraindications (known adverse drug outcomes).


In [None]:
# Extract drug-disease relationships
drug_disease_mask = (
    ((df['x_type'] == 'drug') & (df['y_type'] == 'disease')) |
    ((df['x_type'] == 'disease') & (df['y_type'] == 'drug'))
)
drug_disease_df = df[drug_disease_mask].copy()

# Normalize direction: always drug -> disease
def normalize_drug_disease(row):
    if row['x_type'] == 'drug':
        return pd.Series({
            'drug_id': row['x_id'],
            'drug_name': row['x_name'],
            'disease_id': row['y_id'],
            'disease_name': row['y_name'],
            'relation': row['relation']
        })
    else:
        return pd.Series({
            'drug_id': row['y_id'],
            'drug_name': row['y_name'],
            'disease_id': row['x_id'],
            'disease_name': row['x_name'],
            'relation': row['relation']
        })

drug_disease_normalized = drug_disease_df.apply(normalize_drug_disease, axis=1)

print(f"Total drug-disease edges: {len(drug_disease_normalized):,}")
print(f"\nRelationship types:")
print(drug_disease_normalized['relation'].value_counts())


In [None]:
# Select 1-2 disease areas with many contraindications (adverse outcomes)
contraindications = drug_disease_normalized[drug_disease_normalized['relation'] == 'contraindication']

# Find diseases with most contraindications
disease_contra_counts = contraindications.groupby(['disease_id', 'disease_name']).size().reset_index(name='contra_count')
disease_contra_counts = disease_contra_counts.sort_values('contra_count', ascending=False)

print("Top diseases with contraindications (adverse outcomes):")
print(disease_contra_counts.head(10))

# Select top 1-2 disease areas
selected_diseases = disease_contra_counts.head(2)
selected_disease_ids = selected_diseases['disease_id'].tolist()
selected_disease_names = selected_diseases['disease_name'].tolist()

print(f"\nSelected disease areas:")
for did, dname in zip(selected_disease_ids, selected_disease_names):
    print(f"  - {dname} (ID: {did})")

# Get drugs with contraindications for selected diseases
selected_contraindications = contraindications[contraindications['disease_id'].isin(selected_disease_ids)]
drug_contra_counts = selected_contraindications.groupby(['drug_id', 'drug_name']).size().reset_index(name='contra_count')
drug_contra_counts = drug_contra_counts.sort_values('contra_count', ascending=False)

print(f"\nDrugs with contraindications for selected diseases: {len(drug_contra_counts)}")

# Select 10-20 drugs
n_drugs = min(20, len(drug_contra_counts))
selected_drugs = drug_contra_counts.head(n_drugs)
selected_drug_ids = selected_drugs['drug_id'].tolist()
selected_drug_names = selected_drugs['drug_name'].tolist()

print(f"\nSelected {len(selected_drug_ids)} drugs:")
for did, dname in zip(selected_drug_ids[:10], selected_drug_names[:10]):
    print(f"  - {dname} (ID: {did})")
if len(selected_drug_ids) > 10:
    print(f"  ... and {len(selected_drug_ids) - 10} more")


## 5. Build GDi and GDr Mappings

**GDi**: Disease → Genes → Pathways  
**GDr**: Drug → Target Genes → Pathways


In [None]:
# Build GDi: Disease → Genes → Pathways
print("Building GDi: Disease → Genes → Pathways")
print("=" * 60)

# Step 1: Disease → Genes (disease_protein relation)
disease_gene_edges = df[
    ((df['x_type'] == 'disease') & (df['y_type'] == 'gene/protein') & (df['relation'] == 'disease_protein')) |
    ((df['x_type'] == 'gene/protein') & (df['y_type'] == 'disease') & (df['relation'] == 'disease_protein'))
].copy()

# Normalize to disease -> gene
def normalize_disease_gene(row):
    if row['x_type'] == 'disease':
        return pd.Series({
            'disease_id': row['x_id'],
            'disease_name': row['x_name'],
            'gene_id': row['y_id'],
            'gene_name': row['y_name']
        })
    else:
        return pd.Series({
            'disease_id': row['y_id'],
            'disease_name': row['y_name'],
            'gene_id': row['x_id'],
            'gene_name': row['x_name']
        })

disease_gene_df = disease_gene_edges.apply(normalize_disease_gene, axis=1)

# Step 2: Genes → Pathways (pathway_protein relation)
gene_pathway_edges = df[
    ((df['x_type'] == 'gene/protein') & (df['y_type'] == 'pathway') & (df['relation'] == 'pathway_protein')) |
    ((df['x_type'] == 'pathway') & (df['y_type'] == 'gene/protein') & (df['relation'] == 'pathway_protein'))
].copy()

def normalize_gene_pathway(row):
    if row['x_type'] == 'gene/protein':
        return pd.Series({
            'gene_id': row['x_id'],
            'gene_name': row['x_name'],
            'pathway_id': row['y_id'],
            'pathway_name': row['y_name']
        })
    else:
        return pd.Series({
            'gene_id': row['y_id'],
            'gene_name': row['y_name'],
            'pathway_id': row['x_id'],
            'pathway_name': row['x_name']
        })

gene_pathway_df = gene_pathway_edges.apply(normalize_gene_pathway, axis=1)

# Build GDi mapping: disease_id -> {gene_ids} -> {pathway_ids}
gdi_disease_genes = defaultdict(set)
gdi_gene_pathways = defaultdict(set)
gdi_disease_pathways = defaultdict(set)

# Disease -> Genes
for _, row in disease_gene_df.iterrows():
    gdi_disease_genes[row['disease_id']].add(row['gene_id'])

# Gene -> Pathways
for _, row in gene_pathway_df.iterrows():
    gdi_gene_pathways[row['gene_id']].add(row['pathway_id'])

# Disease -> Pathways (via genes)
for disease_id, gene_ids in gdi_disease_genes.items():
    for gene_id in gene_ids:
        if gene_id in gdi_gene_pathways:
            gdi_disease_pathways[disease_id].update(gdi_gene_pathways[gene_id])

print(f"Diseases with gene associations: {len(gdi_disease_genes)}")
print(f"Genes with pathway associations: {len(gdi_gene_pathways)}")
print(f"Diseases with pathway associations: {len(gdi_disease_pathways)}")

# Check selected diseases
for disease_id in selected_disease_ids:
    n_genes = len(gdi_disease_genes.get(disease_id, set()))
    n_pathways = len(gdi_disease_pathways.get(disease_id, set()))
    disease_name = selected_diseases[selected_diseases['disease_id'] == disease_id]['disease_name'].values[0]
    print(f"\n{disease_name}:")
    print(f"  Genes: {n_genes}, Pathways: {n_pathways}")


In [None]:
# Build GDr: Drug → Target Genes → Pathways
print("Building GDr: Drug → Target Genes → Pathways")
print("=" * 60)

# Step 1: Drug → Target Genes (drug_protein relation)
drug_gene_edges = df[
    ((df['x_type'] == 'drug') & (df['y_type'] == 'gene/protein') & (df['relation'] == 'drug_protein')) |
    ((df['x_type'] == 'gene/protein') & (df['y_type'] == 'drug') & (df['relation'] == 'drug_protein'))
].copy()

def normalize_drug_gene(row):
    if row['x_type'] == 'drug':
        return pd.Series({
            'drug_id': row['x_id'],
            'drug_name': row['x_name'],
            'gene_id': row['y_id'],
            'gene_name': row['y_name']
        })
    else:
        return pd.Series({
            'drug_id': row['y_id'],
            'drug_name': row['y_name'],
            'gene_id': row['x_id'],
            'gene_name': row['x_name']
        })

drug_gene_df = drug_gene_edges.apply(normalize_drug_gene, axis=1)

# Build GDr mapping: drug_id -> {gene_ids} -> {pathway_ids}
gdr_drug_genes = defaultdict(set)
gdr_drug_pathways = defaultdict(set)

# Drug -> Genes
for _, row in drug_gene_df.iterrows():
    gdr_drug_genes[row['drug_id']].add(row['gene_id'])

# Drug -> Pathways (via genes, reuse gene_pathway mapping)
for drug_id, gene_ids in gdr_drug_genes.items():
    for gene_id in gene_ids:
        if gene_id in gdi_gene_pathways:
            gdr_drug_pathways[drug_id].update(gdi_gene_pathways[gene_id])

print(f"Drugs with gene targets: {len(gdr_drug_genes)}")
print(f"Drugs with pathway associations: {len(gdr_drug_pathways)}")

# Check selected drugs
for drug_id in selected_drug_ids[:5]:
    n_genes = len(gdr_drug_genes.get(drug_id, set()))
    n_pathways = len(gdr_drug_pathways.get(drug_id, set()))
    drug_name = selected_drugs[selected_drugs['drug_id'] == drug_id]['drug_name'].values[0]
    print(f"\n{drug_name}:")
    print(f"  Target Genes: {n_genes}, Pathways: {n_pathways}")


## 6. Build NetworkX Graph for Distance Features


In [None]:
# Build a simplified NetworkX graph for computing distances
# We'll include: drugs, diseases, genes, pathways
print("Building NetworkX graph for distance computation...")

G = nx.Graph()

# Add nodes with types
node_info = {}

# Add disease nodes
for disease_id in selected_disease_ids:
    G.add_node(disease_id, node_type='disease')
    disease_name = selected_diseases[selected_diseases['disease_id'] == disease_id]['disease_name'].values[0]
    node_info[disease_id] = {'name': disease_name, 'type': 'disease'}

# Add drug nodes
for drug_id in selected_drug_ids:
    G.add_node(drug_id, node_type='drug')
    drug_name = selected_drugs[selected_drugs['drug_id'] == drug_id]['drug_name'].values[0]
    node_info[drug_id] = {'name': drug_name, 'type': 'drug'}

# Add gene nodes and edges
all_genes = set()
for disease_id in selected_disease_ids:
    all_genes.update(gdi_disease_genes.get(disease_id, set()))
for drug_id in selected_drug_ids:
    all_genes.update(gdr_drug_genes.get(drug_id, set()))

for gene_id in all_genes:
    G.add_node(gene_id, node_type='gene')
    node_info[gene_id] = {'name': f'Gene_{gene_id}', 'type': 'gene'}

# Add pathway nodes
all_pathways = set()
for disease_id in selected_disease_ids:
    all_pathways.update(gdi_disease_pathways.get(disease_id, set()))
for drug_id in selected_drug_ids:
    all_pathways.update(gdr_drug_pathways.get(drug_id, set()))

for pathway_id in all_pathways:
    G.add_node(pathway_id, node_type='pathway')
    node_info[pathway_id] = {'name': f'Pathway_{pathway_id}', 'type': 'pathway'}

# Add edges: disease-gene, drug-gene, gene-pathway
for disease_id in selected_disease_ids:
    for gene_id in gdi_disease_genes.get(disease_id, set()):
        G.add_edge(disease_id, gene_id, relation='disease_gene')

for drug_id in selected_drug_ids:
    for gene_id in gdr_drug_genes.get(drug_id, set()):
        G.add_edge(drug_id, gene_id, relation='drug_gene')

for gene_id in all_genes:
    for pathway_id in gdi_gene_pathways.get(gene_id, set()):
        if pathway_id in all_pathways:
            G.add_edge(gene_id, pathway_id, relation='gene_pathway')

print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
print(f"  Diseases: {len(selected_disease_ids)}")
print(f"  Drugs: {len(selected_drug_ids)}")
print(f"  Genes: {len(all_genes)}")
print(f"  Pathways: {len(all_pathways)}")


## 7. Extract Interaction Features

For each drug-disease pair, extract:
1. Shared genes count
2. Shared pathways count
3. Pathway coverage overlap (Jaccard similarity)
4. Graph distance between drug and disease nodes


In [None]:
def compute_pathway_overlap(disease_pathways, drug_pathways):
    """Compute Jaccard similarity of pathway sets."""
    if len(disease_pathways) == 0 and len(drug_pathways) == 0:
        return 0.0
    intersection = len(disease_pathways & drug_pathways)
    union = len(disease_pathways | drug_pathways)
    return intersection / union if union > 0 else 0.0

def compute_graph_distance(G, drug_id, disease_id):
    """Compute shortest path distance between drug and disease."""
    try:
        if drug_id not in G or disease_id not in G:
            return -1  # Not in graph
        distance = nx.shortest_path_length(G, drug_id, disease_id)
        return distance
    except nx.NetworkXNoPath:
        return -1  # No path exists

def extract_features(drug_id, disease_id, gdi_disease_genes, gdi_disease_pathways,
                     gdr_drug_genes, gdr_drug_pathways, G):
    """Extract all interaction features for a drug-disease pair."""
    # Get gene and pathway sets
    disease_genes = gdi_disease_genes.get(disease_id, set())
    disease_pathways = gdi_disease_pathways.get(disease_id, set())
    drug_genes = gdr_drug_genes.get(drug_id, set())
    drug_pathways = gdr_drug_pathways.get(drug_id, set())
    
    # Feature 1: Shared genes count
    shared_genes = disease_genes & drug_genes
    n_shared_genes = len(shared_genes)
    
    # Feature 2: Shared pathways count
    shared_pathways = disease_pathways & drug_pathways
    n_shared_pathways = len(shared_pathways)
    
    # Feature 3: Pathway overlap (Jaccard similarity)
    pathway_overlap = compute_pathway_overlap(disease_pathways, drug_pathways)
    
    # Feature 4: Graph distance
    graph_distance = compute_graph_distance(G, drug_id, disease_id)
    
    # Additional features
    n_disease_genes = len(disease_genes)
    n_disease_pathways = len(disease_pathways)
    n_drug_genes = len(drug_genes)
    n_drug_pathways = len(drug_pathways)
    
    return {
        'n_shared_genes': n_shared_genes,
        'n_shared_pathways': n_shared_pathways,
        'pathway_overlap': pathway_overlap,
        'graph_distance': graph_distance,
        'n_disease_genes': n_disease_genes,
        'n_disease_pathways': n_disease_pathways,
        'n_drug_genes': n_drug_genes,
        'n_drug_pathways': n_drug_pathways,
    }


In [None]:
# Create dataset: all drug-disease pairs with labels and features
print("Creating dataset with features...")

dataset = []

# Positive examples: contraindications (adverse outcomes)
positive_pairs = contraindications[
    (contraindications['disease_id'].isin(selected_disease_ids)) &
    (contraindications['drug_id'].isin(selected_drug_ids))
]

print(f"Positive examples (contraindications): {len(positive_pairs)}")

# Negative examples: indications (therapeutic uses) or random pairs
indications = drug_disease_normalized[drug_disease_normalized['relation'] == 'indication']
negative_pairs = indications[
    (indications['disease_id'].isin(selected_disease_ids)) &
    (indications['drug_id'].isin(selected_drug_ids))
]

# If not enough negative examples, add random pairs
if len(negative_pairs) < len(positive_pairs):
    # Generate random drug-disease pairs
    all_pairs = set(zip(positive_pairs['drug_id'], positive_pairs['disease_id']))
    all_pairs.update(zip(negative_pairs['drug_id'], negative_pairs['disease_id']))
    
    needed = len(positive_pairs) - len(negative_pairs)
    random_pairs = []
    attempts = 0
    while len(random_pairs) < needed and attempts < 10000:
        drug_id = np.random.choice(selected_drug_ids)
        disease_id = np.random.choice(selected_disease_ids)
        if (drug_id, disease_id) not in all_pairs:
            random_pairs.append((drug_id, disease_id))
            all_pairs.add((drug_id, disease_id))
        attempts += 1
    
    if random_pairs:
        random_df = pd.DataFrame(random_pairs, columns=['drug_id', 'disease_id'])
        random_df['drug_name'] = random_df['drug_id'].map(dict(zip(selected_drugs['drug_id'], selected_drugs['drug_name'])))
        random_df['disease_name'] = random_df['disease_id'].map(dict(zip(selected_diseases['disease_id'], selected_diseases['disease_name'])))
        random_df['relation'] = 'none'
        negative_pairs = pd.concat([negative_pairs, random_df], ignore_index=True)

print(f"Negative examples: {len(negative_pairs)}")

# Extract features for positive pairs
for _, row in tqdm(positive_pairs.iterrows(), total=len(positive_pairs), desc="Positive pairs"):
    features = extract_features(
        row['drug_id'], row['disease_id'],
        gdi_disease_genes, gdi_disease_pathways,
        gdr_drug_genes, gdr_drug_pathways, G
    )
    features['drug_id'] = row['drug_id']
    features['drug_name'] = row['drug_name']
    features['disease_id'] = row['disease_id']
    features['disease_name'] = row['disease_name']
    features['label'] = 1  # Adverse outcome
    dataset.append(features)

# Extract features for negative pairs
for _, row in tqdm(negative_pairs.iterrows(), total=len(negative_pairs), desc="Negative pairs"):
    features = extract_features(
        row['drug_id'], row['disease_id'],
        gdi_disease_genes, gdi_disease_pathways,
        gdr_drug_genes, gdr_drug_pathways, G
    )
    features['drug_id'] = row['drug_id']
    features['drug_name'] = row['drug_name']
    features['disease_id'] = row['disease_id']
    features['disease_name'] = row['disease_name']
    features['label'] = 0  # No adverse outcome
    dataset.append(features)

df_dataset = pd.DataFrame(dataset)

print(f"\nDataset created: {len(df_dataset)} examples")
print(f"  Positive (adverse): {df_dataset['label'].sum()}")
print(f"  Negative (safe): {len(df_dataset) - df_dataset['label'].sum()}")
print(f"\nFeature statistics:")
print(df_dataset.describe())


## 8. Baseline: Simple PrimeKG Embeddings

Create simple node embeddings using graph structure (e.g., node2vec or degree-based features).


In [None]:
# Simple baseline: Use node degrees and centrality measures as embeddings
print("Creating baseline embeddings...")

# Compute node degrees
degrees = dict(G.degree())

# Compute centrality measures (for a subset if graph is large)
if G.number_of_nodes() < 10000:
    print("Computing centrality measures...")
    betweenness = nx.betweenness_centrality(G)
    closeness = nx.closeness_centrality(G)
else:
    print("Graph too large, using degree only...")
    betweenness = {n: 0.0 for n in G.nodes()}
    closeness = {n: 0.0 for n in G.nodes()}

# Create embedding features for each drug-disease pair
baseline_features = []
for _, row in df_dataset.iterrows():
    drug_id = row['drug_id']
    disease_id = row['disease_id']
    
    drug_degree = degrees.get(drug_id, 0)
    disease_degree = degrees.get(disease_id, 0)
    drug_betweenness = betweenness.get(drug_id, 0.0)
    disease_betweenness = betweenness.get(disease_id, 0.0)
    drug_closeness = closeness.get(drug_id, 0.0)
    disease_closeness = closeness.get(disease_id, 0.0)
    
    baseline_features.append({
        'drug_degree': drug_degree,
        'disease_degree': disease_degree,
        'drug_betweenness': drug_betweenness,
        'disease_betweenness': disease_betweenness,
        'drug_closeness': drug_closeness,
        'disease_closeness': disease_closeness,
        'degree_sum': drug_degree + disease_degree,
        'degree_diff': abs(drug_degree - disease_degree),
    })

df_baseline = pd.DataFrame(baseline_features)
print(f"Baseline features created: {df_baseline.shape[1]} features")
print(df_baseline.describe())


## 9. Train Models

Train Logistic Regression and Random Forest models on:
1. Interaction features only
2. Baseline embeddings only
3. Combined features


In [None]:
# Prepare feature sets
feature_cols = [
    'n_shared_genes', 'n_shared_pathways', 'pathway_overlap', 'graph_distance',
    'n_disease_genes', 'n_disease_pathways', 'n_drug_genes', 'n_drug_pathways'
]

baseline_cols = list(df_baseline.columns)

# Handle missing values (graph_distance = -1 means no path)
df_dataset['graph_distance'] = df_dataset['graph_distance'].replace(-1, df_dataset['graph_distance'].max() + 1)

X_interaction = df_dataset[feature_cols].values
X_baseline = df_baseline.values
X_combined = np.hstack([X_interaction, X_baseline])

y = df_dataset['label'].values

# Train-test split
X_interaction_train, X_interaction_test, y_train, y_test = train_test_split(
    X_interaction, y, test_size=0.2, random_state=42, stratify=y
)

X_baseline_train, X_baseline_test, _, _ = train_test_split(
    X_baseline, y, test_size=0.2, random_state=42, stratify=y
)

X_combined_train, X_combined_test, _, _ = train_test_split(
    X_combined, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {len(y_train)} examples")
print(f"Test set: {len(y_test)} examples")
print(f"  Positive in train: {y_train.sum()}, test: {y_test.sum()}")


In [None]:
# Train Logistic Regression models
print("Training Logistic Regression models...")
print("=" * 60)

models = {}
results = {}

# 1. Interaction features only
lr_interaction = LogisticRegression(random_state=42, max_iter=1000)
lr_interaction.fit(X_interaction_train, y_train)
y_pred_interaction = lr_interaction.predict(X_interaction_test)
y_proba_interaction = lr_interaction.predict_proba(X_interaction_test)[:, 1]

models['LR_Interaction'] = lr_interaction
results['LR_Interaction'] = {
    'accuracy': accuracy_score(y_test, y_pred_interaction),
    'precision': precision_score(y_test, y_pred_interaction),
    'recall': recall_score(y_test, y_pred_interaction),
    'f1': f1_score(y_test, y_pred_interaction),
    'roc_auc': roc_auc_score(y_test, y_proba_interaction)
}

# 2. Baseline embeddings only
lr_baseline = LogisticRegression(random_state=42, max_iter=1000)
lr_baseline.fit(X_baseline_train, y_train)
y_pred_baseline = lr_baseline.predict(X_baseline_test)
y_proba_baseline = lr_baseline.predict_proba(X_baseline_test)[:, 1]

models['LR_Baseline'] = lr_baseline
results['LR_Baseline'] = {
    'accuracy': accuracy_score(y_test, y_pred_baseline),
    'precision': precision_score(y_test, y_pred_baseline),
    'recall': recall_score(y_test, y_pred_baseline),
    'f1': f1_score(y_test, y_pred_baseline),
    'roc_auc': roc_auc_score(y_test, y_proba_baseline)
}

# 3. Combined features
lr_combined = LogisticRegression(random_state=42, max_iter=1000)
lr_combined.fit(X_combined_train, y_train)
y_pred_combined = lr_combined.predict(X_combined_test)
y_proba_combined = lr_combined.predict_proba(X_combined_test)[:, 1]

models['LR_Combined'] = lr_combined
results['LR_Combined'] = {
    'accuracy': accuracy_score(y_test, y_pred_combined),
    'precision': precision_score(y_test, y_pred_combined),
    'recall': recall_score(y_test, y_pred_combined),
    'f1': f1_score(y_test, y_pred_combined),
    'roc_auc': roc_auc_score(y_test, y_proba_combined)
}

# Print results
print("\nLogistic Regression Results:")
results_df = pd.DataFrame(results).T
print(results_df.round(4))


In [None]:
# Train Random Forest models
print("\nTraining Random Forest models...")
print("=" * 60)

# 1. Interaction features only
rf_interaction = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
rf_interaction.fit(X_interaction_train, y_train)
y_pred_interaction = rf_interaction.predict(X_interaction_test)
y_proba_interaction = rf_interaction.predict_proba(X_interaction_test)[:, 1]

models['RF_Interaction'] = rf_interaction
results['RF_Interaction'] = {
    'accuracy': accuracy_score(y_test, y_pred_interaction),
    'precision': precision_score(y_test, y_pred_interaction),
    'recall': recall_score(y_test, y_pred_interaction),
    'f1': f1_score(y_test, y_pred_interaction),
    'roc_auc': roc_auc_score(y_test, y_proba_interaction)
}

# 2. Baseline embeddings only
rf_baseline = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
rf_baseline.fit(X_baseline_train, y_train)
y_pred_baseline = rf_baseline.predict(X_baseline_test)
y_proba_baseline = rf_baseline.predict_proba(X_baseline_test)[:, 1]

models['RF_Baseline'] = rf_baseline
results['RF_Baseline'] = {
    'accuracy': accuracy_score(y_test, y_pred_baseline),
    'precision': precision_score(y_test, y_pred_baseline),
    'recall': recall_score(y_test, y_pred_baseline),
    'f1': f1_score(y_test, y_pred_baseline),
    'roc_auc': roc_auc_score(y_test, y_proba_baseline)
}

# 3. Combined features
rf_combined = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
rf_combined.fit(X_combined_train, y_train)
y_pred_combined = rf_combined.predict(X_combined_test)
y_proba_combined = rf_combined.predict_proba(X_combined_test)[:, 1]

models['RF_Combined'] = rf_combined
results['RF_Combined'] = {
    'accuracy': accuracy_score(y_test, y_pred_combined),
    'precision': precision_score(y_test, y_pred_combined),
    'recall': recall_score(y_test, y_pred_combined),
    'f1': f1_score(y_test, y_pred_combined),
    'roc_auc': roc_auc_score(y_test, y_proba_combined)
}

# Print results
print("\nRandom Forest Results:")
results_df = pd.DataFrame(results).T
print(results_df.round(4))


## 10. Model Evaluation & Visualization


In [None]:
# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Accuracy comparison
ax = axes[0, 0]
results_df['accuracy'].plot(kind='bar', ax=ax, color='steelblue')
ax.set_title('Model Accuracy Comparison', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy')
ax.set_ylim([0, 1])
ax.tick_params(axis='x', rotation=45)
ax.grid(axis='y', alpha=0.3)

# 2. ROC-AUC comparison
ax = axes[0, 1]
results_df['roc_auc'].plot(kind='bar', ax=ax, color='coral')
ax.set_title('ROC-AUC Comparison', fontsize=12, fontweight='bold')
ax.set_ylabel('ROC-AUC')
ax.set_ylim([0, 1])
ax.tick_params(axis='x', rotation=45)
ax.grid(axis='y', alpha=0.3)

# 3. F1 Score comparison
ax = axes[1, 0]
results_df['f1'].plot(kind='bar', ax=ax, color='mediumseagreen')
ax.set_title('F1 Score Comparison', fontsize=12, fontweight='bold')
ax.set_ylabel('F1 Score')
ax.set_ylim([0, 1])
ax.tick_params(axis='x', rotation=45)
ax.grid(axis='y', alpha=0.3)

# 4. Feature importance (best RF model)
ax = axes[1, 1]
best_model_name = results_df['roc_auc'].idxmax()
best_model = models[best_model_name]

if 'RF' in best_model_name:
    if 'Combined' in best_model_name:
        all_feature_names = feature_cols + baseline_cols
        importances = best_model.feature_importances_
    elif 'Interaction' in best_model_name:
        all_feature_names = feature_cols
        importances = best_model.feature_importances_
    else:
        all_feature_names = baseline_cols
        importances = best_model.feature_importances_
    
    # Get top 10 features
    top_indices = np.argsort(importances)[-10:][::-1]
    top_features = [all_feature_names[i] for i in top_indices]
    top_importances = importances[top_indices]
    
    ax.barh(range(len(top_features)), top_importances, color='gold')
    ax.set_yticks(range(len(top_features)))
    ax.set_yticklabels(top_features)
    ax.set_xlabel('Importance')
    ax.set_title(f'Top 10 Features: {best_model_name}', fontsize=12, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
else:
    ax.text(0.5, 0.5, 'Feature importance\navailable for\nRandom Forest only',
            ha='center', va='center', fontsize=12, transform=ax.transAxes)
    ax.set_title(f'Best Model: {best_model_name}', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\nBest Model: {best_model_name}")
print(f"  ROC-AUC: {results_df.loc[best_model_name, 'roc_auc']:.4f}")
print(f"  F1 Score: {results_df.loc[best_model_name, 'f1']:.4f}")


In [None]:
# Confusion matrices for best models
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

model_names = ['LR_Combined', 'RF_Interaction', 'RF_Combined']
for idx, model_name in enumerate(model_names):
    if model_name not in models:
        continue
    
    model = models[model_name]
    
    if 'Combined' in model_name:
        X_test = X_combined_test
    elif 'Interaction' in model_name:
        X_test = X_interaction_test
    else:
        X_test = X_baseline_test
    
    y_pred = model.predict(X_test)
    cm = confusion_matrix(y_test, y_pred)
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx],
                xticklabels=['Safe', 'Adverse'], yticklabels=['Safe', 'Adverse'])
    axes[idx].set_title(f'{model_name}\nAccuracy: {results_df.loc[model_name, "accuracy"]:.3f}',
                        fontsize=11, fontweight='bold')
    axes[idx].set_ylabel('True Label')
    axes[idx].set_xlabel('Predicted Label')

plt.tight_layout()
plt.show()


## 11. Summary & Next Steps

### Key Findings:
1. **Feature Performance**: Which feature set works best?
2. **Model Comparison**: LR vs RF performance
3. **Important Features**: What drives predictions?

### Next Steps for Full Model:
1. **GNN Integration**: Use graph neural networks to capture complex interactions
2. **More Disease Areas**: Expand to additional disease areas
3. **External Data**: Integrate DisGeNET and DrugBank for richer gene mappings
4. **Feature Engineering**: Add temporal, dosage, and patient-specific features
5. **Evaluation**: Cross-validation, external validation sets


## 11. Summary & Next Steps

### Key Findings:
1. **Feature Performance**: Which feature set works best?
2. **Model Comparison**: LR vs RF performance
3. **Important Features**: What drives predictions?

### Next Steps for Full Model:
1. **GNN Integration**: Use graph neural networks to capture complex interactions
2. **More Disease Areas**: Expand to additional disease areas
3. **External Data**: Integrate DisGeNET and DrugBank for richer gene mappings
4. **Feature Engineering**: Add temporal, dosage, and patient-specific features
5. **Evaluation**: Cross-validation, external validation sets


In [None]:
# Final summary
print("=" * 60)
print("MVP SUMMARY")
print("=" * 60)
print(f"\nDataset:")
print(f"  Disease areas: {len(selected_disease_ids)}")
print(f"  Drugs: {len(selected_drug_ids)}")
print(f"  Total examples: {len(df_dataset)}")
print(f"    Positive (adverse): {df_dataset['label'].sum()}")
print(f"    Negative (safe): {len(df_dataset) - df_dataset['label'].sum()}")

print(f"\nFeatures:")
print(f"  Interaction features: {len(feature_cols)}")
print(f"  Baseline features: {len(baseline_cols)}")
print(f"  Combined: {len(feature_cols) + len(baseline_cols)}")

best_model_name = results_df['roc_auc'].idxmax()
print(f"\nBest Model: {best_model_name}")
print(f"  ROC-AUC: {results_df.loc[best_model_name, 'roc_auc']:.4f}")
print(f"  Accuracy: {results_df.loc[best_model_name, 'accuracy']:.4f}")
print(f"  F1 Score: {results_df.loc[best_model_name, 'f1']:.4f}")

print("\n" + "=" * 60)
