In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GATConv
from torch_geometric.data import HeteroData
import numpy as np
import pandas as pd
import os

In [14]:
input_data = {
    "user": {
        "id": "user_001",
        "name": "John Doe",
        "age": 45,
        "gender": "male",
        "blood_type": "O+",
        "genetic_conditions": ["G6PD deficiency"],
        "allergies": ["penicillin", "sulfa drugs"],
        "medical_history": ["hypertension", "asthma"]
    },
    "symptoms": {
        "reported": ["shortness of breath", "chest pain"],
        "duration_days": 3
    },
    "diagnosis": {
        "predicted_conditions": ["coronary artery disease"],
        "confidence_scores": {
            "coronary artery disease": 0.88
        }
    },
    "timestamp": "2025-04-18T10:20:00Z"
}

drug='aspirin'

In [5]:
sider_file_path = 'data/meddra_all_se.tsv'
sider_names_file_path = 'data/drug_names.tsv'
sider_df = pd.read_csv(sider_file_path,sep='\t',header=None, compression=None)
drug_names_df = pd.read_csv(sider_names_file_path,sep='\t',header=None,compression=None)
drug_names_df.columns = ['STITCH_flat', 'Drug_Name']
sider_df.columns = [
    'STITCH_compound_flat',  # Example: CID100000085
    'STITCH_compound_stereo',  # Example: CID000010917
    'UMLS_concept_id',         # Example: C0000729
    'MedDRA_type',             # e.g., LLT
    'MedDRA_concept_id',       # Example: C0000729
    'LLT_preferred_term'       # e.g., "Abdominal cramps"
]
sider_df.dropna(subset=['MedDRA_type', 'MedDRA_concept_id'], inplace=True)
ctd_dir='data'
ctd_chem_disease_file = os.path.join(ctd_dir, 'CTD_chemicals_diseases.csv.gz')
ctd_chem_gene_file = os.path.join(ctd_dir, 'CTD_chem_gene_ixns.csv.gz')
ctd_chemicals_file = os.path.join(ctd_dir, 'CTD_chemicals.csv.gz')
ctd_genes_file = os.path.join(ctd_dir, 'CTD_genes.csv.gz')
ctd_chem_disease_df = pd.read_csv(
        ctd_chem_disease_file,
        comment='#', 
        compression='gzip'
    )
ctd_chem_gene_df = pd.read_csv(
    ctd_chem_gene_file,
    comment='#',
    compression='gzip'
)
ctd_chemicals_df = pd.read_csv(
    ctd_chemicals_file,
    sep='\t',
    comment='#',
    compression='gzip'
)
ctd_genes_df = pd.read_csv(
    ctd_genes_file,
    sep='\t',
    comment='#',
    compression='gzip'
)
ctd_chem_gene_df.columns = [
    "ChemicalName",       # e.g., 10074-G5
    "ChemicalID",         # e.g., C534883
    "CasRN",              # Unnamed or CAS Registry Number
    "GeneSymbol",         # e.g., AR
    "GeneID",             # e.g., 367
    "GeneForms",          # e.g., protein
    "Organism",           # e.g., Homo sapiens
    "OrganismID",         # e.g., 9606
    "Interaction",        # Natural language interaction
    "InteractionActions", # Parsed actions e.g., decreases^reaction|increases^expression
    "PubMedIDs"           # Supporting publication IDs
]
ctd_chem_disease_df.columns = [
    "ChemicalName", "ChemicalID", "CasRN", "DiseaseName", "DiseaseID", 
    "DirectEvidence", "InferenceGeneSymbol", "InferenceScore", 
    "OmimIDs", "PubMedIDs"
]
ctd_chem_disease_df.drop(columns=['CasRN', 'DirectEvidence', 'OmimIDs'],inplace=True)
ctd_chem_gene_df.drop(
    columns=['CasRN', 'GeneForms', 'Organism', 'OrganismID'],inplace=True
)
ctd_chem_disease_df['InferenceScore'].fillna(ctd_chem_disease_df['InferenceScore'].mean(), inplace=True)
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
ctd_chem_disease_df['InferenceScore'] = scaler.fit_transform(ctd_chem_disease_df[['InferenceScore']])
all_drug_ids = pd.concat([
    sider_df['STITCH_compound_flat'], 
    ctd_chem_gene_df['ChemicalID'], 
    ctd_chem_disease_df['ChemicalID']
]).dropna().unique()
all_se_ids = sider_df['MedDRA_concept_id'].dropna().unique()
all_disease_ids = ctd_chem_disease_df['DiseaseID'].dropna().unique()
all_gene_ids = ctd_chem_gene_df['GeneID'].dropna().unique()
import numpy as np

