In [1]:
import pandas as pd
import numpy as np
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

def load_dataset(dataset_name ):
    base_path =  Path('/bozdagpool/tp0626.unt.ad.unt.edu/Top_DTI_Work/Top_DTI')

    ############################# LOAD LLM Embeddigs ###############################################################
    
    llm_embeddings_path = base_path / f"embeddings/llm/biosnap_random"
    # drug
    smile_names = np.load(llm_embeddings_path / f"drug/biosnap_random_smiles.npy", allow_pickle=True)
    can_smile_names = np.load(llm_embeddings_path / f"drug/biosnap_random_canonical_smiles.npy", allow_pickle=True)
    drug_embeddings = np.load(llm_embeddings_path / f"drug/biosnap_random_molecule_embeddings.npy", allow_pickle=True)
    #gene
    sequences_names = np.load(llm_embeddings_path / f"target/biosnap_random_target_sequences.npy", allow_pickle=True)
    gene_embeddings = np.load(llm_embeddings_path / f"target/biosnap_random_sequence_embeddings.npy", allow_pickle=True)

    
    can_drugs_llm = pd.DataFrame({
        'smiles': smile_names,
        'can_smiles': can_smile_names,
        'drug_llm_embeddings': drug_embeddings.tolist()  
    })

    drugs_llm = can_drugs_llm.dropna().reset_index(drop=True)

    genes_llm = pd.DataFrame({
        'sequences': sequences_names,
        'protein_llm_embeddings': gene_embeddings.tolist()  
    })

    ############################# LOAD  Structural Features ###############################################################
   
    topological_embeddings_path = base_path / f"structure/s_embeddings/{dataset_name}"
    
    # protein contact_map 
    protein_indexes = np.load(topological_embeddings_path / f"protein_index.npy")
    protein_embeddings = np.load(topological_embeddings_path / f"protein_contact_embeddings.npy")
    protein_sturct_embedding_df = pd.DataFrame({"protein_index":  protein_indexes, "protein_struc_embeddings": list(protein_embeddings)})


    # Drug moleculer images
    drug_indexes = np.load(topological_embeddings_path / f"mol_names.npy")
    drug_struc_embeddings = np.load(topological_embeddings_path / f"mol_image_embeddings.npy")
    drug_struct_embedding_df = pd.DataFrame({"drug_index": drug_indexes, "drug_struc_embeddings": list(drug_struc_embeddings)})


    ############################# LOAD TRAIN VALIDATION AND TEST DATA ###############################################################   
   
    data_path = base_path / "datasets"
    
    task_paths = {
        "biosnap_random": data_path / "biosnap/random",  
        "biosnap_unseen_target": data_path / "biosnap/unseen_target", 
        "biosnap_unseen_drug": data_path / "biosnap/unseen_drug", 
        "human_random": data_path / "human/random",
        "human_cold": data_path / "human/cold"
    }




    #############'Biosnap_UNSEEN_TARGET' #############
    dataset_path = data_path /"biosnap" / "unseen_target" 
  

    
    train_file = dataset_path / 'train.csv'
    valid_file = dataset_path / 'val.csv'
    test_file = dataset_path / 'test.csv'

    
    train = pd.read_csv(train_file)
    valid = pd.read_csv(valid_file)
    test = pd.read_csv(test_file)

    train['Set'] = "Train"
    valid['Set'] = "Valid"
    test['Set'] = "Test"

    full_dataset = pd.concat([train, valid, test], ignore_index=True)
   

    ############################# MERGE EMBEDDINGS ############################################################### 
    #Gene
    genes_index = genes_llm.reset_index()
    genes_index['index'] = genes_index['index'].astype(int)
    protein_sturct_embedding_df['protein_index'] = protein_sturct_embedding_df['protein_index'].astype(int)
    genes = pd.merge(genes_index,  protein_sturct_embedding_df, left_on = 'index', right_on ='protein_index')
    genes.drop(columns=['index', 'protein_index'], inplace = True)
    #Drug
    drugs_index = drugs_llm.reset_index()
    drugs_index['index'] = drugs_index['index'].astype(int)
    drug_struct_embedding_df['drug_index'] = drug_struct_embedding_df['drug_index'].astype(int)
    drugs = pd.merge(drugs_index, drug_struct_embedding_df, left_on = 'index', right_on ='drug_index')
    drugs.drop(columns=['index', 'drug_index', 'can_smiles'], inplace = True)
    

    # Merge with drug and gene embeddings
    drug_data = pd.merge(full_dataset, drugs, left_on='SMILES', right_on='smiles', how='inner')   
    drug_data = drug_data.drop(columns=['smiles'])
    
    # Merge drug data with gene data
    target_data = pd.merge(drug_data, genes, left_on='Protein', right_on='sequences', how='inner')
    target_data = target_data.drop(columns=['sequences'])
   


    return drugs, genes, target_data




