In [3]:
# --- 1. Load Libraries and Enriched Knowledge Graph ---

import networkx as nx
import pandas as pd
from pathlib import Path
import torch
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected, RandomLinkSplit

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")

# --- Configuration ---
PROJECT_ROOT = Path('.').resolve().parent
DATA_DIR = PROJECT_ROOT / 'data'
GRAPH_PATH = DATA_DIR / 'knowledge_graph.graphml'

# --- Load the Graph ---
print(f"\nLoading enriched knowledge graph from: {GRAPH_PATH}")
G = nx.read_graphml(GRAPH_PATH)
print("Graph loaded successfully.")

# --- Verification ---
print("\n--- Verifying Graph Enrichment ---")
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")

print("\nVerifying attributes of drug DB00316 (Acetaminophen):")
acetaminophen_node = G.nodes['DB00316']
for attr, value in acetaminophen_node.items():
    print(f"  - {attr}: {value}")

PyTorch version: 2.4.1
PyTorch Geometric version: 2.6.1

Loading enriched knowledge graph from: C:\Users\Sheetal\PharmacoGraph-Agent\data\knowledge_graph.graphml
Graph loaded successfully.

--- Verifying Graph Enrichment ---
Number of nodes: 4388
Number of edges: 318666

Verifying attributes of drug DB00316 (Acetaminophen):
  - type: drug
  - mol_weight: 151.165
  - logp: 1.3505999999999998
  - h_bond_donors: 2
  - h_bond_acceptors: 2
  - tpsa: 49.33


In [8]:
# --- 2. Convert NetworkX Graph to PyG HeteroData Object (CORRECTED VERSION) ---
import numpy as np

print("--- Preparing data for PyTorch Geometric (Corrected Bipartite Method) ---")
data = HeteroData()

# --- Node Processing (Separate for each type) ---
# We need to create separate mappings for drug and reaction nodes.
# The indices for each type MUST start from 0.

# Process drug nodes
drug_nodes = [node for node, data_dict in G.nodes(data=True) if data_dict['type'] == 'drug']
drug_map = {node_id: i for i, node_id in enumerate(drug_nodes)}
drug_features = []
for drug_id in drug_nodes:
    data_dict = G.nodes[drug_id]
    features = [
        float(data_dict.get('mol_weight', 0)), float(data_dict.get('logp', 0)),
        float(data_dict.get('h_bond_donors', 0)), float(data_dict.get('h_bond_acceptors', 0)),
        float(data_dict.get('tpsa', 0))
    ]
    drug_features.append(features)
data['drug'].x = torch.tensor(drug_features, dtype=torch.float)

# Process reaction nodes
reaction_nodes = [node for node, data_dict in G.nodes(data=True) if data_dict['type'] == 'reaction']
reaction_map = {node_id: i for i, node_id in enumerate(reaction_nodes)}
# For reaction nodes, we use an identity matrix as features
data['reaction'].x = torch.eye(len(reaction_nodes))

print("Node features processed and mapped separately.")
print(f"  - Drug mapping: {len(drug_map)} nodes, indices 0-{len(drug_map)-1}")
print(f"  - Reaction mapping: {len(reaction_map)} nodes, indices 0-{len(reaction_map)-1}")

# --- Edge Processing (Map to new indices) ---
# Create the edge_index using the new, type-specific mappings.
source_nodes = []
target_nodes = []
for u, v in G.edges():
    # Ensure consistent ordering: source is always a drug, target is always a reaction
    if G.nodes[u]['type'] == 'drug':
        source_id, target_id = u, v
    else:
        source_id, target_id = v, u
    
    # Append the new, correct indices
    source_nodes.append(drug_map[source_id])
    target_nodes.append(reaction_map[target_id])

edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)
data['drug', 'causes', 'reaction'].edge_index = edge_index
print("\nEdge index created with type-specific indices.")

# --- Add reverse edges ---
data['reaction', 'rev_causes', 'drug'].edge_index = edge_index.flip([0])
print("Reverse edges added.")

print("\n--- PyG Data Object Created Successfully ---")
print(data)

--- Preparing data for PyTorch Geometric (Corrected Bipartite Method) ---
Node features processed and mapped separately.
  - Drug mapping: 917 nodes, indices 0-916
  - Reaction mapping: 3471 nodes, indices 0-3470

Edge index created with type-specific indices.
Reverse edges added.