sampling_fraction = 0.08  # Set to e.g. 0.1 for 10% sampling
num_drugs = len(all_drug_ids)
num_sampled_drugs = int(num_drugs * sampling_fraction)

sampled_drug_ids = np.random.choice(all_drug_ids, size=num_sampled_drugs, replace=False)

# print(f"\nSubsampling drugs:")
# print(f"  Original drugs: {num_drugs}")
# print(f"  Sampled drugs: {num_sampled_drugs}")

# Create mappings for sampled drugs and full mappings for other node types
drug_mapping = {id: i for i, id in enumerate(sampled_drug_ids)}
se_mapping = {id: i for i, id in enumerate(all_se_ids)}
disease_mapping = {id: i for i, id in enumerate(all_disease_ids)}
gene_mapping = {id: i for i, id in enumerate(all_gene_ids)}

from sklearn.calibration import LabelEncoder

drug_encoder = LabelEncoder()
se_encoder = LabelEncoder()
disease_encoder = LabelEncoder()
gene_encoder = LabelEncoder()
drug_encoder.fit(sampled_drug_ids)
encoded_drug_features = drug_encoder.transform(sampled_drug_ids)

# Encode side effects (you can use all SEs here since subsampling is only on drugs)
encoded_se_features = se_encoder.fit_transform(all_se_ids)

# Encode diseases and genes (same reasoning applies as side effects)
encoded_disease_features = disease_encoder.fit_transform(all_disease_ids)
encoded_gene_features = gene_encoder.fit_transform(all_gene_ids)
sider_df_sampled = sider_df[sider_df['STITCH_compound_flat'].isin(sampled_drug_ids)].copy()
ctd_chem_disease_df_sampled = ctd_chem_disease_df[ctd_chem_disease_df['ChemicalID'].isin(sampled_drug_ids)].copy()
ctd_chem_gene_df_sampled = ctd_chem_gene_df[ctd_chem_gene_df['ChemicalID'].isin(sampled_drug_ids)].copy()
drug_mapping = {drug_id: idx for idx, drug_id in enumerate(drug_encoder.classes_)}
se_mapping = {se_id: idx for idx, se_id in enumerate(se_encoder.classes_)}
disease_mapping = {disease_id: idx for idx, disease_id in enumerate(disease_encoder.classes_)}
gene_mapping = {gene_id: idx for idx, gene_id in enumerate(gene_encoder.classes_)}

# Drug -> Side Effect edges
sider_edges = []
for _, row in sider_df_sampled.iterrows():
    drug_id = row['STITCH_compound_flat']
    se_id = row['MedDRA_concept_id']
    if drug_id in drug_mapping and se_id in se_mapping:
        sider_edges.append((drug_mapping[drug_id], se_mapping[se_id]))

sider_edge_index = torch.tensor(sider_edges, dtype=torch.long).t().contiguous()
print(f"Drug-Side Effect edges: {sider_edge_index.size(1)}")

# Drug -> Disease edges
ctd_chem_disease_edges = []
for _, row in ctd_chem_disease_df_sampled.iterrows():
    chem_id = row['ChemicalID']
    disease_id = row['DiseaseID']
    if chem_id in drug_mapping and disease_id in disease_mapping:
        ctd_chem_disease_edges.append((drug_mapping[chem_id], disease_mapping[disease_id]))

ctd_chem_disease_edge_index = torch.tensor(ctd_chem_disease_edges, dtype=torch.long).t().contiguous()
print(f"Drug-Disease edges: {ctd_chem_disease_edge_index.size(1)}")

# Drug -> Gene edges
ctd_chem_gene_edges = []
for _, row in ctd_chem_gene_df_sampled.iterrows():
    chem_id = row['ChemicalID']
    gene_id = row['GeneID']
    if chem_id in drug_mapping and gene_id in gene_mapping:
        ctd_chem_gene_edges.append((drug_mapping[chem_id], gene_mapping[gene_id]))

ctd_chem_gene_edge_index = torch.tensor(ctd_chem_gene_edges, dtype=torch.long).t().contiguous()
print(f"Drug-Gene edges: {ctd_chem_gene_edge_index.size(1)}")
import random
from sklearn.calibration import LabelEncoder
num_drug_features = 100
num_se_features = 80
num_disease_features = 90
num_gene_features = 150

x_drug = torch.randn(len(drug_mapping), num_drug_features)
x_side_effect = torch.randn(len(se_mapping), num_se_features)
x_disease = torch.randn(len(disease_mapping), num_disease_features)
x_gene = torch.randn(len(gene_mapping), num_gene_features)

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  ctd_chem_disease_df['InferenceScore'].fillna(ctd_chem_disease_df['InferenceScore'].mean(), inplace=True)