In [16]:
dataset_name = 'biosnap'
sub_dataset ='biosnap_random'

In [17]:
drugs, genes, df = load_dataset(dataset_name)

In [18]:
drugs

Unnamed: 0,smiles,drug_llm_embeddings,drug_struc_embeddings
0,OP(O)(=O)C(Cl)(Cl)P(O)(O)=O,"[0.598218560218811, 0.6313725113868713, -0.219...","[40.0, 24.0, 24.0, 19.0, 18.0, 18.0, 18.0, 18...."
1,NC1=NC(=O)N(C=N1)[C@H]1C[C@H](O)[C@@H](CO)O1,"[0.21274812519550323, 0.09204189479351044, -0....","[31.0, 10.0, 9.0, 9.0, 10.0, 10.0, 19.0, 19.0,..."
2,OCCCCCCCCNCO,"[1.0989850759506226, 0.13781903684139252, 0.21...","[26.0, 2.0, 1.0, 1.0, 1.0, 1.0, 4.0, 3.0, 4.0,..."
3,C[C@H](OP(O)(O)=O)[C@@H](N)C(O)=O,"[0.2674747705459595, 0.5900732278823853, -0.20...","[4.0, 2.0, 1.0, 1.0, 4.0, 10.0, 19.0, 17.0, 17..."
4,CCO,"[-0.6757968664169312, -1.3385236263275146, -1....","[18.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
...,...,...,...
4500,CCOC(=O)C1=CC2=CC=C(O)C=C2OC1=O,"[0.48438143730163574, -0.019473664462566376, 1...","[57.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0,..."
4501,[H][C@](O)(CN)CNC1=C2N=CC=CC2=C(C=C1)[N+]([O-])=O,"[0.17140626907348633, 0.22953173518180847, 0.6...","[53.0, 7.0, 7.0, 7.0, 7.0, 7.0, 22.0, 20.0, 20..."
4502,CC1NC2=CC(Cl)=C(C=C2C(=O)N1C1=CC=CC=C1C)S(N)(=...,"[-0.0021631531417369843, 0.5341809391975403, 0...","[72.0, 11.0, 10.0, 9.0, 8.0, 8.0, 16.0, 16.0, ..."
4503,CCC1=NN(C2=C1C(=O)NC(CC1=CC(O)=C(O)C=C1)=N2)C1...,"[0.3830249607563019, 0.43456849455833435, 0.90...","[80.0, 20.0, 14.0, 13.0, 11.0, 11.0, 21.0, 18...."


In [19]:
genes

Unnamed: 0,sequences,protein_llm_embeddings,protein_struc_embeddings
0,MGDHAWSFLKDFLAGGVAAAVSKTAVAPIERVKLLLQVQHASKQIS...,"[0.040794070810079575, 0.1398317515850067, -0....","[218.0, 218.0, 1020.0, 356.0, 356.0, 128.0, 12..."
1,MVLDLDLFRVDKGGDPALIRETQEKRFKDPGLVDQLVKADSEWRRC...,"[0.07856228947639465, 0.09228259325027466, 0.0...","[9.0, 9.0, 960.0, 290.0, 290.0, 387.0, 387.0, ..."
2,MGNLKSVAQEPGPPCGLGLGLGLGLCGKQGPATPAPEPSRAPASLL...,"[0.030257243663072586, 0.09058675915002823, 0....","[48.0, 48.0, 2009.0, 2009.0, 143.0, 143.0, 143..."
3,MGNAAAAKKGSEQESVKEFLAKAKEDFLKKWESPAQNTAHLDQFER...,"[0.07570360600948334, 0.11278703063726425, 0.0...","[110.0, 110.0, 1074.0, 532.0, 532.0, 532.0, 12..."
4,MVNENTRMYIPEENHQGSNYGSPRPAHANMNANAAAGLAPEHIPTP...,"[0.07552585750818253, 0.09334281831979752, 0.0...","[61.0, 61.0, 2032.0, 2032.0, 154.0, 154.0, 154..."
...,...,...,...
2176,MFSMRIVCLVLSVVGTAWTADSGEGDFLAEGGGVRGPRVVERHQSA...,"[0.028062790632247925, 0.04501698911190033, 0....","[88.0, 587.0, 272.0, 149.0, 149.0, 97.0, 88.0,..."
2177,MAETVADTRRLITKPQNLNDAYGPPSNFLEIDVSNPQTVGVGRGRF...,"[0.03317379578948021, 0.05994739383459091, 0.0...","[30.0, 30.0, 148.0, 148.0, 130.0, 130.0, 130.0..."
2178,MAEGAAGREDPAPPDAAGGEDDPRVGPDAAGDCVTAASGGRMRDRR...,"[0.002608957001939416, 0.01860393024981022, 0....","[61.0, 61.0, 1882.0, 263.0, 263.0, 263.0, 129...."
2179,MKMLTRLQVLTLALFSKGFLLSLGDHNFLRREIKIEGDLVLGGLFP...,"[0.032084159553050995, 0.12270920723676682, 0....","[61.0, 61.0, 1728.0, 391.0, 391.0, 391.0, 143...."


In [20]:
df

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,DrugBank ID,Gene,Y,SMILES,Protein,Set,drug_llm_embeddings,drug_struc_embeddings,protein_llm_embeddings,protein_struc_embeddings
0,0,1,2,DB00850,P21728,1,OCCN1CCN(CCCN2C3=CC=CC=C3SC3=C2C=C(Cl)C=C3)CC1,MRTLNTSAMDGTGLVVERDFSVRILTACFLSLLILSTLLGNTLVCA...,Train,"[0.23164264857769012, 0.4718208312988281, 0.76...","[70.0, 23.0, 11.0, 11.0, 11.0, 11.0, 11.0, 17....","[0.007180177606642246, 0.07590547949075699, 0....","[96.0, 1606.0, 450.0, 450.0, 162.0, 144.0, 114..."
1,1,2,3,DB00629,P08913,1,NC(N)=NN=CC1=C(Cl)C=CC=C1Cl,MGSLQPDAGNASWNGTEAPGGGARATPYSLQVTLTLVCLAGLLMLL...,Train,"[0.3290030062198639, 0.6285668611526489, 0.908...","[45.0, 10.0, 8.0, 6.0, 6.0, 6.0, 28.0, 27.0, 3...","[0.005615565925836563, 0.06402911245822906, 0....","[89.0, 1211.0, 598.0, 598.0, 164.0, 144.0, 126..."
2,2,4,5,DB00755,P48443,1,C\C(\C=C\C1=C(C)CCCC1(C)C)=C/C=C/C(/C)=C/C(O)=O,MYGNYSHFMKFPAGYGGSPGHTGSTSMSPSAALSTGKPMDSHPSYT...,Train,"[0.8176257014274597, 0.12078941613435745, 0.60...","[65.0, 10.0, 7.0, 7.0, 6.0, 5.0, 5.0, 5.0, 5.0...","[0.010951308533549309, 0.08242924511432648, 0....","[66.0, 66.0, 492.0, 491.0, 491.0, 336.0, 336.0..."
3,3,7,8,DB01136,P08588,1,COC1=CC=CC=C1OCCNCC(O)COC1=CC=CC2=C1C1=CC=CC=C1N2,MGAGVLVLGASEPGNLSSAAPLPDGAATAARLLVPASPPASLLPPA...,Train,"[0.19162507355213165, 0.035214636474847794, 0....","[77.0, 20.0, 17.0, 15.0, 14.0, 13.0, 18.0, 15....","[-0.016087913885712624, 0.06223440170288086, 0...","[96.0, 1497.0, 443.0, 443.0, 172.0, 172.0, 133..."
4,4,9,10,DB05265,P63000,0,[H][C@@]12CCC3=CC(C(C)C)=C(C=C3[C@@]1(C)CCC[C@...,MQAIKCVVVGDGAVGKTCLLISYTTNAFPGEYIPTVFDNYSANVMV...,Train,"[0.05757851526141167, 0.14686033129692078, 0.0...","[76.0, 11.0, 10.0, 8.0, 6.0, 7.0, 7.0, 6.0, 9....","[0.08401910215616226, 0.20102143287658691, 0.0...","[280.0, 280.0, 400.0, 400.0, 236.0, 236.0, 236..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
27459,5335,27543,27574,DB00186,P28472,1,OC1N=C(C2=CC=CC=C2Cl)C2=C(NC1=O)C=CC(Cl)=C2,MWGLAGGRLFGIFSAPVLVAVVCCAQSVNDPGNMSFVKETVDKLLK...,Test,"[0.015055220574140549, 0.6198558211326599, 0.7...","[95.0, 15.0, 11.0, 10.0, 9.0, 8.0, 20.0, 17.0,...","[0.05267154425382614, 0.11373468488454819, 0.0...","[27.0, 1350.0, 523.0, 523.0, 237.0, 203.0, 177..."
27460,5336,27547,27578,DB07242,P09238,0,[H][C@@]1(C#N)C2=C(N(C)C3=C(Cl)C(Cl)=CC=C23)C(...,MMHLAFLVLLCLPVCSAYPLSGAAKEEDSNKDLAQQYLEKYYNLEK...,Test,"[-0.2760624289512634, 0.39136213064193726, 0.2...","[65.0, 14.0, 11.0, 9.0, 10.0, 10.0, 27.0, 21.0...","[0.01586500182747841, 0.03571433946490288, 0.0...","[94.0, 94.0, 908.0, 558.0, 558.0, 396.0, 396.0..."
27461,5337,27548,27579,DB01006,A8MPY1,0,N#CC1=CC=C(C=C1)C(N1C=NC=N1)C1=CC=C(C=C1)C#N,MVLAFQLVSFTYIWIILKPNVCAASNIKMTHQRCSSSMKQTCKQET...,Test,"[0.6828752756118774, 0.40455248951911926, 0.51...","[84.0, 17.0, 15.0, 14.0, 14.0, 14.0, 30.0, 29....","[0.028850406408309937, 0.08950702100992203, -0...","[28.0, 1361.0, 466.0, 264.0, 264.0, 130.0, 110..."
27462,5338,27550,27581,DB07321,P09467,1,COC1=CC2=C(OC(NS(=O)(=O)C3=CC(Cl)=CC=C3Cl)=N2)...,MADQAPFDTDVNTLTRFVMEEGRKARGTGELTQLLNSLCTAVKAIS...,Test,"[-0.12703728675842285, 0.6404566764831543, 0.3...","[84.0, 23.0, 15.0, 15.0, 15.0, 15.0, 21.0, 20....","[0.06200770288705826, 0.09202540665864944, 0.0...","[50.0, 50.0, 50.0, 1020.0, 1020.0, 1020.0, 380..."


In [21]:
df.isna().sum()

Unnamed: 0.2                0
Unnamed: 0                  0
Unnamed: 0.1                0
DrugBank ID                 0
Gene                        0
Y                           0
SMILES                      0
Protein                     0
Set                         0
drug_llm_embeddings         0
drug_struc_embeddings       0
protein_llm_embeddings      0
protein_struc_embeddings    0
dtype: int64

In [22]:
import torch
import random
import numpy as np
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
import pandas as pd
import random
random.seed(42)

# unique drugs and genes
unique_drug_id = drugs['smiles'].unique()
unique_gene_id = genes['sequences'].unique()

# their ids
drug_id_map = {drug: idx for idx, drug in enumerate(unique_drug_id)}
gene_id_map = {gene: idx for idx, gene in enumerate(unique_gene_id)}

# and embeddings 
drug_llm_embeddings = np.array([np.array(emb) for emb in drugs['drug_llm_embeddings']])
drug_struct_embeddings = np.array([np.array(emb) for emb in drugs['drug_struc_embeddings']])
gene_llm_embeddings = np.array([np.array(emb) for emb in genes['protein_llm_embeddings']])
protein_struct_embeddings = np.array([np.array(emb) for emb in genes['protein_struc_embeddings']])


data = HeteroData()

data["drug"].node_id = torch.arange(len(unique_drug_id))
data["gene"].node_id = torch.arange(len(unique_gene_id))

# node features
data["drug"].xl = torch.tensor(drug_llm_embeddings, dtype=torch.float)
data["drug"].xs = torch.tensor(drug_struct_embeddings, dtype=torch.float)
data["gene"].xl = torch.tensor(gene_llm_embeddings, dtype=torch.float)
data["gene"].xs = torch.tensor(protein_struct_embeddings, dtype=torch.float)



edge_indices, edge_labels, edge_splits = [], [], []

for _, row in df.iterrows():
    drug_idx = drug_id_map[row["SMILES"]]
    gene_idx = gene_id_map[row["Protein"]]
    edge_indices.append([drug_idx, gene_idx])
    edge_labels.append(row["Y"])
    edge_splits.append(row["Set"])

# create tensors
edge_indices = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
edge_labels = torch.tensor(edge_labels, dtype=torch.float)

# and add them to data
data["drug", "interacts_with", "gene"].edge_index = edge_indices
data["drug", "interacts_with", "gene"].edge_label = edge_labels

# Split data according to Set Labels
edge_splits = np.array(edge_splits)
train_mask = torch.tensor(edge_splits == "Train", dtype=torch.bool)
valid_mask = torch.tensor(edge_splits == "Valid", dtype=torch.bool)
test_mask = torch.tensor(edge_splits == "Test", dtype=torch.bool)

data["drug", "interacts_with", "gene"].train_mask = train_mask
data["drug", "interacts_with", "gene"].valid_mask = valid_mask
data["drug", "interacts_with", "gene"].test_mask = test_mask
data = T.ToUndirected()(data)


In [23]:
data

HeteroData(
  drug={
    node_id=[4505],
    xl=[4505, 768],
    xs=[4505, 1200],
  },
  gene={
    node_id=[2181],
    xl=[2181, 1024],
    xs=[2181, 1200],
  },
  (drug, interacts_with, gene)={
    edge_index=[2, 27464],
    edge_label=[27464],
    train_mask=[27464],
    valid_mask=[27464],
    test_mask=[27464],
  },
  (gene, rev_interacts_with, drug)={
    edge_index=[2, 27464],
    edge_label=[27464],
    train_mask=[27464],
    valid_mask=[27464],
    test_mask=[27464],
  }
)

In [24]:
# Full dataset
print("===== Full Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", data['drug', 'interacts_with', 'gene'].num_edges)

# Train
train_mask = data["drug", "interacts_with", "gene"].train_mask
print("\n===== Training Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", train_mask.sum().item())

# Validation
valid_mask = data["drug", "interacts_with", "gene"].valid_mask
print("\n===== Validation Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", valid_mask.sum().item())

# Test 
test_mask = data["drug", "interacts_with", "gene"].test_mask
print("\n===== Test Dataset =====")
print("Number of drug nodes:", data['drug'].num_nodes)
print("Number of gene nodes:", data['gene'].num_nodes)
print("Number of edges:", test_mask.sum().item())


===== Full Dataset =====
Number of drug nodes: 4505
Number of gene nodes: 2181
Number of edges: 27464

===== Training Dataset =====
Number of drug nodes: 4505
Number of gene nodes: 2181
Number of edges: 19361

===== Validation Dataset =====
Number of drug nodes: 4505
Number of gene nodes: 2181
Number of edges: 2768

===== Test Dataset =====
Number of drug nodes: 4505
Number of gene nodes: 2181
Number of edges: 5335


In [25]:
from torch_geometric.loader import DataLoader 
from torch.utils.data import Dataset

class EdgeDataset(Dataset):
    def __init__(self, data: HeteroData, mask: torch.Tensor):
        self.edge_index = data["drug", "interacts_with", "gene"].edge_index[:, mask]
        self.edge_label = data["drug", "interacts_with", "gene"].edge_label[mask]
        self.num_edges = self.edge_index.size(1)

    def __len__(self):
        return self.num_edges

    def __getitem__(self, idx):
        return self.edge_index[:, idx], self.edge_label[idx]


# Custom collate function to prevent automatic batching of edge_index
def custom_collate_fn(batch):
    edge_indices = [item[0] for item in batch]  # Extract edge_index from batch
    edge_labels = [item[1] for item in batch]  # Extract edge_label from batch
    
    # Concatenate edge indices and labels
    edge_index = torch.cat(edge_indices, dim=1)  # Shape: [2, batch_size]
    edge_label = torch.cat(edge_labels, dim=0)   # Shape: [batch_size]

    return edge_index, edge_label


train_dataset = EdgeDataset(data, train_mask)
valid_dataset = EdgeDataset(data, valid_mask)
test_dataset = EdgeDataset(data, test_mask)


train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=custom_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate_fn)




In [26]:
import torch
from collections import Counter

# Function to count positive and negative edges
def count_edges(edge_labels):
    counter = Counter(edge_labels.tolist())
    return counter[1.0], counter[0.0]  # Positive, Negative


train_labels = train_dataset.edge_label
valid_labels = valid_dataset.edge_label
test_labels = test_dataset.edge_label


print("Total number of nodes:")
print(f"Drugs: {data['drug'].xs.size(0)}")
print(f"Genes: {data['gene'].xs.size(0)}")

print("$"*52)

train_positive, train_negative = count_edges(train_labels)
valid_positive, valid_negative = count_edges(valid_labels)
test_positive, test_negative = count_edges(test_labels)



print(f"Train - Total: {train_labels.size(0)}, Positive: {train_positive}, Negative: {train_negative}")
print(f"Valid - Total: {valid_labels.size(0)}, Positive: {valid_positive}, Negative: {valid_negative}")
print(f"Test  - Total: {test_labels.size(0)}, Positive: {test_positive}, Negative: {test_negative}")


Total number of nodes:
Drugs: 4505
Genes: 2181
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
Train - Total: 19361, Positive: 9872, Negative: 9489
Valid - Total: 2768, Positive: 1382, Negative: 1386
Test  - Total: 5335, Positive: 2576, Negative: 2759


In [27]:
drug_l_input_size= data['drug']['xl'].size(1)
gene_l_input_size= data['gene']['xl'].size(1)
drug_s_input_size= data['drug']['xs'].size(1)
gene_s_input_size= data['gene']['xs'].size(1)

print(f"Drug LLM Input Size: {drug_l_input_size}")
print(f"Gene LLM Input Size: {gene_l_input_size}")
print(f"Drug Structure Input Size: {drug_s_input_size}")
print(f"Gene Structure Input Size: {gene_s_input_size}")


Drug LLM Input Size: 768
Gene LLM Input Size: 1024
Drug Structure Input Size: 1200
Gene Structure Input Size: 1200


# The below model takes structure and llm embeddings

In [28]:
import torch
from torch_geometric.nn import SAGEConv, to_hetero, BatchNorm
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, recall_score, confusion_matrix
import os


torch.manual_seed(42)

# it takes drug and gene llm and structure features and assigns weighs to them durung traning
class FeatureFusion(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.fc = torch.nn.Linear(2 * hidden_channels, hidden_channels)

    def forward(self, llm_features, struct_features):
        concat_features = torch.cat([llm_features, struct_features], dim=-1)
        alpha = torch.sigmoid(self.fc(concat_features)) # values are between 0 and 1 because of Sigmoid function
        return alpha * llm_features + (1 - alpha) * struct_features # relative importance of each feature


class Classifier(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.fc1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.fc2 = torch.nn.Linear(hidden_channels, 1)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, drug_emb, gene_emb, edge_index):
        if edge_index.shape[0] != 2:
            edge_index = edge_index.t()

        src, dst = edge_index
        drug_emb_src = drug_emb[src]
        gene_emb_dst = gene_emb[dst]

        edge_emb = torch.cat([drug_emb_src, gene_emb_dst], dim=-1)
        edge_emb = F.relu(self.fc1(edge_emb))
        edge_emb = self.dropout(edge_emb)
        return self.fc2(edge_emb).squeeze(-1)



class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.bn1 = BatchNorm(hidden_channels)
        self.dropout = torch.nn.Dropout(0.5)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.bn2 = BatchNorm(hidden_channels)
        

    def forward(self, x, edge_index):
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        return x



# Main Model
class Model(torch.nn.Module):
    def __init__(self, hidden_channels, drug_l_input_size, gene_l_input_size, drug_s_input_size, gene_s_input_size):
        super().__init__()

        # Drug 
        self.drug_llm_lin = torch.nn.Sequential(
            torch.nn.Linear(drug_l_input_size, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )
        self.drug_struct_lin = torch.nn.Sequential(
            torch.nn.Linear(drug_s_input_size, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )

        # Gene 
        self.gene_llm_lin = torch.nn.Sequential(
            torch.nn.Linear(gene_l_input_size, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )
        self.gene_struct_lin = torch.nn.Sequential(
            torch.nn.Linear(gene_s_input_size, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, hidden_channels),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU()
        )

        self.feature_fusion = FeatureFusion(hidden_channels)
        self.gnn = GNN(hidden_channels)
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        self.classifier = Classifier(hidden_channels)

    def forward(self, data, edge_index, edge_label):
        drug_xl = data["drug"].xl.to(device).float()
        drug_xs = data["drug"].xs.to(device).float()
        gene_xl = data["gene"].xl.to(device).float()
        gene_xs = data["gene"].xs.to(device).float()

        # Feature normalization
        drug_xs = (drug_xs - drug_xs.mean(dim=0)) / (drug_xs.std(dim=0, unbiased=False) + 1e-8)
        gene_xs = (gene_xs - gene_xs.mean(dim=0)) / (gene_xs.std(dim=0, unbiased=False) + 1e-8)

        # Feature Fussion
        drug_llm = self.drug_llm_lin(drug_xl)
        drug_struct = self.drug_struct_lin(drug_xs)
        drug_x = self.feature_fusion(drug_llm, drug_struct)

        gene_llm = self.gene_llm_lin(gene_xl)
        gene_struct = self.gene_struct_lin(gene_xs)
        gene_x = self.feature_fusion(gene_llm, gene_struct)

        # Pass through GNN
        x_dict = {"drug": drug_x, "gene": gene_x}
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        drug_emb = x_dict["drug"]
        gene_emb = x_dict["gene"]
        pred = self.classifier(drug_emb, gene_emb, edge_index)

        return pred, edge_label


In [29]:
import os
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, recall_score, confusion_matrix
import numpy as np
import gc
import random

# Training function

def train_and_validate(model, data, train_loader, val_loader, device, patience=5, max_epochs=100, dataset_name="default", run_id=1):
    model_save_dir = f"models/{dataset_name}"
    os.makedirs(model_save_dir, exist_ok=True)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data.edge_index_dict = {key: edge_index.to(device) for key, edge_index in data.edge_index_dict.items()}
    model = model.to(device)

    # Calculate positive weight for the loss function
    all_labels = torch.cat([labels for _, labels in train_loader], dim=0)
    negative_to_positive_ratio = len(all_labels[all_labels == 0]) / len(all_labels[all_labels == 1])  # Negatives/Positives
    pos_weight=torch.tensor(negative_to_positive_ratio).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay = 1e-2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 4, eta_min=1e-6)
       
    best_auroc = 0.0
    best_val_loss = float("inf")
    epochs_no_improve = 0

    for epoch in range(1, max_epochs + 1):
        model.train()
        total_train_loss, total_train_examples = 0, 0

        for edge_index, edge_label in train_loader:
            optimizer.zero_grad()
            edge_index, edge_label = edge_index.to(device), edge_label.to(device)

            pred, ground_truth = model(data, edge_index, edge_label)

            loss = F.binary_cross_entropy_with_logits(pred, ground_truth, pos_weight)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * pred.numel()
            total_train_examples += pred.numel()

        avg_train_loss = total_train_loss / total_train_examples

        # Validation
        model.eval()
        total_val_loss, total_val_examples = 0, 0
        preds, ground_truths = [], []

        with torch.no_grad():
            for edge_index, edge_label in val_loader:
                edge_index, edge_label = edge_index.to(device), edge_label.to(device)
                pred, ground_truth = model(data, edge_index, edge_label)
                loss = F.binary_cross_entropy_with_logits(pred, ground_truth)

                total_val_loss += loss.item() * pred.numel()
                total_val_examples += pred.numel()

                preds.append(pred.cpu())
                ground_truths.append(ground_truth.cpu())

        avg_val_loss = total_val_loss / total_val_examples
        preds = torch.cat(preds, dim=0).sigmoid().numpy()
        ground_truths = torch.cat(ground_truths, dim=0).numpy()

        # Metrics
        auroc = roc_auc_score(ground_truths, preds)

        if auroc > best_auroc:
            best_auroc = auroc
            best_epoch = epoch
            model_save_path = os.path.join(model_save_dir, f"{dataset_name}_2best_model_run_{run_number+1}.pt")
            torch.save(model.state_dict(), model_save_path)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping
        if epochs_no_improve >= patience:
            break

    return best_auroc

# Evaluate the model on the test set

def evaluate_test_set(model, data, data_loader, device):
    model.eval()
    preds, ground_truths = [], []

    with torch.no_grad():
        for edge_index, edge_label in data_loader:
            edge_index, edge_label = edge_index.to(device), edge_label.to(device)
            pred, _ = model(data, edge_index, edge_label)
            preds.append(pred.cpu())
            ground_truths.append(edge_label.cpu())

    preds = torch.cat(preds, dim=0).sigmoid().numpy()  
    ground_truths = torch.cat(ground_truths, dim=0).numpy()

    binary_preds = (preds > 0.6).astype(int)

    auroc = roc_auc_score(ground_truths, preds)
    aupr = average_precision_score(ground_truths, preds)
    accuracy = accuracy_score(ground_truths, binary_preds)
    sensitivity = recall_score(ground_truths, binary_preds)
    tn, fp, fn, tp = confusion_matrix(ground_truths, binary_preds).ravel()
    specificity = tn / (tn + fp)

    return auroc, aupr, accuracy, sensitivity, specificity

hidden_channels = 512

model = Model(hidden_channels, drug_l_input_size, gene_l_input_size, drug_s_input_size, gene_s_input_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

auroc_scores, aupr_scores = [], []
accuracy_scores, sensitivity_scores, specificity_scores = [], [], []

for run_number in range(5):
    # Clear memory and set random seed
    gc.collect()
    torch.cuda.empty_cache()
    seed= random.randint(0, 2**32 - 1)  
    torch.manual_seed(seed)


    #Train the model on the train and validation set
    model = Model(hidden_channels, drug_l_input_size, gene_l_input_size, drug_s_input_size, gene_s_input_size)
    train_and_validate(
        model=model,
        data=data,
        train_loader=train_loader,
        val_loader=valid_loader,
        device=device,
        dataset_name=sub_dataset,
        run_id=run_number
    )

    
    model_path = f"models/{sub_dataset}/{sub_dataset}_2best_model_run_{run_number+1}.pt"
    model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    auroc, aupr, accuracy, sensitivity, specificity = evaluate_test_set(model, data, test_loader, device)

    auroc_scores.append(auroc)
    aupr_scores.append(aupr)
    accuracy_scores.append(accuracy)
    sensitivity_scores.append(sensitivity)
    specificity_scores.append(specificity)

    print(f"Run {run_number+1}: AUROC = {auroc:.4f}, AUPRC = {aupr:.4f}, Accuracy = {accuracy:.4f}, Sensitivity = {sensitivity:.4f}, Specificity = {specificity:.4f}")

# Calculate mean and standard deviation
print("\nFinal Results:")
print(f"AUROC: Mean = {np.mean(auroc_scores):.4f}, Std = {np.std(auroc_scores):.4f}")
print(f"AUPRC: Mean = {np.mean(aupr_scores):.4f}, Std = {np.std(aupr_scores):.4f}")
print(f"Accuracy: Mean = {np.mean(accuracy_scores):.4f}, Std = {np.std(accuracy_scores):.4f}")
print(f"Sensitivity: Mean = {np.mean(sensitivity_scores):.4f}, Std = {np.std(sensitivity_scores):.4f}")
print(f"Specificity: Mean = {np.mean(specificity_scores):.4f}, Std = {np.std(specificity_scores):.4f}")


Run 1: AUROC = 0.9072, AUPRC = 0.9093, Accuracy = 0.8300, Sensitivity = 0.7880, Specificity = 0.8692
Run 2: AUROC = 0.9141, AUPRC = 0.9136, Accuracy = 0.8382, Sensitivity = 0.8078, Specificity = 0.8666
Run 3: AUROC = 0.9076, AUPRC = 0.9060, Accuracy = 0.8219, Sensitivity = 0.7818, Specificity = 0.8594
Run 4: AUROC = 0.9099, AUPRC = 0.9093, Accuracy = 0.8285, Sensitivity = 0.7849, Specificity = 0.8692
Run 5: AUROC = 0.9039, AUPRC = 0.9050, Accuracy = 0.8217, Sensitivity = 0.7745, Specificity = 0.8659

Final Results:
AUROC: Mean = 0.9085, Std = 0.0034
AUPRC: Mean = 0.9086, Std = 0.0030
Accuracy: Mean = 0.8281, Std = 0.0061
Sensitivity: Mean = 0.7874, Std = 0.0112
Specificity: Mean = 0.8660, Std = 0.0036