--- PyG Data Object Created Successfully ---
HeteroData(
  drug={ x=[917, 5] },
  reaction={ x=[3471, 3471] },
  (drug, causes, reaction)={ edge_index=[2, 318666] },
  (reaction, rev_causes, drug)={ edge_index=[2, 318666] }
)


In [9]:
# --- 3. Split Data into Training, Validation, and Test Sets ---

print("--- Splitting links for training, validation, and testing ---")
# We will use the RandomLinkSplit transform for this.
# It will split the 'causes' edges into three sets.
# It also automatically adds "negative" edges (links that don't exist)
# for the model to learn from, which is a crucial step.

transform = RandomLinkSplit(
    is_undirected=True,          # Our graph is undirected
    num_val=0.1,                 # Hold out 10% of edges for validation
    num_test=0.1,                # Hold out 10% of edges for testing
    neg_sampling_ratio=1.0,      # For each positive edge, create one negative edge
    edge_types=[('drug', 'causes', 'reaction')],
    rev_edge_types=[('reaction', 'rev_causes', 'drug')], # Needed for undirected graphs
)

train_data, val_data, test_data = transform(data)

print("\n--- Data Splitting Complete ---")
print("Training Data Sample:")
print(train_data)
print("\nValidation Data Sample:")
print(val_data)
print("\nTest Data Sample:")
print(test_data)

--- Splitting links for training, validation, and testing ---

--- Data Splitting Complete ---
Training Data Sample:
HeteroData(
  drug={ x=[917, 5] },
  reaction={ x=[3471, 3471] },
  (drug, causes, reaction)={
    edge_index=[2, 254934],
    edge_label=[509868],
    edge_label_index=[2, 509868],
  },
  (reaction, rev_causes, drug)={ edge_index=[2, 254934] }
)

Validation Data Sample:
HeteroData(
  drug={ x=[917, 5] },
  reaction={ x=[3471, 3471] },
  (drug, causes, reaction)={
    edge_index=[2, 254934],
    edge_label=[63732],
    edge_label_index=[2, 63732],
  },
  (reaction, rev_causes, drug)={ edge_index=[2, 254934] }
)

Test Data Sample:
HeteroData(
  drug={ x=[917, 5] },
  reaction={ x=[3471, 3471] },
  (drug, causes, reaction)={
    edge_index=[2, 286800],
    edge_label=[63732],
    edge_label_index=[2, 63732],
  },
  (reaction, rev_causes, drug)={ edge_index=[2, 286800] }
)


In [10]:
# --- 4. Define the Graph Neural Network (GNN) Model ---
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # We define our GNN layers
        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        # This defines how data flows through the layers
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x

class Decoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # A simple linear layer to predict the link probability
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        # Get the embeddings for the drug and reaction nodes in our "questions"
        row, col = edge_label_index
        z = torch.cat([z_dict['drug'][row], z_dict['reaction'][col]], dim=-1)
        
        # Pass them through the linear layers
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1) # Flatten the output

class Model(torch.nn.Module):
    def __init__(self, hidden_channels, data):
        super().__init__()
        # Define the encoders for each node type
        self.drug_lin = torch.nn.Linear(data['drug'].x.shape[1], hidden_channels)
        self.reaction_lin = torch.nn.Linear(data['reaction'].x.shape[1], hidden_channels)
        
        # Instantiate the GNN, making it heterogeneous
        self.gnn = GNN(hidden_channels)
        self.gnn = to_hetero(self.gnn, data.metadata(), aggr='sum')
        
        # Instantiate the decoder
        self.decoder = Decoder(hidden_channels)

    # Replace the old forward method with this one
    def forward(self, data):
        # Project the initial node features from the input data object
        x_dict = {
          'drug': self.drug_lin(data['drug'].x),
          'reaction': self.reaction_lin(data['reaction'].x),
        }
        
        # Get the final node embeddings from the GNN
        z_dict = self.gnn(x_dict, data.edge_index_dict)
        
        # Use the decoder to get link predictions for the specific links we're interested in
        return self.decoder(z_dict, data['drug', 'causes', 'reaction'].edge_label_index)

# --- Instantiate the model ---
model = Model(hidden_channels=64, data=data) # 64 is a common size for the hidden layer

print("--- GNN Model Architecture Defined and Instantiated ---")
print(model)

