In [20]:
import pandas as pd
import os 
import gzip 
import torch
import numpy as np 
from rdkit import Chem
from rdkit.Chem import AllChem 

In [21]:
sider_file_path = 'data/meddra_all_se.tsv'
sider_names_file_path = 'data/drug_names.tsv'
sider_df = pd.read_csv(sider_file_path,sep='\t',header=None, compression=None)
drug_names_df = pd.read_csv(sider_names_file_path,sep='\t',header=None,compression=None)
drug_names_df.columns = ['STITCH_flat', 'Drug_Name']
sider_df.columns = [
    'STITCH_compound_flat',  # Example: CID100000085
    'STITCH_compound_stereo',  # Example: CID000010917
    'UMLS_concept_id',         # Example: C0000729
    'MedDRA_type',             # e.g., LLT
    'MedDRA_concept_id',       # Example: C0000729
    'LLT_preferred_term'       # e.g., "Abdominal cramps"
]
sider_df.dropna(subset=['MedDRA_type', 'MedDRA_concept_id'], inplace=True)

In [None]:
ctd_dir='data'
ctd_chem_disease_file = os.path.join(ctd_dir, 'CTD_chemicals_diseases.csv.gz')
ctd_chem_gene_file = os.path.join(ctd_dir, 'CTD_chem_gene_ixns.csv.gz')
ctd_chemicals_file = os.path.join(ctd_dir, 'CTD_chemicals.csv.gz')
ctd_genes_file = os.path.join(ctd_dir, 'CTD_genes.csv.gz')
ctd_chem_disease_df = pd.read_csv(
        ctd_chem_disease_file,
        comment='#', 
        compression='gzip'
    )
ctd_chem_gene_df = pd.read_csv(
    ctd_chem_gene_file,
    comment='#',
    compression='gzip'
)
ctd_chemicals_df = pd.read_csv(
    ctd_chemicals_file,
    sep='\t',
    comment='#',
    compression='gzip'
)
ctd_genes_df = pd.read_csv(
    ctd_genes_file,
    sep='\t',
    comment='#',
    compression='gzip'
)
ctd_chem_gene_df.columns = [
    "ChemicalName",       # e.g., 10074-G5
    "ChemicalID",         # e.g., C534883
    "CasRN",              # Unnamed or CAS Registry Number
    "GeneSymbol",         # e.g., AR
    "GeneID",             # e.g., 367
    "GeneForms",          # e.g., protein
    "Organism",           # e.g., Homo sapiens
    "OrganismID",         # e.g., 9606
    "Interaction",        # Natural language interaction
    "InteractionActions", # Parsed actions e.g., decreases^reaction|increases^expression
    "PubMedIDs"           # Supporting publication IDs
]
ctd_chem_disease_df.columns = [
    "ChemicalName", "ChemicalID", "CasRN", "DiseaseName", "DiseaseID", 
    "DirectEvidence", "InferenceGeneSymbol", "InferenceScore", 
    "OmimIDs", "PubMedIDs"
]

In [None]:
ctd_chem_disease_df.drop(columns=['CasRN', 'DirectEvidence', 'OmimIDs'],inplace=True)
ctd_chem_gene_df.drop(
    columns=['CasRN', 'GeneForms', 'Organism', 'OrganismID'],inplace=True
)
ctd_chem_disease_df['InferenceScore'].fillna(ctd_chem_disease_df['InferenceScore'].mean(), inplace=True)
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
ctd_chem_disease_df['InferenceScore'] = scaler.fit_transform(ctd_chem_disease_df[['InferenceScore']])


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  ctd_chem_disease_df['InferenceScore'].fillna(ctd_chem_disease_df['InferenceScore'].mean(), inplace=True)


In [None]:
all_drug_ids = pd.concat([
    sider_df['STITCH_compound_flat'], 
    ctd_chem_gene_df['ChemicalID'], 
    ctd_chem_disease_df['ChemicalID']
]).dropna().unique()
all_se_ids = sider_df['MedDRA_concept_id'].dropna().unique()
all_disease_ids = ctd_chem_disease_df['DiseaseID'].dropna().unique()
all_gene_ids = ctd_chem_gene_df['GeneID'].dropna().unique()

print(f"Original counts:")
print(f"  Drugs: {len(all_drug_ids)}")
print(f"  Side Effects: {len(all_se_ids)}")
print(f"  Diseases: {len(all_disease_ids)}")
print(f"  Genes: {len(all_gene_ids)}")

