**Imports and Initial Setup**

In [1]:
import pandas as pd
import torch
from neo4j import GraphDatabase
from torch_geometric.data import HeteroData
from torch_geometric.nn import HANConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
from tqdm import tqdm
import pickle
from itertools import product
import matplotlib.pyplot as plt

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Device name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

# Neo4j AuraDB connection
URI = "neo4j+s://b09f418b.databases.neo4j.io"
USERNAME = "neo4j"
PASSWORD = "Y9-UEMVWae0ISwDFsKFAtLczklxpSgOKZfKRyyI-mDY"
driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))

Using device: cuda
Device name: NVIDIA GeForce RTX 3050 Laptop GPU


**Fetch Graph Data with Relevant Relationships**

In [2]:
def fetch_graph_data():
    with driver.session() as session:
        # Fetch all nodes
        nodes_query = "MATCH (n) RETURN elementId(n) AS id, labels(n) AS labels, properties(n) AS props"
        nodes_result = session.run(nodes_query)
        nodes = [(record["id"], record["labels"][0], record["props"]) for record in nodes_result]

        # Fetch only relevant relationships
        rels_query = """
        MATCH (a)-[r:TREATS|INVOLVES|ASSOCIATED_WITH]->(b)
        RETURN elementId(a) AS source, elementId(b) AS target, type(r) AS type
        """
        rels_result = session.run(rels_query)
        relationships = [(record["source"], record["target"], record["type"]) for record in rels_result]
    return nodes, relationships

nodes, relationships = fetch_graph_data()
nodes_df = pd.DataFrame(nodes, columns=["id", "label", "props"])
rels_df = pd.DataFrame(relationships, columns=["source", "target", "type"])

In [3]:
nodes_df.head()

Unnamed: 0,id,label,props
0,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:0,Drug,"{'name': 'ACETAMINOPHEN', 'chembl_id': 'CHEMBL..."
1,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:1,Drug,"{'name': 'NITRIC OXIDE', 'chembl_id': 'CHEMBL1..."
2,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:2,Drug,"{'name': 'DECITABINE', 'chembl_id': 'CHEMBL120..."
3,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:3,Drug,"{'name': 'PRASUGREL', 'chembl_id': 'CHEMBL1201..."
4,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:4,Drug,"{'name': 'LUSPATERCEPT', 'chembl_id': 'CHEMBL3..."


In [4]:
rels_df.head()

Unnamed: 0,source,target,type
0,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:289,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:1423,TREATS
1,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:290,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:1489,TREATS
2,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:292,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:1413,TREATS
3,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:292,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:1466,TREATS
4,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:292,4:4e0830d9-7413-4b1e-83ba-e7d830b0a219:1513,TREATS


**Enhanced Feature Extraction**

In [5]:
def extract_features(props_dict, node_id, rels_df):
    # Handle SMILES for small molecules or fallback for biologics
    if "smiles" in props_dict and props_dict["smiles"]:
        mol = Chem.MolFromSmiles(props_dict["smiles"])
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=512) if mol else np.zeros(512)
        fp_tensor = torch.tensor(np.array(fp), dtype=torch.float32)
    else:
        # Fallback for biologics with sequence or random noise
        fp_tensor = torch.zeros(512, dtype=torch.float32)
        if "sequence" in props_dict and props_dict["sequence"]:
            seq_len = min(len(props_dict["sequence"]), 512)
            fp_tensor[:seq_len] = torch.rand(seq_len)  # Add random noise based on sequence length

    # Calculate degree based on relevant relationships only
    degree = len(rels_df[(rels_df["source"] == node_id) & (rels_df["type"].isin(["TREATS", "INVOLVES", "ASSOCIATED_WITH"]))]) + \
             len(rels_df[(rels_df["target"] == node_id) & (rels_df["type"].isin(["TREATS", "INVOLVES", "ASSOCIATED_WITH"]))])
    extra_features = [float(degree)]
    for key in ["trial_count", "molecular_weight", "value"]:
        extra_features.append(float(props_dict.get(key, 0.0)))
    extra_tensor = torch.tensor(extra_features, dtype=torch.float32)

    features = torch.cat([fp_tensor, extra_tensor], dim=0)
    # Normalize features
    features = (features - features.mean()) / (features.std() + 1e-8) if features.std() > 0 else features
    return features.to(device)

**Build HeteroData and Edge Indices**

In [6]:
# Initialize HeteroData
data = HeteroData()
type_to_idx = {ntype: {row["id"]: idx for idx, (_, row) in enumerate(nodes_df[nodes_df["label"] == ntype].iterrows())} 
               for ntype in nodes_df["label"].unique()}

for node_type in nodes_df["label"].unique():
    type_nodes = nodes_df[nodes_df["label"] == node_type]
    features = torch.stack([extract_features(row["props"], row["id"], rels_df) for _, row in type_nodes.iterrows()])
    data[node_type].x = features