Drug-Side Effect edges: 25299
Drug-Disease edges: 814646
Drug-Gene edges: 325339


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GATConv
from torch_geometric.data import HeteroData
import numpy as np


# === 1. Define GNN Model ===
class HeterogeneousGNNModel(nn.Module):
    def __init__(self, hidden_channels, dropout_prob=0.5):
        super().__init__()

        self.edge_types = [
            ('drug', 'causes', 'side_effect'),
            ('side_effect', 'is_caused_by', 'drug'),
            ('drug', 'associates', 'disease'),
            ('disease', 'is_associated_with_drug', 'drug'),
            ('drug', 'interacts', 'gene'),
            ('gene', 'is_interacted_by', 'drug'),
        ]

        self.conv1 = HeteroConv({
            edge_type: GATConv((-1, -1), hidden_channels, add_self_loops=False)
            for edge_type in self.edge_types
        }, aggr='sum')

        self.dropout1 = nn.Dropout(dropout_prob)

        self.conv2 = HeteroConv({
            edge_type: GATConv((hidden_channels, hidden_channels), hidden_channels, add_self_loops=False)
            for edge_type in self.edge_types
        }, aggr='sum')

        self.dropout2 = nn.Dropout(dropout_prob)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.dropout1(x) for key, x in x_dict.items()}

        x_dict = self.conv2(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.dropout2(x) for key, x in x_dict.items()}

        return x_dict


# === 2. Define Link Predictor ===
class LinkPredictor(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin = nn.Linear(in_channels * 2, 1)

    def forward(self, z_drug, z_se):
        z = torch.cat([z_drug, z_se], dim=-1)
        return self.lin(z)
def get_link_prediction_metrics(pos_pred, neg_pred):
    pos_prob = torch.sigmoid(pos_pred)
    neg_prob = torch.sigmoid(neg_pred)

    probs = torch.cat([pos_prob, neg_prob], dim=0).detach().cpu().numpy()
    ground_truth = torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)], dim=0).detach().cpu().numpy()

    auc_score = roc_auc_score(ground_truth, probs)
    auprc_score = average_precision_score(ground_truth, probs)

    return auc_score, auprc_score

# === 3. Load the models and weights ===
hidden_channels = 64
gnn_model = HeterogeneousGNNModel(hidden_channels=hidden_channels)
link_predictor = LinkPredictor(in_channels=hidden_channels)

checkpoint = torch.load("adr_gnn_model.pth", map_location=torch.device("cpu"))
gnn_model.load_state_dict(checkpoint['gnn_state_dict'])
link_predictor.load_state_dict(checkpoint['link_predictor_state_dict'])

gnn_model.eval()
link_predictor.eval()
from torch_geometric.data import HeteroData
data = HeteroData()

data['drug'].x = x_drug
data['side_effect'].x = x_side_effect
data['disease'].x = x_disease
data['gene'].x = x_gene

data['drug', 'causes', 'side_effect'].edge_index = sider_edge_index
data['drug', 'associates', 'disease'].edge_index = ctd_chem_disease_edge_index
data['drug', 'interacts', 'gene'].edge_index = ctd_chem_gene_edge_index
data['side_effect', 'is_caused_by', 'drug'].edge_index = sider_edge_index.flip(0)
data['disease', 'is_associated_with_drug', 'drug'].edge_index = ctd_chem_disease_edge_index.flip(0)
data['gene', 'is_interacted_by', 'drug'].edge_index = ctd_chem_gene_edge_index.flip(0)
device='cpu'
data.to(device)
target_edge_index = data['drug', 'causes', 'side_effect'].edge_index

# Split positive edges into train, val, test
num_target_edges = target_edge_index.size(1)
perm = np.random.permutation(num_target_edges)
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

num_train_pos = int(num_target_edges * train_ratio)
num_val_pos = int(num_target_edges * val_ratio)
num_test_pos = num_target_edges - num_train_pos - num_val_pos

train_pos_indices = perm[:num_train_pos]
val_pos_indices = perm[num_train_pos:num_train_pos + num_val_pos]
test_pos_indices = perm[num_train_pos + num_val_pos:]

# Extract the actual edge index tensors for positive edges
data['drug', 'causes', 'side_effect'].edge_index_train_pos = target_edge_index[:, train_pos_indices].to(device)
data['drug', 'causes', 'side_effect'].edge_index_val_pos = target_edge_index[:, val_pos_indices].to(device)
data['drug', 'causes', 'side_effect'].edge_index_test_pos = target_edge_index[:, test_pos_indices].to(device)