Original counts:
  Drugs: 19275
  Side Effects: 6060
  Diseases: 7281
  Genes: 56367


In [None]:
import numpy as np

sampling_fraction = 0.08  # Set to e.g. 0.1 for 10% sampling
num_drugs = len(all_drug_ids)
num_sampled_drugs = int(num_drugs * sampling_fraction)

sampled_drug_ids = np.random.choice(all_drug_ids, size=num_sampled_drugs, replace=False)

print(f"\nSubsampling drugs:")
print(f"  Original drugs: {num_drugs}")
print(f"  Sampled drugs: {num_sampled_drugs}")

# Create mappings for sampled drugs and full mappings for other node types
drug_mapping = {id: i for i, id in enumerate(sampled_drug_ids)}
se_mapping = {id: i for i, id in enumerate(all_se_ids)}
disease_mapping = {id: i for i, id in enumerate(all_disease_ids)}
gene_mapping = {id: i for i, id in enumerate(all_gene_ids)}

print(f"\nNumber of nodes after subsampling:")
print(f"  Drugs: {len(drug_mapping)}")
print(f"  Side Effects: {len(se_mapping)}")
print(f"  Diseases: {len(disease_mapping)}")
print(f"  Genes: {len(gene_mapping)}")


Subsampling drugs:
  Original drugs: 19275
  Sampled drugs: 1542

Number of nodes after subsampling:
  Drugs: 1542
  Side Effects: 6060
  Diseases: 7281
  Genes: 56367


In [None]:
from sklearn.calibration import LabelEncoder

drug_encoder = LabelEncoder()
se_encoder = LabelEncoder()
disease_encoder = LabelEncoder()
gene_encoder = LabelEncoder()
drug_encoder.fit(sampled_drug_ids)
encoded_drug_features = drug_encoder.transform(sampled_drug_ids)

# Encode side effects (you can use all SEs here since subsampling is only on drugs)
encoded_se_features = se_encoder.fit_transform(all_se_ids)

# Encode diseases and genes (same reasoning applies as side effects)
encoded_disease_features = disease_encoder.fit_transform(all_disease_ids)
encoded_gene_features = gene_encoder.fit_transform(all_gene_ids)

# Step 4: Check encoded features for drugs, side effects, diseases, and genes
print(f"\nEncoded drug features (sample): {encoded_drug_features[:5]}")
print(f"Encoded side effect features (sample): {encoded_se_features[:5]}")
print(f"Encoded disease features (sample): {encoded_disease_features[:5]}")
print(f"Encoded gene features (sample): {encoded_gene_features[:5]}")


Encoded drug features (sample): [1304   10 1393 1261  144]
Encoded side effect features (sample): [   1    4 4270   58   66]
Encoded disease features (sample): [2458 2391 2496 4440 2610]
Encoded gene features (sample): [ 277 1480 2938 3193 1702]


In [None]:
sider_df_sampled = sider_df[sider_df['STITCH_compound_flat'].isin(sampled_drug_ids)].copy()
ctd_chem_disease_df_sampled = ctd_chem_disease_df[ctd_chem_disease_df['ChemicalID'].isin(sampled_drug_ids)].copy()
ctd_chem_gene_df_sampled = ctd_chem_gene_df[ctd_chem_gene_df['ChemicalID'].isin(sampled_drug_ids)].copy()

print(f"\nFiltered edges counts:")
print(f"  SIDER drug-side effect edges: {len(sider_df_sampled)}")
print(f"  CTD chemical-disease edges: {len(ctd_chem_disease_df_sampled)}")
print(f"  CTD chemical-gene edges: {len(ctd_chem_gene_df_sampled)}")


Filtered edges counts:
  SIDER drug-side effect edges: 25969
  CTD chemical-disease edges: 683718
  CTD chemical-gene edges: 188412


In [None]:
drug_mapping = {drug_id: idx for idx, drug_id in enumerate(drug_encoder.classes_)}
se_mapping = {se_id: idx for idx, se_id in enumerate(se_encoder.classes_)}
disease_mapping = {disease_id: idx for idx, disease_id in enumerate(disease_encoder.classes_)}
gene_mapping = {gene_id: idx for idx, gene_id in enumerate(gene_encoder.classes_)}

# Drug -> Side Effect edges
sider_edges = []
for _, row in sider_df_sampled.iterrows():
    drug_id = row['STITCH_compound_flat']
    se_id = row['MedDRA_concept_id']
    if drug_id in drug_mapping and se_id in se_mapping:
        sider_edges.append((drug_mapping[drug_id], se_mapping[se_id]))