# Build edge indices for relevant relationships
edge_dict = {}
for _, row in rels_df.iterrows():
    src_type = nodes_df[nodes_df["id"] == row["source"]]["label"].iloc[0]
    tgt_type = nodes_df[nodes_df["id"] == row["target"]]["label"].iloc[0]
    edge_type = (src_type, row["type"], tgt_type)
    src_idx = type_to_idx[src_type][row["source"]]
    tgt_idx = type_to_idx[tgt_type][row["target"]]
    if edge_type not in edge_dict:
        edge_dict[edge_type] = [[], []]
    edge_dict[edge_type][0].append(src_idx)
    edge_dict[edge_type][1].append(tgt_idx)

for edge_type, (src_indices, tgt_indices) in edge_dict.items():
    if src_indices and tgt_indices:
        data[edge_type].edge_index = torch.tensor([src_indices, tgt_indices], dtype=torch.long).to(device)
    else:
        data[edge_type].edge_index = torch.zeros((2, 0), dtype=torch.long).to(device)

metadata = (list(data.node_types), list(data.edge_types))
data.metadata = lambda: metadata

**Prepare Training Data**

In [7]:
# Prepare drug-disease pairs
positive_pairs = rels_df[rels_df["type"] == "TREATS"]
drug_nodes = nodes_df[nodes_df["label"] == "Drug"]["id"].tolist()
disease_nodes = nodes_df[nodes_df["label"] == "Disease"]["id"].tolist()
all_pairs = [(d, dis) for d, dis in product(drug_nodes, disease_nodes)]
labels = [1 if len(positive_pairs[(positive_pairs["source"] == d) & (positive_pairs["target"] == dis)]) > 0 else 0 
          for d, dis in all_pairs]
pos_pairs = [p for p, l in zip(all_pairs, labels) if l == 1]
neg_pairs = [p for p, l in zip(all_pairs, labels) if l == 0]
neg_pairs_sampled = np.random.choice(len(neg_pairs), len(pos_pairs) * 2, replace=False) if len(neg_pairs) > 0 else []
selected_pairs = pos_pairs + [neg_pairs[i] for i in neg_pairs_sampled]
selected_labels = [1] * len(pos_pairs) + [0] * len(neg_pairs_sampled)

edge_index = torch.tensor(
    [[type_to_idx["Drug"][d] for d, _ in selected_pairs], 
     [type_to_idx["Disease"][dis] for _, dis in selected_pairs]], 
    dtype=torch.long).to(device)
labels = torch.tensor(selected_labels, dtype=torch.float).to(device)

# Split data
train_idx, temp_idx = train_test_split(range(len(selected_pairs)), test_size=0.4, random_state=42)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

data["Drug", "TREATS", "Disease"].train_edge_index = edge_index[:, train_idx]
data["Drug", "TREATS", "Disease"].train_edge_label = labels[train_idx]
data["Drug", "TREATS", "Disease"].val_edge_index = edge_index[:, val_idx]
data["Drug", "TREATS", "Disease"].val_edge_label = labels[val_idx]
data["Drug", "TREATS", "Disease"].test_edge_index = edge_index[:, test_idx]
data["Drug", "TREATS", "Disease"].test_edge_label = labels[test_idx]

**Define HAN Model**

In [8]:
class EnhancedHAN(torch.nn.Module):
    def __init__(self, in_channels_dict, hidden_channels, out_channels, metadata):
        super(EnhancedHAN, self).__init__()
        self.han1 = HANConv(in_channels_dict, hidden_channels, metadata=metadata, heads=4, dropout=0.2)
        self.linear_drug = torch.nn.Linear(hidden_channels, out_channels)
        self.linear_disease = torch.nn.Linear(hidden_channels, out_channels)
        self.residual = torch.nn.Linear(list(in_channels_dict.values())[0], out_channels)
        self.node_types = list(in_channels_dict.keys())

    def forward(self, x_dict, edge_index_dict, return_attention_weights=False):
        device = next(self.parameters()).device
        x_in = torch.cat([x_dict[nt] for nt in x_dict], dim=0)

        # Single HAN layer
        x = self.han1(x_dict, edge_index_dict)
        x = {k: torch.relu(v) if v is not None else torch.zeros(x_dict[k].shape[0], self.han1.out_channels, device=device)
             for k, v in x.items()}

        # Residual and final embeddings
        res = torch.relu(self.residual(x_in))[:len(x["Drug"]) + len(x["Disease"])]
        drug_emb = self.linear_drug(x["Drug"]) + res[:len(x["Drug"])]
        disease_emb = self.linear_disease(x["Disease"]) + res[len(x["Drug"]):]
        return drug_emb, disease_emb, None

**Model Initialization and Training Setup**

In [9]:
# Initialize model
in_channels_dict = {nt: data[nt].x.shape[1] for nt in data.node_types}
model = EnhancedHAN(in_channels_dict, hidden_channels=256, out_channels=128, metadata=metadata).to(device)

# Print model summary
print("Model Architecture:")
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params}")