--- GNN Model Architecture Defined and Instantiated ---
Model(
  (drug_lin): Linear(in_features=5, out_features=64, bias=True)
  (reaction_lin): Linear(in_features=3471, out_features=64, bias=True)
  (gnn): GraphModule(
    (conv1): ModuleDict(
      (drug__causes__reaction): SAGEConv(64, 64, aggr=mean)
      (reaction__rev_causes__drug): SAGEConv(64, 64, aggr=mean)
    )
    (conv2): ModuleDict(
      (drug__causes__reaction): SAGEConv(64, 64, aggr=mean)
      (reaction__rev_causes__drug): SAGEConv(64, 64, aggr=mean)
    )
  )
  (decoder): Decoder(
    (lin1): Linear(in_features=128, out_features=64, bias=True)
    (lin2): Linear(in_features=64, out_features=1, bias=True)
  )
)


In [11]:
# --- 5. Train the GNN Model (FINAL, SELF-CONTAINED VERSION) ---

# --- All necessary imports for this cell ---
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import time
import os

# --- Setup for Training ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move all data splits to the selected device (e.g., CPU or GPU)
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)

# Re-instantiate the model to ensure it's fresh, then move to device
model = Model(hidden_channels=64, data=data).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

# --- Training Function ---
def train():
    model.train()
    optimizer.zero_grad()
    
    # Pass the entire training data object to the model
    pred = model(train_data)
    
    # Use the correct key to access the edge labels
    target = train_data['drug', 'causes', 'reaction'].edge_label
    loss = criterion(pred, target)
    
    loss.backward()
    optimizer.step()
    return float(loss)

# --- Testing Function ---
@torch.no_grad()
def test(data_split):
    model.eval()
    
    # Pass the entire data split object to the model
    pred = model(data_split)
    
    # Use the correct key to access the edge labels
    target = data_split['drug', 'causes', 'reaction'].edge_label.float()
    
    auc = roc_auc_score(target.cpu().numpy(), pred.sigmoid().cpu().numpy())
    return auc

# --- The Training Loop ---
start_time = time.time()
print("\n--- Starting Model Training ---")
best_val_auc = 0
best_model_path = os.path.join(DATA_DIR, 'best_gnn_model.pt')
patience = 10
patience_counter = 0

for epoch in range(1, 201):
    loss = train()
    
    if epoch % 10 == 0:
        train_auc = test(train_data)
        val_auc = test(val_data)
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train AUC: {train_auc:.4f}, Val AUC: {val_auc:.4f}')

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save(model.state_dict(), best_model_path)
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"\n--- Early stopping at epoch {epoch}. ---")
            break

end_time = time.time()
print(f"\n--- Training Complete in {end_time - start_time:.2f} seconds ---")

# --- Final Evaluation on the Test Set ---
print(f"\nLoading best model from: {best_model_path}")
model.load_state_dict(torch.load(best_model_path))
test_auc = test(test_data)
print(f'\nFinal Test AUC: {test_auc:.4f}')

Using device: cpu

--- Starting Model Training ---
Epoch: 010, Loss: 0.7113, Train AUC: 0.5087, Val AUC: 0.5060
Epoch: 020, Loss: 0.6918, Train AUC: 0.5828, Val AUC: 0.5736
Epoch: 030, Loss: 0.6863, Train AUC: 0.5920, Val AUC: 0.5852
Epoch: 040, Loss: 0.6691, Train AUC: 0.6284, Val AUC: 0.6211
Epoch: 050, Loss: 0.7432, Train AUC: 0.5567, Val AUC: 0.5568
Epoch: 060, Loss: 0.6734, Train AUC: 0.6386, Val AUC: 0.6356
Epoch: 070, Loss: 0.7011, Train AUC: 0.6262, Val AUC: 0.6194
Epoch: 080, Loss: 0.6819, Train AUC: 0.6050, Val AUC: 0.5999
Epoch: 090, Loss: 0.6747, Train AUC: 0.6265, Val AUC: 0.6227
Epoch: 100, Loss: 0.6609, Train AUC: 0.6870, Val AUC: 0.6808
Epoch: 110, Loss: 0.6977, Train AUC: 0.4306, Val AUC: 0.4317
Epoch: 120, Loss: 0.6720, Train AUC: 0.6298, Val AUC: 0.6262
Epoch: 130, Loss: 0.6041, Train AUC: 0.7623, Val AUC: 0.7573
Epoch: 140, Loss: 0.5289, Train AUC: 0.8340, Val AUC: 0.8281
Epoch: 150, Loss: 0.4400, Train AUC: 0.8898, Val AUC: 0.8850
Epoch: 160, Loss: 0.7217, Train AU

  model.load_state_dict(torch.load(best_model_path))



Final Test AUC: 0.9275