sider_edge_index = torch.tensor(sider_edges, dtype=torch.long).t().contiguous()
print(f"Drug-Side Effect edges: {sider_edge_index.size(1)}")

# Drug -> Disease edges
ctd_chem_disease_edges = []
for _, row in ctd_chem_disease_df_sampled.iterrows():
    chem_id = row['ChemicalID']
    disease_id = row['DiseaseID']
    if chem_id in drug_mapping and disease_id in disease_mapping:
        ctd_chem_disease_edges.append((drug_mapping[chem_id], disease_mapping[disease_id]))

ctd_chem_disease_edge_index = torch.tensor(ctd_chem_disease_edges, dtype=torch.long).t().contiguous()
print(f"Drug-Disease edges: {ctd_chem_disease_edge_index.size(1)}")

# Drug -> Gene edges
ctd_chem_gene_edges = []
for _, row in ctd_chem_gene_df_sampled.iterrows():
    chem_id = row['ChemicalID']
    gene_id = row['GeneID']
    if chem_id in drug_mapping and gene_id in gene_mapping:
        ctd_chem_gene_edges.append((drug_mapping[chem_id], gene_mapping[gene_id]))

ctd_chem_gene_edge_index = torch.tensor(ctd_chem_gene_edges, dtype=torch.long).t().contiguous()
print(f"Drug-Gene edges: {ctd_chem_gene_edge_index.size(1)}")


Drug-Side Effect edges: 25969
Drug-Disease edges: 683718
Drug-Gene edges: 188412


In [None]:
import random
from sklearn.calibration import LabelEncoder
num_drug_features = 100
num_se_features = 80
num_disease_features = 90
num_gene_features = 150

x_drug = torch.randn(len(drug_mapping), num_drug_features)
x_side_effect = torch.randn(len(se_mapping), num_se_features)
x_disease = torch.randn(len(disease_mapping), num_disease_features)
x_gene = torch.randn(len(gene_mapping), num_gene_features)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GATConv, Linear

class HeterogeneousGNNModel(torch.nn.Module):
    def __init__(self, hidden_channels, dropout_prob=0.5):
        super().__init__()

        self.edge_types = [
            ('drug', 'causes', 'side_effect'),
            ('side_effect', 'is_caused_by', 'drug'),
            ('drug', 'associates', 'disease'),
            ('disease', 'is_associated_with_drug', 'drug'),
            ('drug', 'interacts', 'gene'),
            ('gene', 'is_interacted_by', 'drug'),
        ]

        self.conv1 = HeteroConv({
            edge_type: GATConv((-1, -1), hidden_channels, add_self_loops=False)
            for edge_type in self.edge_types
        }, aggr='sum')

        self.dropout1 = nn.Dropout(dropout_prob)

        self.conv2 = HeteroConv({
            edge_type: GATConv((hidden_channels, hidden_channels), hidden_channels, add_self_loops=False)
            for edge_type in self.edge_types
        }, aggr='sum')

        self.dropout2 = nn.Dropout(dropout_prob)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.dropout1(x) for key, x in x_dict.items()}

        x_dict = self.conv2(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.dropout2(x) for key, x in x_dict.items()}

        return x_dict


# --- 10. Link Prediction Head ---

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin = torch.nn.Linear(in_channels, 1)

    def forward(self, z_drug, z_side_effect):
        z = torch.cat([z_drug, z_side_effect], dim=-1) # Example: Concatenation
        return self.lin(z)


# --- 11. Helper function for Link Prediction Metrics ---
from sklearn.metrics import roc_auc_score, average_precision_score

def get_link_prediction_metrics(pos_pred, neg_pred):
    pos_prob = torch.sigmoid(pos_pred)
    neg_prob = torch.sigmoid(neg_pred)

    probs = torch.cat([pos_prob, neg_prob], dim=0).detach().cpu().numpy()
    ground_truth = torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)], dim=0).detach().cpu().numpy()

    auc_score = roc_auc_score(ground_truth, probs)
    auprc_score = average_precision_score(ground_truth, probs)

    return auc_score, auprc_score

In [None]:
from torch_geometric.data import HeteroData
data = HeteroData()

data['drug'].x = x_drug
data['side_effect'].x = x_side_effect
data['disease'].x = x_disease
data['gene'].x = x_gene

