In [1]:
import pandas as pd
import numpy as np
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

def load_dataset(sub_dataset, dataset_name):
    base_path =  Path('Path_to_base_directory')

    ############################# LOAD LLM Embeddigs ###############################################################
    
    llm_embeddings_path = base_path / f"embeddings/LLM/{sub_dataset}"
    # drug
    smile_names = np.load(llm_embeddings_path / f"drug/{sub_dataset}_smiles.npy", allow_pickle=True)
    can_smile_names = np.load(llm_embeddings_path / f"drug/{sub_dataset}_canonical_smiles.npy", allow_pickle=True)
    drug_embeddings = np.load(llm_embeddings_path / f"drug/{sub_dataset}_molecule_embeddings.npy", allow_pickle=True)
    #gene
    sequences_names = np.load(llm_embeddings_path / f"target/{sub_dataset}_target_sequences.npy", allow_pickle=True)
    gene_embeddings = np.load(llm_embeddings_path / f"target/{sub_dataset}_sequence_embeddings.npy", allow_pickle=True)

    
    can_drugs_llm = pd.DataFrame({
        'smiles': smile_names,
        'can_smiles': can_smile_names,
        'drug_llm_embeddings': drug_embeddings.tolist()  
    })

    drugs_llm = can_drugs_llm.dropna().reset_index(drop=True)

    genes_llm = pd.DataFrame({
        'sequences': sequences_names,
        'protein_llm_embeddings': gene_embeddings.tolist()  
    })

    ############################# LOAD  Structural Features ###############################################################
   
    topological_embeddings_path = base_path / f"embeddings/TDA/{dataset_name}"
    
    # protein contact_map 
    protein_indexes = np.load(topological_embeddings_path / f"02_protein_index.npy")
    protein_embeddings = np.load(topological_embeddings_path / f"02_protein_contact_embeddings.npy")
    protein_sturct_embedding_df = pd.DataFrame({"protein_index":  protein_indexes, "protein_struc_embeddings": list(protein_embeddings)})


    # Drug moleculer images
    drug_indexes = np.load(topological_embeddings_path / f"01_mol_names.npy")
    drug_struc_embeddings = np.load(topological_embeddings_path / f"01_mol_image_embeddings.npy")
    drug_struct_embedding_df = pd.DataFrame({"drug_index": drug_indexes, "drug_struc_embeddings": list(drug_struc_embeddings)})


    ############################# LOAD TRAIN VALIDATION AND TEST DATA ###############################################################   
   
    data_path = base_path / "data"
    
    task_paths = {
        "biosnap_random": data_path / "biosnap/random",
        "human_random": data_path / "human/random",
        "human_cold": data_path / "human/cold"
    }

    # Validate dataset name
    if sub_dataset not in task_paths:
        raise ValueError(f"Dataset name {sub_dataset} not available. Available datasets are {list(task_paths.keys())}.")

    dataset_path = task_paths[sub_dataset]
    
    train_file = dataset_path / 'train.csv'
    valid_file = dataset_path / 'val.csv'
    test_file = dataset_path / 'test.csv'

    
    train = pd.read_csv(train_file)
    valid = pd.read_csv(valid_file)
    test = pd.read_csv(test_file)

    train['Set'] = "Train"
    valid['Set'] = "Valid"
    test['Set'] = "Test"

    full_dataset = pd.concat([train, valid, test], ignore_index=True)

    ############################# MERGE EMBEDDINGS ############################################################### 
    #Gene
    genes_index = genes_llm.reset_index()
    genes_index['index'] = genes_index['index'].astype(int)
    protein_sturct_embedding_df['protein_index'] = protein_sturct_embedding_df['protein_index'].astype(int)
    genes = pd.merge(genes_index,  protein_sturct_embedding_df, left_on = 'index', right_on ='protein_index')
    genes.drop(columns=['index', 'protein_index'], inplace = True)
    #Drug
    drugs_index = drugs_llm.reset_index()
    drugs_index['index'] = drugs_index['index'].astype(int)
    drug_struct_embedding_df['drug_index'] = drug_struct_embedding_df['drug_index'].astype(int)
    drugs = pd.merge(drugs_index, drug_struct_embedding_df, left_on = 'index', right_on ='drug_index')
    drugs.drop(columns=['index', 'drug_index', 'can_smiles'], inplace = True)
    

    # Merge with drug and gene embeddings
    drug_data = pd.merge(full_dataset, drugs, left_on='SMILES', right_on='smiles', how='inner')   
    drug_data = drug_data.drop(columns=['smiles'])
    
    # Merge drug data with gene data
    target_data = pd.merge(drug_data, genes, left_on='Protein', right_on='sequences', how='inner')
    target_data = target_data.drop(columns=['sequences'])
   


    return drugs, genes, target_data