# Training setup
pos_weight = torch.tensor([len(neg_pairs_sampled) / len(pos_pairs) if len(pos_pairs) > 0 else 1.0], dtype=torch.float).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=50)

# Lists to store metrics
train_losses = []
val_aucs = []
epochs_tracked = []

Model Architecture:
EnhancedHAN(
  (han1): HANConv(256, heads=4)
  (linear_drug): Linear(in_features=256, out_features=128, bias=True)
  (linear_disease): Linear(in_features=256, out_features=128, bias=True)
  (residual): Linear(in_features=516, out_features=128, bias=True)
)
Model parameters: 1920128


**Training and Evaluation Functions**

In [10]:
def train():
    model.train()
    optimizer.zero_grad()
    drug_emb, disease_emb, _ = model(data.x_dict, data.edge_index_dict)
    edge_index = data["Drug", "TREATS", "Disease"].train_edge_index
    pred = (drug_emb[edge_index[0]] * disease_emb[edge_index[1]]).sum(dim=1)
    loss = loss_fn(pred, data["Drug", "TREATS", "Disease"].train_edge_label)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    return loss.item(), pred

def evaluate(split="val"):
    model.eval()
    with torch.no_grad():
        drug_emb, disease_emb, _ = model(data.x_dict, data.edge_index_dict)
        edge_index = data["Drug", "TREATS", "Disease"][f"{split}_edge_index"]
        pred = (drug_emb[edge_index[0]] * disease_emb[edge_index[1]]).sum(dim=1)
        preds = torch.sigmoid(pred).cpu().numpy()
        labels = data["Drug", "TREATS", "Disease"][f"{split}_edge_label"].cpu().numpy()
        auc = roc_auc_score(labels, preds) if len(np.unique(labels)) > 1 else 0.5
        if split == "test":
            return auc, preds, labels
        return auc

**Training Loop and Visualization**

In [11]:
# Training loop
best_val_auc = 0
patience = 100
early_stop_counter = 0
for epoch in tqdm(range(500), desc="Training"):
    loss, _ = train()
    train_losses.append(loss)
    epochs_tracked.append(epoch)
    
    if epoch % 50 == 0:
        val_auc = evaluate("val")
        val_aucs.append(val_auc)
        print(f"Epoch {epoch}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}")
        scheduler.step(val_auc)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save({'model_state_dict': model.state_dict(), 'metadata': metadata, 'in_channels_dict': in_channels_dict}, 
                       "han_drug_repurposing_retrained.pth")
            early_stop_counter = 0
        else:
            early_stop_counter += 50
            if early_stop_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

# Plot Training Loss
plt.figure(figsize=(10, 6))
plt.plot(epochs_tracked, train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
plt.grid(True)
plt.savefig('training_loss_retrained.png')
plt.close()

# Plot Validation AUC
plt.figure(figsize=(10, 6))
plt.plot([e for e in range(0, len(val_aucs) * 50, 50)], val_aucs, label='Validation AUC', marker='o')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('Validation AUC Over Time')
plt.legend()
plt.grid(True)
plt.savefig('validation_auc_retrained.png')
plt.close()

# Final evaluation with ROC curve
test_auc, test_preds, test_labels = evaluate("test")
print(f"Test AUC: {test_auc:.4f}")

# Plot ROC Curve
fpr, tpr, _ = roc_curve(test_labels, test_preds)
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {test_auc:.4f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random Guess')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve on Test Set')
plt.legend()
plt.grid(True)
plt.savefig('roc_curve_retrained.png')
plt.close()

Training:   1%|          | 6/500 [00:00<00:54,  8.99it/s]

Epoch 0, Loss: 0.8384, Val AUC: 0.1577


Training:  11%|█         | 55/500 [00:01<00:09, 47.27it/s]

Epoch 50, Loss: 0.5368, Val AUC: 0.8494


Training:  22%|██▏       | 110/500 [00:02<00:06, 57.30it/s]

Epoch 100, Loss: 0.4950, Val AUC: 0.8738


Training:  32%|███▏      | 161/500 [00:03<00:05, 58.98it/s]

Epoch 150, Loss: 0.5108, Val AUC: 0.9107


Training:  42%|████▏     | 209/500 [00:04<00:05, 58.20it/s]

Epoch 200, Loss: 0.4820, Val AUC: 0.9143


Training:  51%|█████▏    | 257/500 [00:05<00:04, 59.95it/s]

Epoch 250, Loss: 0.4479, Val AUC: 0.8845


Training:  60%|██████    | 300/500 [00:05<00:03, 51.60it/s]


Epoch 300, Loss: 0.4764, Val AUC: 0.8958
Early stopping at epoch 300
Test AUC: 0.8862


In [12]:
# Save data
nodes_df.to_csv("nodes_df_retrained.csv", index=False)
rels_df.to_csv("rels_df_retrained.csv", index=False)
with open('type_to_idx_retrained.pkl', 'wb') as f:
    pickle.dump(type_to_idx, f)

driver.close()