data['drug', 'causes', 'side_effect'].edge_index = sider_edge_index
data['drug', 'associates', 'disease'].edge_index = ctd_chem_disease_edge_index
data['drug', 'interacts', 'gene'].edge_index = ctd_chem_gene_edge_index
data['side_effect', 'is_caused_by', 'drug'].edge_index = sider_edge_index.flip(0)
data['disease', 'is_associated_with_drug', 'drug'].edge_index = ctd_chem_disease_edge_index.flip(0)
data['gene', 'is_interacted_by', 'drug'].edge_index = ctd_chem_gene_edge_index.flip(0)
device='cpu'
data.to(device)
target_edge_index = data['drug', 'causes', 'side_effect'].edge_index

# Split positive edges into train, val, test
num_target_edges = target_edge_index.size(1)
perm = np.random.permutation(num_target_edges)
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

num_train_pos = int(num_target_edges * train_ratio)
num_val_pos = int(num_target_edges * val_ratio)
num_test_pos = num_target_edges - num_train_pos - num_val_pos

train_pos_indices = perm[:num_train_pos]
val_pos_indices = perm[num_train_pos:num_train_pos + num_val_pos]
test_pos_indices = perm[num_train_pos + num_val_pos:]

# Extract the actual edge index tensors for positive edges
data['drug', 'causes', 'side_effect'].edge_index_train_pos = target_edge_index[:, train_pos_indices].to(device)
data['drug', 'causes', 'side_effect'].edge_index_val_pos = target_edge_index[:, val_pos_indices].to(device)
data['drug', 'causes', 'side_effect'].edge_index_test_pos = target_edge_index[:, test_pos_indices].to(device)

# Generate negative samples for validation and test sets
# Ensure these negative samples are NOT present in the original positive edges
num_drug_nodes = data['drug'].num_nodes
num_se_nodes = data['side_effect'].num_nodes

data['drug', 'causes', 'side_effect'].edge_index_val_neg = torch.randint(
    0, num_drug_nodes, (2, num_val_pos), dtype=torch.long, device=device) # Example: same number as val_pos

data['drug', 'causes', 'side_effect'].edge_index_test_neg = torch.randint(
    0, num_drug_nodes, (2, num_test_pos), dtype=torch.long, device=device) # Example: same number as test_pos



In [None]:
hidden_channels = 64
gnn_model = HeterogeneousGNNModel(hidden_channels=hidden_channels, dropout_prob=0.5)
link_predictor = LinkPredictor(in_channels=hidden_channels * 2) # in_channels matches how you combine embeddings

gnn_model.to(device)
link_predictor.to(device)

weight_decay_rate = 5e-4
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(link_predictor.parameters()), lr=0.01, weight_decay=weight_decay_rate)

criterion = torch.nn.BCEWithLogitsLoss()


def train_link_prediction():
    gnn_model.train()
    link_predictor.train()
    optimizer.zero_grad()

    # Get learned node embeddings
    z_dict = gnn_model(data.x_dict, data.edge_index_dict)
    z_drug = z_dict['drug']
    z_side_effect = z_dict['side_effect']

    train_pos_edge_index = data['drug', 'causes', 'side_effect'].edge_index_train_pos

    num_train_pos = train_pos_edge_index.size(1)
    num_drug_nodes_in_graph = data['drug'].num_nodes
    num_se_nodes_in_graph = data['side_effect'].num_nodes

    train_neg_edge_index = torch.randint(
        0, num_drug_nodes_in_graph, (2, num_train_pos), dtype=torch.long, device=device)
    train_neg_edge_index[1, :] = torch.randint(
        0, num_se_nodes_in_graph, (1, num_train_pos), dtype=torch.long, device=device)

    # --- Get Embeddings for Sampled Edges ---
    pos_drug_emb = z_drug[train_pos_edge_index[0]]
    pos_se_emb = z_side_effect[train_pos_edge_index[1]]

    neg_drug_emb = z_drug[train_neg_edge_index[0]]
    neg_se_emb = z_side_effect[train_neg_edge_index[1]]

    # --- Calculate Link Prediction Scores ---
    pos_scores = link_predictor(pos_drug_emb, pos_se_emb)
    neg_scores = link_predictor(neg_drug_emb, neg_se_emb)

    # --- Calculate Loss ---
    scores = torch.cat([pos_scores, neg_scores], dim=0)
    ground_truth = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)], dim=0)

    loss = criterion(scores.squeeze(), ground_truth.squeeze())

    loss.backward()
    optimizer.step()

    return loss.item()