In [2]:
dataset_name = 'human'
sub_dataset ='human_cold'

In [3]:
drugs, genes, df = load_dataset(sub_dataset, dataset_name)

In [4]:
drugs

Unnamed: 0,smiles,drug_llm_embeddings,drug_struc_embeddings
0,CC[C@@]1(C[C@@H]2C3=CC(=C(C=C3CCN2C[C@H]1CC(C)...,"[-0.2966340184211731, 0.2303713858127594, 0.70...","[49.0, 13.0, 11.0, 9.0, 9.0, 9.0, 11.0, 11.0, ..."
1,CCCC(=O)C1=CN=CC=C1,"[0.5686405897140503, 0.632786214351654, 0.0171...","[1.0, 2.0, 4.0, 4.0, 5.0, 6.0, 11.0, 11.0, 16...."
2,C[C@H](C[C@@H](C(=O)O)N)C(=O)O,"[-0.8690953254699707, 0.9624016880989075, 0.01...","[64.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,..."
3,CC(C)(CC1=CC=C(C=C1)Cl)N,"[-0.15941157937049866, 0.9059374332427979, 0.1...","[55.0, 11.0, 10.0, 9.0, 9.0, 8.0, 7.0, 9.0, 9...."
4,C[C@H]1[C@@H]([C@H]([C@H]([C@@H](O1)O[C@@H]2[C...,"[0.6342077255249023, -0.3015095293521881, 0.49...","[49.0, 9.0, 6.0, 5.0, 6.0, 7.0, 17.0, 20.0, 22..."
...,...,...,...
1808,C1=CC2=C(C=C1F)C(=CN2)CC(C(=O)O)N,"[-0.13452543318271637, 1.001132845878601, 0.95...","[44.0, 6.0, 6.0, 3.0, 3.0, 3.0, 19.0, 18.0, 16..."
1809,CC(C(=O)OC)N,"[0.48688942193984985, 0.3028434216976166, 0.10...","[12.0, 3.0, 3.0, 2.0, 1.0, 1.0, 4.0, 5.0, 5.0,..."
1810,C[C@@H]1CC2([C@@H]3[C@@](O3)(C(O2)O)C)O[C@@H]4...,"[0.38685643672943115, -0.21707198023796082, 0....","[25.0, 1.0, 2.0, 2.0, 4.0, 13.0, 21.0, 23.0, 2..."
1811,C1CNCCC1C2=CC(=NN2)C3=CC=C(C=C3)Cl,"[-0.19216525554656982, 0.7620844841003418, 0.2...","[69.0, 12.0, 9.0, 8.0, 8.0, 8.0, 8.0, 11.0, 10..."


In [6]:
genes

Unnamed: 0,sequences,protein_llm_embeddings,protein_struc_embeddings
0,MSPLNQSAEGLPQEASNRSLNATETSEAWDPRTLQALKISLAVVLS...,"[-0.012558391317725182, 0.07437873631715775, 0...","[62.0, 62.0, 813.0, 375.0, 375.0, 250.0, 250.0..."
1,MAGAGPKRRALAAPAAEEKEEAREKMLAAKSADGSAPAGEGEGVTL...,"[0.0245230533182621, 0.1540171205997467, 0.041...","[26.0, 26.0, 962.0, 354.0, 354.0, 230.0, 230.0..."
2,MKLKLKNVFLAYFLVSIAGLLYALVQLGQPCDCLPPLRAAAEQLRQ...,"[0.025935091078281403, 0.044150467962026596, 0...","[88.0, 308.0, 171.0, 93.0, 24.0, 23.0, 9.0, 16..."
3,METTPLNSQKQLSACEDGEDCQENGVLQKVVPTPGDKVESGQISNG...,"[0.029940910637378693, 0.12786006927490234, 0....","[79.0, 2015.0, 246.0, 246.0, 98.0, 98.0, 61.0,..."
4,MADERKDEAKAPHWTSAPLTEASAHSHPPEIKDQGGAGEGLVRSAN...,"[0.010107791982591152, -0.07885642349720001, 0...","[56.0, 56.0, 524.0, 524.0, 324.0, 324.0, 324.0..."
...,...,...,...
1498,MEMEKEFEQIDKSGSWAAIYQDIRHEASDFPCRVAKLPKNKNRNRY...,"[0.0005158429266884923, 0.025331098586320877, ...","[45.0, 1521.0, 429.0, 429.0, 217.0, 217.0, 160..."
1499,MGDVEKGKKIFIMKCSQCHTVEKGGKHKTGPNLHGLFGRKTGQAPG...,"[0.054099857807159424, 0.163589745759964, 0.01...","[95.0, 95.0, 1270.0, 792.0, 792.0, 792.0, 150...."
1500,MSSAAEPPPPPPPESAPSKPAASIASGGSNSSNKGGPEGVAAQAVA...,"[0.06637243181467056, 0.08882562816143036, 0.0...","[70.0, 70.0, 1473.0, 566.0, 566.0, 182.0, 182...."
1501,MGAASGRRGPGLLLPLPLLLLLPPQPALALDPGLQPGNFSADEAGA...,"[0.06093117594718933, 0.10112055391073227, 0.0...","[87.0, 1428.0, 417.0, 191.0, 191.0, 86.0, 60.0..."


In [7]:
df

Unnamed: 0,SMILES,Protein,Y,Set,drug_llm_embeddings,drug_struc_embeddings,protein_llm_embeddings,protein_struc_embeddings
0,CC[C@@]1(C[C@@H]2C3=CC(=C(C=C3CCN2C[C@H]1CC(C)...,MSPLNQSAEGLPQEASNRSLNATETSEAWDPRTLQALKISLAVVLS...,0,Train,"[-0.2966340184211731, 0.2303713858127594, 0.70...","[49.0, 13.0, 11.0, 9.0, 9.0, 9.0, 11.0, 11.0, ...","[-0.012558391317725182, 0.07437873631715775, 0...","[62.0, 62.0, 813.0, 375.0, 375.0, 250.0, 250.0..."
1,CCCC(=O)C1=CN=CC=C1,MAGAGPKRRALAAPAAEEKEEAREKMLAAKSADGSAPAGEGEGVTL...,0,Train,"[0.5686405897140503, 0.632786214351654, 0.0171...","[1.0, 2.0, 4.0, 4.0, 5.0, 6.0, 11.0, 11.0, 16....","[0.0245230533182621, 0.1540171205997467, 0.041...","[26.0, 26.0, 962.0, 354.0, 354.0, 230.0, 230.0..."
2,C[C@H](C[C@@H](C(=O)O)N)C(=O)O,MKLKLKNVFLAYFLVSIAGLLYALVQLGQPCDCLPPLRAAAEQLRQ...,0,Train,"[-0.8690953254699707, 0.9624016880989075, 0.01...","[64.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.025935091078281403, 0.044150467962026596, 0...","[88.0, 308.0, 171.0, 93.0, 24.0, 23.0, 9.0, 16..."
3,CC(C)(CC1=CC=C(C=C1)Cl)N,METTPLNSQKQLSACEDGEDCQENGVLQKVVPTPGDKVESGQISNG...,1,Train,"[-0.15941157937049866, 0.9059374332427979, 0.1...","[55.0, 11.0, 10.0, 9.0, 9.0, 8.0, 7.0, 9.0, 9....","[0.029940910637378693, 0.12786006927490234, 0....","[79.0, 2015.0, 246.0, 246.0, 98.0, 98.0, 61.0,..."
4,C[C@H]1[C@@H]([C@H]([C@H]([C@@H](O1)O[C@@H]2[C...,MADERKDEAKAPHWTSAPLTEASAHSHPPEIKDQGGAGEGLVRSAN...,0,Train,"[0.6342077255249023, -0.3015095293521881, 0.49...","[49.0, 9.0, 6.0, 5.0, 6.0, 7.0, 17.0, 20.0, 22...","[0.010107791982591152, -0.07885642349720001, 0...","[56.0, 56.0, 524.0, 524.0, 324.0, 324.0, 324.0..."
...,...,...,...,...,...,...,...,...
3914,C1=CC2=C(C=C1F)C(=CN2)CC(C(=O)O)N,MGAASGRRGPGLLLPLPLLLLLPPQPALALDPGLQPGNFSADEAGA...,0,Test,"[-0.13452543318271637, 1.001132845878601, 0.95...","[44.0, 6.0, 6.0, 3.0, 3.0, 3.0, 19.0, 18.0, 16...","[0.06093117594718933, 0.10112055391073227, 0.0...","[87.0, 1428.0, 417.0, 191.0, 191.0, 86.0, 60.0..."
3915,CC(C(=O)OC)N,MTEQAISFAKDFLAGGIAAAISKTAVAPIERVKLLLQVQHASKQIA...,0,Test,"[0.48688942193984985, 0.3028434216976166, 0.10...","[12.0, 3.0, 3.0, 2.0, 1.0, 1.0, 4.0, 5.0, 5.0,...","[0.054344408214092255, 0.1416536420583725, -0....","[71.0, 71.0, 1294.0, 1294.0, 477.0, 477.0, 477..."
3916,C[C@@H]1CC2([C@@H]3[C@@](O3)(C(O2)O)C)O[C@@H]4...,MLFSALLLEVIWILAADGGQHWTYEGPHGQDHWPASYPECGNNAQS...,0,Test,"[0.38685643672943115, -0.21707198023796082, 0....","[25.0, 1.0, 2.0, 2.0, 4.0, 13.0, 21.0, 23.0, 2...","[0.029615368694067, 0.06942982971668243, -0.01...","[22.0, 22.0, 22.0, 654.0, 654.0, 654.0, 488.0,..."
3917,C1CNCCC1C2=CC(=NN2)C3=CC=C(C=C3)Cl,MAHVRGLQLPGCLALAALCSLVHSQHVFLAPQQARSLLQRVRRANT...,1,Test,"[-0.19216525554656982, 0.7620844841003418, 0.2...","[69.0, 12.0, 9.0, 8.0, 8.0, 8.0, 8.0, 11.0, 10...","[0.012352170422673225, 0.04928702116012573, 0....","[52.0, 957.0, 310.0, 310.0, 251.0, 251.0, 191...."


In [8]:
df.isna().sum()

SMILES                      0
Protein                     0
Y                           0
Set                         0
drug_llm_embeddings         0
drug_struc_embeddings       0
protein_llm_embeddings      0
protein_struc_embeddings    0
dtype: int64

In [9]:
import torch
import random
import numpy as np
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
import pandas as pd
import random


# unique drugs and genes
unique_drug_id = drugs['smiles'].unique()
unique_gene_id = genes['sequences'].unique()

# their ids
drug_id_map = {drug: idx for idx, drug in enumerate(unique_drug_id)}
gene_id_map = {gene: idx for idx, gene in enumerate(unique_gene_id)}

# and embeddings 
drug_llm_embeddings = np.array([np.array(emb) for emb in drugs['drug_llm_embeddings']])
drug_struct_embeddings = np.array([np.array(emb) for emb in drugs['drug_struc_embeddings']])
gene_llm_embeddings = np.array([np.array(emb) for emb in genes['protein_llm_embeddings']])
protein_struct_embeddings = np.array([np.array(emb) for emb in genes['protein_struc_embeddings']])


data = HeteroData()

data["drug"].node_id = torch.arange(len(unique_drug_id))
data["gene"].node_id = torch.arange(len(unique_gene_id))

# node features
data["drug"].xl = torch.tensor(drug_llm_embeddings, dtype=torch.float)
data["drug"].xs = torch.tensor(drug_struct_embeddings, dtype=torch.float)
data["gene"].xl = torch.tensor(gene_llm_embeddings, dtype=torch.float)
data["gene"].xs = torch.tensor(protein_struct_embeddings, dtype=torch.float)



edge_indices, edge_labels, edge_splits = [], [], []

for _, row in df.iterrows():
    drug_idx = drug_id_map[row["SMILES"]]
    gene_idx = gene_id_map[row["Protein"]]
    edge_indices.append([drug_idx, gene_idx])
    edge_labels.append(row["Y"])
    edge_splits.append(row["Set"])

# create tensors
edge_indices = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
edge_labels = torch.tensor(edge_labels, dtype=torch.float)

# and add them to data
data["drug", "interacts_with", "gene"].edge_index = edge_indices
data["drug", "interacts_with", "gene"].edge_label = edge_labels

# Split data according to Set Labels
edge_splits = np.array(edge_splits)
train_mask = torch.tensor(edge_splits == "Train", dtype=torch.bool)
valid_mask = torch.tensor(edge_splits == "Valid", dtype=torch.bool)
test_mask = torch.tensor(edge_splits == "Test", dtype=torch.bool)

data["drug", "interacts_with", "gene"].train_mask = train_mask
data["drug", "interacts_with", "gene"].valid_mask = valid_mask
data["drug", "interacts_with", "gene"].test_mask = test_mask
data = T.ToUndirected()(data)


In [10]:
data

HeteroData(
  drug={
    node_id=[1813],
    xl=[1813, 768],
    xs=[1813, 1200],
  },
  gene={
    node_id=[1503],
    xl=[1503, 1024],
    xs=[1503, 1200],
  },
  (drug, interacts_with, gene)={
    edge_index=[2, 3919],
    edge_label=[3919],
    train_mask=[3919],
    valid_mask=[3919],
    test_mask=[3919],
  },
  (gene, rev_interacts_with, drug)={
    edge_index=[2, 3919],
    edge_label=[3919],
    train_mask=[3919],
    valid_mask=[3919],
    test_mask=[3919],
  }
)

In [11]:
# Full dataset
print("===== Full Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", data['drug', 'interacts_with', 'gene'].num_edges)

# Train
train_mask = data["drug", "interacts_with", "gene"].train_mask
print("\n===== Training Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", train_mask.sum().item())

# Validation
valid_mask = data["drug", "interacts_with", "gene"].valid_mask
print("\n===== Validation Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", valid_mask.sum().item())

# Test 
test_mask = data["drug", "interacts_with", "gene"].test_mask
print("\n===== Test Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", test_mask.sum().item())


===== Full Dataset =====
Number of drug nodes: 1813
Number of gene nodes: 1503
Number of edges: 3919

===== Training Dataset =====
Number of drug nodes: 1813
Number of gene nodes: 1503
Number of edges: 3453

===== Validation Dataset =====
Number of drug nodes: 1813
Number of gene nodes: 1503
Number of edges: 155

===== Test Dataset =====
Number of drug nodes: 1813
Number of gene nodes: 1503
Number of edges: 311


In [12]:
from torch_geometric.loader import DataLoader 
from torch.utils.data import Dataset

class EdgeDataset(Dataset):
    def __init__(self, data: HeteroData, mask: torch.Tensor):
        self.edge_index = data["drug", "interacts_with", "gene"].edge_index[:, mask]
        self.edge_label = data["drug", "interacts_with", "gene"].edge_label[mask]
        self.num_edges = self.edge_index.size(1)

    def __len__(self):
        return self.num_edges

    def __getitem__(self, idx):
        return self.edge_index[:, idx], self.edge_label[idx]


# Custom collate function to prevent automatic batching of edge_index
def custom_collate_fn(batch):
    edge_indices = [item[0] for item in batch]  # Extract edge_index from batch
    edge_labels = [item[1] for item in batch]  # Extract edge_label from batch
    
    # Concatenate edge indices and labels
    edge_index = torch.cat(edge_indices, dim=1)  # Shape: [2, batch_size]
    edge_label = torch.cat(edge_labels, dim=0)   # Shape: [batch_size]

    return edge_index, edge_label


train_dataset = EdgeDataset(data, train_mask)
valid_dataset = EdgeDataset(data, valid_mask)
test_dataset = EdgeDataset(data, test_mask)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)




In [13]:
import torch
from collections import Counter

# Function to count positive and negative edges
def count_edges(edge_labels):
    counter = Counter(edge_labels.tolist())
    return counter[1.0], counter[0.0]  # Positive, Negative


train_labels = train_dataset.edge_label
valid_labels = valid_dataset.edge_label
test_labels = test_dataset.edge_label


print("Total number of nodes:")
print(f"Drugs: {data['drug'].xs.size(0)}")
print(f"Genes: {data['gene'].xs.size(0)}")

print("$"*52)

train_positive, train_negative = count_edges(train_labels)
valid_positive, valid_negative = count_edges(valid_labels)
test_positive, test_negative = count_edges(test_labels)



print(f"Train - Total: {train_labels.size(0)}, Positive: {train_positive}, Negative: {train_negative}")
print(f"Valid - Total: {valid_labels.size(0)}, Positive: {valid_positive}, Negative: {valid_negative}")
print(f"Test  - Total: {test_labels.size(0)}, Positive: {test_positive}, Negative: {test_negative}")


Total number of nodes:
Drugs: 1813
Genes: 1503
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
Train - Total: 3453, Positive: 1815, Negative: 1638
Valid - Total: 155, Positive: 59, Negative: 96
Test  - Total: 311, Positive: 121, Negative: 190


In [14]:
drug_l_input_size= data['drug']['xl'].size(1)
gene_l_input_size= data['gene']['xl'].size(1)
drug_s_input_size= data['drug']['xs'].size(1)
gene_s_input_size= data['gene']['xs'].size(1)

print(f"Drug LLM Input Size: {drug_l_input_size}")
print(f"Gene LLM Input Size: {gene_l_input_size}")
print(f"Drug Structure Input Size: {drug_s_input_size}")
print(f"Gene Structure Input Size: {gene_s_input_size}")


Drug LLM Input Size: 768
Gene LLM Input Size: 1024
Drug Structure Input Size: 1200
Gene Structure Input Size: 1200


# The below model takes structure and llm embeddings

In [15]:
import torch
from torch_geometric.nn import SAGEConv, to_hetero, BatchNorm
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, recall_score, confusion_matrix
import os



# it takes drug and gene llm and structure features and assigns weighs to them durung traning
class FeatureFusion(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.fc = torch.nn.Linear(2 * hidden_channels, hidden_channels)

    def forward(self, llm_features, struct_features):
        concat_features = torch.cat([llm_features, struct_features], dim=-1)
        alpha = torch.sigmoid(self.fc(concat_features)) # values are between 0 and 1
        return alpha * llm_features + (1 - alpha) * struct_features # relative importance of each feature


class Classifier(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.fc1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.fc2 = torch.nn.Linear(hidden_channels, 1)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, drug_emb, gene_emb, edge_index):
        if edge_index.shape[0] != 2:
            edge_index = edge_index.t()

        src, dst = edge_index
        drug_emb_src = drug_emb[src]
        gene_emb_dst = gene_emb[dst]

        edge_emb = torch.cat([drug_emb_src, gene_emb_dst], dim=-1)
        edge_emb = F.relu(self.fc1(edge_emb))
        edge_emb = self.dropout(edge_emb)
        return self.fc2(edge_emb).squeeze(-1)



class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.bn1 = BatchNorm(hidden_channels)
        self.dropout = torch.nn.Dropout(0.5)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.bn2 = BatchNorm(hidden_channels)
        

    def forward(self, x, edge_index):
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        return x



# Main Model
class Model(torch.nn.Module):
    def __init__(self, hidden_channels, drug_l_input_size, gene_l_input_size, drug_s_input_size, gene_s_input_size):
        super().__init__()

        # Drug 
        self.drug_llm_lin = torch.nn.Sequential(
            torch.nn.Linear(drug_l_input_size, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )
        self.drug_struct_lin = torch.nn.Sequential(
            torch.nn.Linear(drug_s_input_size, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )

        # Gene 
        self.gene_llm_lin = torch.nn.Sequential(
            torch.nn.Linear(gene_l_input_size, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )
        self.gene_struct_lin = torch.nn.Sequential(
            torch.nn.Linear(gene_s_input_size, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )

        self.feature_fusion = FeatureFusion(hidden_channels)
        self.gnn = GNN(hidden_channels)
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        self.classifier = Classifier(hidden_channels)

    def forward(self, data, edge_index, edge_label):
        drug_xl = data["drug"].xl.to(device).float()
        drug_xs = data["drug"].xs.to(device).float()
        gene_xl = data["gene"].xl.to(device).float()
        gene_xs = data["gene"].xs.to(device).float()

        # Feature normalization
        drug_xs = (drug_xs - drug_xs.mean(dim=0)) / (drug_xs.std(dim=0, unbiased=False) + 1e-8)
        gene_xs = (gene_xs - gene_xs.mean(dim=0)) / (gene_xs.std(dim=0, unbiased=False) + 1e-8)

        # Feature Fussion
        drug_llm = self.drug_llm_lin(drug_xl)
        drug_struct = self.drug_struct_lin(drug_xs)
        drug_x = self.feature_fusion(drug_llm, drug_struct)

        gene_llm = self.gene_llm_lin(gene_xl)
        gene_struct = self.gene_struct_lin(gene_xs)
        gene_x = self.feature_fusion(gene_llm, gene_struct)

        # Pass through GNN
        x_dict = {"drug": drug_x, "gene": gene_x}
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        drug_emb = x_dict["drug"]
        gene_emb = x_dict["gene"]
        pred = self.classifier(drug_emb, gene_emb, edge_index)

        return pred, edge_label


In [41]:
import os
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, recall_score, confusion_matrix
import numpy as np
import gc

hidden_channels = 128

model = Model(hidden_channels, drug_l_input_size, gene_l_input_size, drug_s_input_size, gene_s_input_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_and_validate(model, data, train_loader, val_loader, device, patience=5, max_epochs=100, dataset_name="default", run_number =1):
    model_save_dir = f"models/{dataset_name}"
    os.makedirs(model_save_dir, exist_ok=True)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data.edge_index_dict = {key: edge_index.to(device) for key, edge_index in data.edge_index_dict.items()}
    model = model.to(device)

    # Calculate positive weight for the loss function
    all_labels = torch.cat([labels for _, labels in train_loader], dim=0)
    negative_to_positive_ratio = len(all_labels[all_labels == 0]) / len(all_labels[all_labels == 1])  # Negatives/Positives
    pos_weight=torch.tensor(negative_to_positive_ratio).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay = 1e-2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 4, eta_min=1e-6)
       
    best_auroc = 0.0
    best_val_loss = float("inf")
    epochs_no_improve = 0

    for epoch in range(1, max_epochs + 1):
        model.train()
        total_train_loss, total_train_examples = 0, 0

        for edge_index, edge_label in train_loader:
            optimizer.zero_grad()
            edge_index, edge_label = edge_index.to(device), edge_label.to(device)

            pred, ground_truth = model(data, edge_index, edge_label)

            loss = F.binary_cross_entropy_with_logits(pred, ground_truth, pos_weight)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * pred.numel()
            total_train_examples += pred.numel()

        avg_train_loss = total_train_loss / total_train_examples

        # Validation
        model.eval()
        total_val_loss, total_val_examples = 0, 0
        preds, ground_truths = [], []

        with torch.no_grad():
            for edge_index, edge_label in val_loader:
                edge_index, edge_label = edge_index.to(device), edge_label.to(device)
                pred, ground_truth = model(data, edge_index, edge_label)
                loss = F.binary_cross_entropy_with_logits(pred, ground_truth)

                total_val_loss += loss.item() * pred.numel()
                total_val_examples += pred.numel()

                preds.append(pred.cpu())
                ground_truths.append(ground_truth.cpu())

        avg_val_loss = total_val_loss / total_val_examples
        preds = torch.cat(preds, dim=0).sigmoid().numpy()
        ground_truths = torch.cat(ground_truths, dim=0).numpy()

        # Metrics
        auroc = roc_auc_score(ground_truths, preds)

        if auroc > best_auroc:
            best_auroc = auroc
            best_epoch = epoch
            model_save_path = os.path.join(model_save_dir, f"{dataset_name}_best_model_run_{run_number+1}.pt")
            torch.save(model.state_dict(), model_save_path)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping
        if epochs_no_improve >= patience:
            break

    return best_auroc

# Evaluate the model on the test set

def evaluate_test_set(model, data, data_loader, device):
    model.eval()
    preds, ground_truths = [], []

    with torch.no_grad():
        for edge_index, edge_label in data_loader:
            edge_index, edge_label = edge_index.to(device), edge_label.to(device)
            pred, _ = model(data, edge_index, edge_label)
            preds.append(pred.cpu())
            ground_truths.append(edge_label.cpu())

    preds = torch.cat(preds, dim=0).sigmoid().numpy()  
    ground_truths = torch.cat(ground_truths, dim=0).numpy()

    auroc = roc_auc_score(ground_truths, preds)
    aupr = average_precision_score(ground_truths, preds)

    return auroc, aupr



auroc_scores, aupr_scores = [], []

for run_number in range(5):
    # Clear memory and set random seed
    gc.collect()
    torch.cuda.empty_cache()
    seed = random.randint(0, 2**32 - 1)  
    torch.manual_seed(seed)

    #Train the model on the train and validation set
    model = Model(hidden_channels, drug_l_input_size, gene_l_input_size, drug_s_input_size, gene_s_input_size)
    train_and_validate(
        model=model,
        data=data,
        train_loader=train_loader,
        val_loader=valid_loader,
        device=device,
        dataset_name=sub_dataset, 
        run_number = run_number
    )

    model_path = f"models/{sub_dataset}/{sub_dataset}_best_model_run_{run_number+1}.pt"
    model.load_state_dict(torch.load(model_path, map_location=device))
    auroc, aupr = evaluate_test_set(model, data, test_loader, device)

    auroc_scores.append(auroc)
    aupr_scores.append(aupr)

    print(f"Run {run_number+1}: AUROC = {auroc:.4f}, AUPRC = {aupr:.4f}")


print("\nFinal Results:")
print(f"AUROC: Mean = {np.mean(auroc_scores):.4f}, Std = {np.std(auroc_scores):.4f}")
print(f"AUPRC: Mean = {np.mean(aupr_scores):.4f}, Std = {np.std(aupr_scores):.4f}")



Run 1: AUROC = 0.9009, AUPRC = 0.8367
Run 2: AUROC = 0.9000, AUPRC = 0.8407
Run 3: AUROC = 0.8905, AUPRC = 0.8354
Run 4: AUROC = 0.9091, AUPRC = 0.8683
Run 5: AUROC = 0.8893, AUPRC = 0.8018

Final Results:
AUROC: Mean = 0.8980, Std = 0.0073
AUPRC: Mean = 0.8366, Std = 0.0211