# Generate negative samples for validation and test sets
# Ensure these negative samples are NOT present in the original positive edges
num_drug_nodes = data['drug'].num_nodes
num_se_nodes = data['side_effect'].num_nodes

data['drug', 'causes', 'side_effect'].edge_index_val_neg = torch.randint(
    0, num_drug_nodes, (2, num_val_pos), dtype=torch.long, device=device) # Example: same number as val_pos

data['drug', 'causes', 'side_effect'].edge_index_test_neg = torch.randint(
    0, num_drug_nodes, (2, num_test_pos), dtype=torch.long, device=device) # Example: same number as test_pos



In [9]:
hidden_channels = 64
gnn_model = HeterogeneousGNNModel(hidden_channels=hidden_channels)
link_predictor = LinkPredictor(in_channels=hidden_channels)

checkpoint = torch.load("adr_gnn_model.pth", map_location=torch.device("cpu"))
gnn_model.load_state_dict(checkpoint['gnn_state_dict'])
link_predictor.load_state_dict(checkpoint['link_predictor_state_dict'])

gnn_model.eval()
link_predictor.eval()

LinkPredictor(
  (lin): Linear(in_features=128, out_features=1, bias=True)
)

In [10]:
with torch.no_grad():
    out_dict = gnn_model(data.x_dict, data.edge_index_dict)


In [12]:
drug_idx = 11
side_effect_idx = 22

drug_embed = out_dict['drug'][drug_idx].unsqueeze(0)           # (1, hidden)
se_embed = out_dict['side_effect'][side_effect_idx].unsqueeze(0)

pred_score = link_predictor(drug_embed, se_embed)
prob = torch.sigmoid(pred_score)

print(f"Predicted probability of side effect for drug {drug_idx} and side effect {side_effect_idx}: {prob.item():.4f}")

Predicted probability of side effect for drug 11 and side effect 22: 0.0001


In [16]:
import json
def preprocess_input(input_data, drug):
    # Parse the input_data from the string to a dictionary
    data = input_data
    
    # Extract user info, symptoms, diagnosis
    user_data = data['user']
    symptoms = data['symptoms']
    diagnosis = data['diagnosis']
    
    # Prepare features for the GNN model
    user_features = {
        "user_id": user_data['id'],
        "age": user_data['age'],
        "gender": user_data['gender'],
        "blood_type": user_data['blood_type'],
        "genetic_conditions": user_data['genetic_conditions'],
        "allergies": user_data['allergies'],
        "medical_history": user_data['medical_history'],
    }
    
    symptoms_data = symptoms['reported']
    diagnosis_data = diagnosis['predicted_conditions']
    
    # Combine user features, symptoms, and diagnosis to create the model input
    model_input = {
        "user_features": user_features,
        "symptoms": symptoms_data,
        "diagnosis": diagnosis_data,
        "drug": drug
    }
    
    return model_input

def predict_adr(model_input):
    # Your GNN model prediction logic here
    # This is a placeholder and should be replaced by the actual GNN-based ADR prediction
    # Assuming a function that processes the input and returns the prediction
    print(f"Predicting ADR for drug: {model_input['drug']} with user data:")
    print(model_input['user_features'])
    
    # Mock response (replace with actual model logic)
    adr_predictions = {
        "drug": model_input['drug'],
        "predicted_adr": ["gastrointestinal upset", "bleeding risk"],
        "severity": "moderate",
        "precautions": ["Avoid alcohol", "Monitor blood pressure"]
    }
    
    return adr_predictions

# Main function to process and predict ADR
def predict_adverse_reactions(input_data, drug):
    # Preprocess the input data
    model_input = preprocess_input(input_data, drug)
    
    # Get the ADR predictions using the model
    adr_predictions = predict_adr(model_input)
    
    return adr_predictions

# Test the integration with example input
adr_result = predict_adverse_reactions(input_data, drug)
print(json.dumps(adr_result, indent=4))

Predicting ADR for drug: aspirin with user data:
{'user_id': 'user_001', 'age': 45, 'gender': 'male', 'blood_type': 'O+', 'genetic_conditions': ['G6PD deficiency'], 'allergies': ['penicillin', 'sulfa drugs'], 'medical_history': ['hypertension', 'asthma']}
{
    "drug": "aspirin",
    "predicted_adr": [
        "gastrointestinal upset",
        "bleeding risk"
    ],
    "severity": "moderate",
    "precautions": [
        "Avoid alcohol",
        "Monitor blood pressure"
    ]
}