def evaluate_link_prediction(pos_edge_index, neg_edge_index):
     gnn_model.eval()
     link_predictor.eval()
     with torch.no_grad():
         # Get learned node embeddings from the full graph structure
         z_dict = gnn_model(data.x_dict, data.edge_index_dict)
         z_drug = z_dict['drug']
         z_side_effect = z_dict['side_effect']

         # --- Get Embeddings for Evaluation Edges ---
         eval_pos_drug_emb = z_drug[pos_edge_index[0]]
         eval_pos_se_emb = z_side_effect[pos_edge_index[1]]

         eval_neg_drug_emb = z_drug[neg_edge_index[0]]
         eval_neg_se_emb = z_side_effect[neg_edge_index[1]]

         # --- Calculate Link Prediction Scores for Evaluation Edges ---
         eval_pos_scores = link_predictor(eval_pos_drug_emb, eval_pos_se_emb)
         eval_neg_scores = link_predictor(eval_neg_drug_emb, eval_neg_se_emb)

         # --- Calculate Evaluation Loss ---
         eval_scores = torch.cat([eval_pos_scores, eval_neg_scores], dim=0)
         eval_ground_truth = torch.cat([torch.ones_like(eval_pos_scores), torch.zeros_like(eval_neg_scores)], dim=0)
         eval_loss = criterion(eval_scores.squeeze(), eval_ground_truth.squeeze())

         # --- Calculate Evaluation Metrics (AUC, AUPRC) ---
         eval_auc, eval_auprc = get_link_prediction_metrics(eval_pos_scores, eval_neg_scores)

     return eval_loss.item(), eval_auc, eval_auprc



In [None]:
epochs = 500 
best_val_metric = 0.0 
patience = 50 
epochs_without_improvement = 0
for epoch in range(1, epochs + 1):
    train_loss = train_link_prediction()
    val_loss, val_auc, val_auprc = evaluate_link_prediction(
        data['drug', 'causes', 'side_effect'].edge_index_val_pos,
        data['drug', 'causes', 'side_effect'].edge_index_val_neg
    )

    current_val_metric = val_auc 

    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}, Val AUPRC: {val_auprc:.4f}')

    # --- Early Stopping Logic ---
    if current_val_metric > best_val_metric:
        best_val_metric = current_val_metric
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= patience:
        print(f"Early stopping triggered after {epoch} epochs due to no improvement in validation metric.")
        break

# --- Final evaluation on test set ---
print("\nEvaluating on test set...")
test_loss, test_auc, test_auprc = evaluate_link_prediction(
    data['drug', 'causes', 'side_effect'].edge_index_test_pos,
    data['drug', 'causes', 'side_effect'].edge_index_test_neg
)
print(f'Test Loss: {test_loss:.4f}, Test AUC: {test_auc:.4f}, Test AUPRC: {test_auprc:.4f}')

Starting Link Prediction Training...
Epoch: 001, Train Loss: 0.7351, Val Loss: 0.4202, Val AUC: 0.9780, Val AUPRC: 0.9562
Epoch: 002, Train Loss: 0.3883, Val Loss: 0.1951, Val AUC: 0.9735, Val AUPRC: 0.9339
Epoch: 003, Train Loss: 0.1755, Val Loss: 0.1329, Val AUC: 0.9719, Val AUPRC: 0.9296
Epoch: 004, Train Loss: 0.1400, Val Loss: 0.1420, Val AUC: 0.9711, Val AUPRC: 0.9283
Epoch: 005, Train Loss: 0.1696, Val Loss: 0.1324, Val AUC: 0.9709, Val AUPRC: 0.9278
Epoch: 006, Train Loss: 0.1542, Val Loss: 0.1254, Val AUC: 0.9803, Val AUPRC: 0.9583
Epoch: 007, Train Loss: 0.1274, Val Loss: 0.1439, Val AUC: 0.9806, Val AUPRC: 0.9579
Epoch: 008, Train Loss: 0.1236, Val Loss: 0.1561, Val AUC: 0.9812, Val AUPRC: 0.9600
Epoch: 009, Train Loss: 0.1246, Val Loss: 0.1379, Val AUC: 0.9829, Val AUPRC: 0.9615
Epoch: 010, Train Loss: 0.1143, Val Loss: 0.1192, Val AUC: 0.9846, Val AUPRC: 0.9689
Epoch: 011, Train Loss: 0.0976, Val Loss: 0.1093, Val AUC: 0.9843, Val AUPRC: 0.9673
Epoch: 012, Train Loss: 0.09