DRKG

Adapted from: https://github.com/gnn4dr/DRKG/blob/master/drkg_with_dgl/loading_drkg_in_dgl.ipynb

https://medium.com/@pytorch_geometric/link-prediction-on-heterogeneous-graphs-with-pyg-6d5c29677c70

In [2]:
import pandas as pd
import numpy as np
import os 

In [3]:
from SIMP_LLM.DRKG_loading   import  get_triplets, read_tsv,filter_drkg,map_drkg_relationships,filter_interaction_subset,print_head
from SIMP_LLM.DRKG_translate import  load_lookups

  from .autonotebook import tqdm as notebook_tqdm


# 1) Load Data

In [4]:
### 1) Read: This section reads DRKG and a glossary (used to map entities from codes to words)
DATA_DIR           = os.path.join("data")
verbose            =  True 
triplets,drkg_df   =  get_triplets(drkg_file = os.path.join(DATA_DIR  ,'drkg.tsv'),             verbose=verbose)  # Read triplets (head,relationship,tail)
relation_glossary  =  read_tsv(relation_file = os.path.join(DATA_DIR  ,'relation_glossary.tsv'),verbose=verbose)  # Read relationship mapping  


### 2) Filter & Map Interactions: This section returns a list of interactions e.g. DRUGBANK::treats::Compound:Disease )
# 2.1: First  we filter the interactions to only Compound-Disease
# 2.2: Then   we map the codes -> text  (this will be use to further filter interactions based on text) e.g.  Hetionet::CpD::Compound:Disease -> palliation
# 2.3: We use natural text to fitler  interactions based on terms such as "treat" (but we return the orignal interaction name )



# modularize this in create_dataframe
drkg_rx_dx_relations        = filter_drkg(data_frame = drkg_df ,  filter_column = 1 ,  filter_term = r'.*?Compound:Disease', verbose = verbose) # 2.1 Filter only Compound-Disease Interactions
drkg_rx_dx_relations_mapped = map_drkg_relationships(drkg_rx_dx_relations,relation_glossary,verbose=verbose)                                    # 2.2 Map codes to text 

### 2.3 Filter Drug interactions Interaction types to only include: treat inhibit or alleviate interactions  ###
drkg_rx_dx_relation_subset =  filter_interaction_subset(df                  = drkg_rx_dx_relations_mapped,
                                                        filter_colunm_name = 'Interaction-type' ,
                                                        regex_string       =  'treat|inhibit|alleviate',
                                                        return_colunm_name =  'Relation-name')

# 3) Use Filter Interactions to get Gilter DRKG 
drkg_df_filtered = drkg_df[drkg_df[1].isin(drkg_rx_dx_relation_subset)] # 3.1 Filter DRKG  to only  Compund-Disease 
print_head(df=drkg_df_filtered)



###

rx_dx_triplets   = drkg_df_filtered.values.tolist()                     # 3.2 Convert filtered DRKG to list


 Triplets:

[['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::2157'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::5264'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::2158'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::3309'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::28912'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::811'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::2159'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::821'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::5627'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::5624']]

 data/drkg.tsv  Dataframe:

+----+------------+--------------------------------+-------------+
|    | 0          | 1                              | 2           |
|----+------------+--------------------------------+-------------|
|  0 | Gene::2157 | bioarx::HumGenHumGen:Gene:Gene | Gene::2157  |
|  1 | Gene::2157 | bioarx::HumGenHumGen:Gene:Gene

In [5]:
# 4) Load Data frames for translation
hetionet_df, gene_df, drugbank_df, omim_df, mesh_dict, chebi_df, chembl_df = load_lookups(data_path=DATA_DIR,verbose=verbose)


 data/hetionet-v1.0-nodes.tsv  Dataframe:

+----+-------------------------+---------------------------+---------+
|    | id                      | name                      | kind    |
|----+-------------------------+---------------------------+---------|
|  0 | Anatomy::UBERON:0000002 | uterine cervix            | Anatomy |
|  1 | Anatomy::UBERON:0000004 | nose                      | Anatomy |
|  2 | Anatomy::UBERON:0000006 | islet of Langerhans       | Anatomy |
|  3 | Anatomy::UBERON:0000007 | pituitary gland           | Anatomy |
|  4 | Anatomy::UBERON:0000010 | peripheral nervous system | Anatomy |
+----+-------------------------+---------------------------+---------+

 Sample of Hetionet Data Types (Before processing):

+-------+----------------------------------+-------------------------------------------+---------------------+
|       | id                               | name                                      | kind                |
|-------+--------------------------------

In [6]:
# Make dictionaries
relation_glossary_relation_dict = pd.Series(relation_glossary['Interaction-type'].values, index=relation_glossary['Relation-name']).to_dict()

node_df   = pd.concat([hetionet_df[['name', 'id']], 
                       gene_df.rename(columns = {"symbol":"name", "GeneID":"id"}),
                       drugbank_df.rename(columns = {"Common name":"name", "DrugBank ID":"id"}),
                       omim_df.rename(columns = {"MIM Number":"id"}),
                       chebi_df.rename(columns = {"NAME":"name", "CHEBI_ACCESSION":"id"}),
                       chembl_df.rename(columns = {"pref_name":"name", "chembl_id":"id"})
                       ], ignore_index=True, axis=0).drop_duplicates() 
node_dict = pd.Series(node_df['name'].values, index=node_df['id']).to_dict() | mesh_dict # Convert node df to dict and merge with MeSH dictionary


# Create and use convert_entitynames function
def convert_entitynames(df, col, node_dict):
  df_update = df.copy()
  df_update[col] = df_update[col].str.replace(r'.*?MESH:', "MESH::", regex=True) # Remove MeSH labeling
  df_update[col] = df_update[col].map(node_dict).fillna(df_update[col])    # Translate dictionary
  df_update[col] = df_update[col].str.replace("Gene::", "Gene ID ") # For remaining uncoverted Gene IDs, remove "::"
  df_update[col] = df_update[col].str.replace("Disease::", "") # For remaining diseases (appears to be just SARS-COVID related names), remove label
  return df_update

df_med    = drkg_df.copy()
df_med[1] = df_med[1].map(relation_glossary_relation_dict).fillna(df_med[1])

df_med = convert_entitynames(df_med, 0, node_dict)
df_med = convert_entitynames(df_med, 2, node_dict)

print_head(df_med) 


+----+--------------+-------------+---------------+
|    | 0            | 1           | 2             |
|----+--------------+-------------+---------------|
|  0 | Gene ID 2157 | interaction | Gene ID 2157  |
|  1 | Gene ID 2157 | interaction | Gene ID 5264  |
|  2 | Gene ID 2157 | interaction | Gene ID 2158  |
|  3 | Gene ID 2157 | interaction | Gene ID 3309  |
|  4 | Gene ID 2157 | interaction | Gene ID 28912 |
+----+--------------+-------------+---------------+


In [7]:
# Filter DRKG in natural language to drug-treats-disease relationships
# rx_dx        = df_med[df_med.iloc[:,1] ==   'Compound treats the disease']
rx_dx        =  df_med.loc[drkg_df_filtered.index]
rx_dx_subset =  rx_dx[0:2000]
rx_dx_subset

Unnamed: 0,0,1,2
1518268,Dornase alfa,Compound treats the disease,Cystic Fibrosis
1518269,Denileukin diftitox,Compound treats the disease,MESH::C063419
1518270,Etanercept,Compound treats the disease,"Spondylitis, Ankylosing"
1518271,Etanercept,Compound treats the disease,Graft vs Host Disease
1518272,Etanercept,Compound treats the disease,Hidradenitis Suppurativa
...,...,...,...
1520263,Hydrocortisone,Compound treats the disease,Choroiditis
1520264,Hydrocortisone,Compound treats the disease,"Adrenal Hyperplasia, Congenital"
1520265,Hydrocortisone,Compound treats the disease,Fanconi Anemia
1520266,Hydrocortisone,Compound treats the disease,Keratitis


In [8]:
# Remove codes that are not relevant or have no conversion
drkg_translated = df_med.copy()

# Remove taxonomy
remove_tax = drkg_translated[drkg_translated[2].str.contains('Tax::')]
drkg_translated = drkg_translated.drop(remove_tax.index)

# Check for any untranslated terms
remove_untranslated = drkg_translated[(drkg_translated[0].str.contains('::'))|(drkg_translated[2].str.contains('::'))]
drkg_translated = drkg_translated.drop(remove_untranslated.index)

# Summarize 
print('Total number of pairs ' + str(drkg_df.shape[0]))
print('Dropped taxonomy pairs ' + str(len(remove_tax.index)))
print('Dropped untranslated pairs ' + str(len(remove_untranslated.index)))
drkg_translated

Total number of pairs 5874261
Dropped taxonomy pairs 14663
Dropped untranslated pairs 62779


Unnamed: 0,0,1,2
0,Gene ID 2157,interaction,Gene ID 2157
1,Gene ID 2157,interaction,Gene ID 5264
2,Gene ID 2157,interaction,Gene ID 2158
3,Gene ID 2157,interaction,Gene ID 3309
4,Gene ID 2157,interaction,Gene ID 28912
...,...,...,...
5874256,Gene ID 29099,reaction,Gene ID 1643
5874257,Gene ID 51645,reaction,Gene ID 3183
5874258,Gene ID 865,catalysis,Gene ID 983
5874259,Gene ID 1066,binding,Gene ID 7365


In [9]:
# Check unique values of untranslated
drkg_test1 = np.unique(remove_untranslated[0][remove_untranslated[0].str.contains('::')])
drkg_test2 = np.unique(remove_untranslated[2][remove_untranslated[2].str.contains('::')])

drkg_mesh_list = drkg_test1.tolist() +  drkg_test2.tolist()
drkg_mesh_unique = pd.DataFrame(pd.unique(drkg_mesh_list))
drkg_mesh_unique

Unnamed: 0,0
0,Compound::Bioarxivdrug:0
1,Compound::Bioarxivdrug:1
2,Compound::Bioarxivdrug:10
3,Compound::Bioarxivdrug:11
4,Compound::Bioarxivdrug:2
...,...
14416,MESH::C580539
14417,MESH::C585640
14418,MESH::D000071
14419,MESH::D018290


In [10]:
drkg_untranslated = drkg_mesh_unique.copy()
drkg_untranslated = drkg_untranslated[drkg_untranslated[0].str.startswith('MESH::')]
print(len(drkg_untranslated))
drkg_untranslated

7751


Unnamed: 0,0
1429,MESH::C000020
1430,MESH::C000050
1431,MESH::C000121
1432,MESH::C000154
1433,MESH::C000188
...,...
14416,MESH::C580539
14417,MESH::C585640
14418,MESH::D000071
14419,MESH::D018290


# 3) BioLinkBERT embedding

In [11]:
rx_dx_subset

Unnamed: 0,0,1,2
1518268,Dornase alfa,Compound treats the disease,Cystic Fibrosis
1518269,Denileukin diftitox,Compound treats the disease,MESH::C063419
1518270,Etanercept,Compound treats the disease,"Spondylitis, Ankylosing"
1518271,Etanercept,Compound treats the disease,Graft vs Host Disease
1518272,Etanercept,Compound treats the disease,Hidradenitis Suppurativa
...,...,...,...
1520263,Hydrocortisone,Compound treats the disease,Choroiditis
1520264,Hydrocortisone,Compound treats the disease,"Adrenal Hyperplasia, Congenital"
1520265,Hydrocortisone,Compound treats the disease,Fanconi Anemia
1520266,Hydrocortisone,Compound treats the disease,Keratitis


In [12]:
from torch_geometric.data import HeteroData
from SIMP_LLM.llm_encode import EntityEncoder
from SIMP_LLM.dataloader_mappings import create_mapping,create_edges


## Example of loading data without anything to encode
device    = "cpu"
Encoder  = EntityEncoder(device = device )


### DX RX Relationship ###
rx_X,rx_mapping = create_mapping(rx_dx_subset[0].to_list(),encoder= Encoder ,device=device) # Maps drugs to indices
dx_X,dx_mapping = create_mapping(rx_dx_subset[2].to_list(),encoder= Encoder ,device=device) # Maps drugs to indices
## As of now this only encodes 'Compound treats the disease', but  the idea is that this is used to encode every entity 
relationship_X,relationship_mapping = create_mapping(rx_dx_subset[1].to_list(),encoder= Encoder ,device=device)  

print(f"Unique Drugs:   {len(rx_mapping)} Matrix shape: {rx_X.shape}")
print(f"Unique Disases: {len(dx_mapping)} Matrix shape: {dx_X.shape }")
relationship_feature = relationship_X[relationship_mapping['Compound treats the disease'],:].reshape(1,-1)


## NEED To add other reationships

Unique Drugs:   499 Matrix shape: torch.Size([499, 768])
Unique Disases: 810 Matrix shape: torch.Size([810, 768])


### Build HeteroData Object

In [150]:
### Create PyG Hetero Graph:
import torch_geometric.transforms as T

data = HeteroData()
data['compounds'].x = rx_X
data['disease'].x   = dx_X

#data['compounds2'].x = rx_X
#data['disease2'].x   = dx_X
#print(data)

Edge_index,edge_attribute = create_edges(df             =  rx_dx_subset,
                                          src_index_col  = 0, 
                                          src_mapping    = rx_mapping , 
                                          dst_index_col  = 2, 
                                          dst_mapping    = dx_mapping ,
                                          edge_attr      = relationship_feature)

data['compounds', 'treats', 'disease'].edge_index = Edge_index
data = T.ToUndirected()(data)


print(data)

HeteroData(
  [1mcompounds[0m={ x=[499, 768] },
  [1mdisease[0m={ x=[810, 768] },
  [1m(compounds, treats, disease)[0m={ edge_index=[2, 2000] },
  [1m(disease, rev_treats, compounds)[0m={ edge_index=[2, 2000] }
)


In [151]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric
from torch_geometric.nn import SAGEConv, to_hetero
from   torch.utils.data      import Dataset, DataLoader
from   torch_geometric.data  import Data
from   torch_geometric.utils import negative_sampling

from torch_geometric.nn import SAGEConv, to_hetero

import torch_geometric.transforms as T


class GNNStack(torch.nn.Module):
    def __init__(self, input_dim:int, hidden_dim:int, output_dim:int, layers:int, dropout:float=0.3, return_embedding=False):
        """
            A stack of GraphSAGE Module 
            input_dim        <int>:   Input dimension
            hidden_dim       <int>:   Hidden dimension
            output_dim       <int>:   Output dimension
            layers           <int>:   Number of layers
            dropout          <float>: Dropout rate
            return_embedding <bool>:  Whether to return the return_embeddingedding of the input graph
        """
        
        super(GNNStack, self).__init__()
        graphSage_conv               = pyg.nn.SAGEConv
        self.dropout                 = dropout
        self.layers                  = layers
        self.return_embedding        = return_embedding

        ### Initalize the layers ###
        self.convs                   = nn.ModuleList()                      # ModuleList to hold the layers
        for l in range(self.layers):
            if l == 0:
                ### First layer  maps from input_dim to hidden_dim ###
                self.convs.append(graphSage_conv(input_dim, hidden_dim))
            else:
                ### All other layers map from hidden_dim to hidden_dim ###
                self.convs.append(graphSage_conv(hidden_dim, hidden_dim))

        # post-message-passing processing MLP
        self.post_mp = nn.Sequential(
                                     nn.Linear(hidden_dim, hidden_dim), 
                                     nn.Dropout(self.dropout),
                                     nn.Linear(hidden_dim, output_dim))

    def forward(self, x, edge_index,device="cpu"):
        for i in range(self.layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.post_mp(x)

        # Return final layer of return_embeddingeddings if specified
        if self.return_embedding:
            return x

        # Else return class probabilities
        return F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
    


class LinkPredictorMLP(nn.Module):
    def __init__(self, in_channels:int, hidden_channels:int, out_channels:int, n_layers:int,dropout_probabilty:float=0.3):
        """
        Args:
            in_channels (int):     Number of input features.
            hidden_channels (int): Number of hidden features.
            out_channels (int):    Number of output features.
            n_layers (int):        Number of MLP layers.
            dropout (float):       Dropout probability.
            """
        super(LinkPredictorMLP, self).__init__()
        self.dropout_probabilty    = dropout_probabilty  # dropout probability
        self.mlp_layers            = nn.ModuleList()     # ModuleList: is a list of modules
        self.non_linearity         = F.relu              # non-linearity
        
        for i in range(n_layers - 1):                                 
            if i == 0:
                self.mlp_layers.append(nn.Linear(in_channels, hidden_channels))          # input layer (in_channels, hidden_channels)
            else:
                self.mlp_layers.append(nn.Linear(hidden_channels, hidden_channels))      # hidden layers (hidden_channels, hidden_channels)

        self.mlp_layers.append(nn.Linear(hidden_channels, out_channels))                 # output layer (hidden_channels, out_channels)


    def reset_parameters(self):
        for mlp_layer in self.mlp_layers:
            mlp_layer.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j                                                     # element-wise multiplication
        for mlp_layer in self.mlp_layers[:-1]:                            # iterate over all layers except the last one
            x = mlp_layer(x)                                              # apply linear transformation
            x = self.non_linearity(x)                                     # Apply non linear activation function
            x = F.dropout(x, p=self.dropout_probabilty,training=self.training)      # Apply dropout
        x = self.mlp_layers[-1](x)                                        # apply linear transformation to the last layer
        x = torch.sigmoid(x)                                              # apply sigmoid activation function to get the probability
        return x
    
### We will use This function to save our best model during trainnig ###
def save_torch_model(model,epoch,PATH:str,optimizer):
    print(f"Saving Model in Path {PATH}")
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer':optimizer,      
                }, PATH)

In [152]:
def train(model, link_predictor, data, optimizer,device:str="cuda"):
    """
    Runs offline training for model, link_predictor and node embeddings given the message
    edges and supervision edges.
    :param model: Torch Graph model used for updating node embeddings based on message passing 
        (If None, no embbeding is performed) 
    :param link_predictor: Torch model used for predicting whether edge exists or not
    :param emb: (N, d) Initial node embeddings for all N nodes in graph
    :param edge_index: (2, E) Edge index for all edges in the graph

    :param optimizer: Torch Optimizer to update model parameters
    :return: Average supervision loss over all positive (and correspondingly sampled negative) edges
    """
    if model != None:
        model.train()
    link_predictor.train()
    train_losses = []

    optimizer.zero_grad()                                 # Reset Gradients
    #edge_index     = torch.tensor(edge_index).T           # Reshape edge index     (2,|E|)
    #x              = x.squeeze(dim=1)                     # Reshape Feature matrix (|N|,D)
    #x , edge_index = x.to(device) , edge_index.to(device) # Move data to devices


    ### Step 1: Get Embeddings:
    # Run message passing on the inital node embeddings to get updated embeddings

    ### This model has the option of only running link predictor without graphsage, for that case the node embedding
    ### is equal to the original embedding (X)
    if model !=  None:
        node_emb   = model(data.x_dict, data.edge_index_dict) # Embed Bert Embeddigns with graphsage (N, d) 

    else:
        node_emb = x                     # Else (None) use Bert Embedddings
    # Predict the class probabilities on the batch of positive edges using link_predictor
    #print(node_emb[edge_index[0]].shape)
    
  
    edge_index     = data['compounds', 'treats', 'disease'].edge_label_index
    pred           = link_predictor(node_emb["compounds"][edge_index[0]], node_emb["disease"][edge_index[0]])   
    ground_truth   = data['compounds', 'treats', 'disease'].edge_label.to(device)

    loss = F.binary_cross_entropy_with_logits(pred, ground_truth.unsqueeze(1))
    # Backpropagate and update parameters
    loss.backward()
    optimizer.step()

    train_losses.append(loss.item())
    return sum(train_losses) / len(train_losses)


def evaluate(model, predictor,dataset,device="cpu",threshold=0.5,ppi_:int=0.9,verbose:bool=False,best_accuracy=0,show_extra_metrics:bool=False):
    if model != None:
        model.eval()
    possitive_acc  = 0 
    negative_acc   = 0
    batches        = 0
    if show_extra_metrics:
        yhat_total     = []
        y_total        = []
        
    for x, edge_index in dataset:                              # Get X and Index from Dataste

        edge_index      = torch.tensor(edge_index).T           # Reshape edge index     (2,|E|)
        number_of_edges =  edge_index.shape[1]                 # Retrive number of edges 
        permutations    =  torch.randperm(number_of_edges)     # Create Permutations for edge index
        edge_index      = edge_index[:,permutations]           # Run permutation
        limit           = int(ppi_*number_of_edges)            # get limit  (based on ppis to embed)
        ppi_index_embed = edge_index[:,0:limit]                # PPI to embed with GraphSage 
        ppi_index_infer = edge_index[:,limit:]                 # PPI to make inference
        
        x                    = x.squeeze(dim=1)                          # Reshape Feature matrix (|N|,D)
        x ,ppi_index_embed   = x.to(device) , ppi_index_embed.to(device) # Move data to devices
        if model !=  None:
            node_emb             = model(x,ppi_index_embed)              # Get all node embeddings
        else:
            node_emb = x                                                 # Else (None) use Bert Embedddings
            
        if verbose:
            print(f" {limit} Positive Protein Interactions were used to Embed a graph with {number_of_edges} ppi's")
        
        del ppi_index_embed 
        with torch.no_grad():
            ### Positive PPI ###
            positive_pairs_embeddings = node_emb[ppi_index_infer[0]].to(device), node_emb[ppi_index_infer[1]].to(device)
            predictions               = predictor(positive_pairs_embeddings[0], positive_pairs_embeddings[1]) 
            y                         = torch.ones_like(input=predictions)
            predictions,y             = predictions.cpu(),y.cpu()
            possitive_acc            += accuracy_score(predictions > threshold  ,y)
            if show_extra_metrics:
                yhat_total.extend(predictions.tolist())
                y_total.extend(y.tolist())
                
            else:
                del y, predictions , positive_pairs_embeddings,ppi_index_infer

            ### Negative PPI ##
            edge_index  =  edge_index.to(device)
            neg_edge    = negative_sampling(edge_index       = edge_index,        # Possitve PPI's
                                         num_nodes        = x.shape[0],           # Total number of nodes in graph
                                         num_neg_samples  = edge_index.shape[1],  # Same Number of edges as in positive example
                                         method           = 'dense',              # Method for edge generation
                                         force_undirected = True)                 # Our graph is undirected

            negative_pairs_embeddings = node_emb[neg_edge[0]].to(device), node_emb[neg_edge[1]].to(device)
            predictions               = predictor(negative_pairs_embeddings[0], negative_pairs_embeddings[1])   
            y                         = torch.zeros_like(input=predictions)
            predictions,y             = predictions.cpu(),y.cpu()
            negative_acc             += accuracy_score(predictions > threshold,y)
            if show_extra_metrics:
                yhat_total.extend(predictions.tolist())
                y_total.extend(y.tolist())
                
            else:
                del y,  predictions  ,negative_pairs_embeddings 
            batches +=1

    negative_acc  = negative_acc/batches
    possitive_acc = possitive_acc/batches
    total_acc     = 0.5*possitive_acc  + 0.5*negative_acc
    if show_extra_metrics == False:
        print(f"Sensitivity (poss_acc):{possitive_acc:.4f} Specificity (negative_acc):{negative_acc:.4f} accuracy:{total_acc:.4f}")
    
    elif show_extra_metrics == True:
        
        fig, ax = plt.subplots(1, 2,figsize=(10,2))
        fpr, tpr, thresholds = metrics.roc_curve( y_total, yhat_total)
        
        sens      =  tpr
        spec      =  1 - fpr
        j         = sens + spec -1
        opt_index = np.where(j == np.max(j))[0][0]
        op_point  = thresholds[opt_index]
        
        print(f"Youdens  index: {op_point:.4f} Sensitivity: {round(sens[opt_index],4)} Specificity: {round(spec[opt_index],4)}")
       
        ax[0].set_title("ROC Curve")
        ax[1].set_title("Confussion Matrix")
        if model == None:
            ax[0].plot(fpr,tpr,label="MLP") 
        else:
            ax[0].plot(fpr,tpr,label="GraphSage+MLP") 
        ax[0].plot([0, 1], [0, 1], 'k--')
        ax[0].set_ylabel('True Positive Rate')
        ax[0].set_xlabel('False Positive Rate')
        ax[0].legend()
       
    
        cfm = metrics.confusion_matrix(y_total, np.array(yhat_total)> op_point)
        
        cmn = cfm.astype('float') / cfm.sum(axis=1)[:, np.newaxis] # Normalise
        disp = ConfusionMatrixDisplay(cmn)
        disp.plot(ax=ax[1])
        
        plt.show()
     
    return total_acc 

In [153]:
data.metadata

<bound method HeteroData.metadata of HeteroData(
  [1mcompounds[0m={ x=[499, 768] },
  [1mdisease[0m={ x=[810, 768] },
  [1m(compounds, treats, disease)[0m={ edge_index=[2, 2000] },
  [1m(disease, rev_treats, compounds)[0m={ edge_index=[2, 2000] }
)>

In [154]:
epochs        = 500
hidden_dim    = 524      # 256 
dropout       = 0.7
num_layers    = 3
learning_rate = 1e-4
node_emb_dim  = 768
device        = "cpu"

HomoGNN         = GNNStack(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout, return_embedding=True).to(device) # the graph neural network that takes all the node embeddings as inputs to message pass and agregate
HeteroGNN       = to_hetero(HomoGNN   , data.metadata(), aggr='sum')
link_predictor  = LinkPredictorMLP(hidden_dim, hidden_dim, 1, num_layers , dropout).to(device) # the MLP that takes embeddings of a pair of nodes and predicts the existence of an edge between them
#optimizer      = torch.optim.AdamW(list(model.parameters()) + list(link_predictor.parameters() ), lr=learning_rate, weight_decay=1e-4)
optimizer       = torch.optim.Adam(list(HeteroGNN.parameters()) + list(link_predictor.parameters() ), lr=learning_rate)

print(HeteroGNN )
print(link_predictor)
print(f"Models Loaded to {device}")


GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (compounds__treats__disease): SAGEConv(768, 524, aggr=mean)
      (disease__rev_treats__compounds): SAGEConv(768, 524, aggr=mean)
    )
    (1-2): 2 x ModuleDict(
      (compounds__treats__disease): SAGEConv(524, 524, aggr=mean)
      (disease__rev_treats__compounds): SAGEConv(524, 524, aggr=mean)
    )
  )
  (post_mp): ModuleList(
    (0): ModuleDict(
      (compounds): Linear(in_features=524, out_features=524, bias=True)
      (disease): Linear(in_features=524, out_features=524, bias=True)
    )
    (1): ModuleDict(
      (compounds): Dropout(p=0.7, inplace=False)
      (disease): Dropout(p=0.7, inplace=False)
    )
    (2): ModuleDict(
      (compounds): Linear(in_features=524, out_features=524, bias=True)
      (disease): Linear(in_features=524, out_features=524, bias=True)
    )
  )
)



def forward(self, x, edge_index, device = 'cpu'):
    x_dict = torch_geometric_nn_to_hetero_transformer_get_dict(x);  x = None
    x_

In [155]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.5,
    neg_sampling_ratio=3,
    add_negative_train_samples=True,
    edge_types=('compounds', 'treats', 'disease'),
    rev_edge_types=('compounds', 'rev_treats', 'disease'), 
)

train_data, val_data, test_data = transform(data)


print(f"Train Data:\n{train_data}")
print(f"Validation Data:\n{val_data}")
print(f"Test Data:\n{test_data}")

Train Data:
HeteroData(
  [1mcompounds[0m={ x=[499, 768] },
  [1mdisease[0m={ x=[810, 768] },
  [1m(compounds, treats, disease)[0m={
    edge_index=[2, 800],
    edge_label=[3200],
    edge_label_index=[2, 3200]
  },
  [1m(disease, rev_treats, compounds)[0m={ edge_index=[2, 2000] },
  [1m(compounds, rev_treats, disease)[0m={}
)
Validation Data:
HeteroData(
  [1mcompounds[0m={ x=[499, 768] },
  [1mdisease[0m={ x=[810, 768] },
  [1m(compounds, treats, disease)[0m={
    edge_index=[2, 1600],
    edge_label=[800],
    edge_label_index=[2, 800]
  },
  [1m(disease, rev_treats, compounds)[0m={ edge_index=[2, 2000] },
  [1m(compounds, rev_treats, disease)[0m={}
)
Test Data:
HeteroData(
  [1mcompounds[0m={ x=[499, 768] },
  [1mdisease[0m={ x=[810, 768] },
  [1m(compounds, treats, disease)[0m={
    edge_index=[2, 1800],
    edge_label=[800],
    edge_label_index=[2, 800]
  },
  [1m(disease, rev_treats, compounds)[0m={ edge_index=[2, 2000] },
  [1m(compounds, rev_trea

In [156]:
#val = val_data.to("cuda")

In [157]:
node_emb    = HeteroGNN( val_data.x_dict,  val_data.edge_index_dict)


In [None]:
train_loss                      = []
train_accuracy                  = []
show_metrics_every              = 20
best_accuracy                   = 0 
best_graphsage_model_path       = ""
best_link_predictor_model_path  = ""

for epoch in range(1,epochs):
    
    ### TRAIN ####
    loss = train(model =  HeteroGNN, link_predictor = link_predictor, data = val_data, optimizer =  optimizer,device="cpu")
    train_loss.append(loss)
    print(f"Epoch {epoch}: loss: {round(loss, 5)}")
    




Epoch 1: loss: 0.84755
Epoch 2: loss: 0.84638
Epoch 3: loss: 0.84612
Epoch 4: loss: 0.84475
Epoch 5: loss: 0.84357
Epoch 6: loss: 0.84228
Epoch 7: loss: 0.84044
Epoch 8: loss: 0.83859
Epoch 9: loss: 0.83694
Epoch 10: loss: 0.83306
Epoch 11: loss: 0.82913
Epoch 12: loss: 0.82392
Epoch 13: loss: 0.81768
Epoch 14: loss: 0.80855
Epoch 15: loss: 0.7965
Epoch 16: loss: 0.78448
Epoch 17: loss: 0.76982
Epoch 18: loss: 0.75282
Epoch 19: loss: 0.73616
Epoch 20: loss: 0.71981
Epoch 21: loss: 0.70881
Epoch 22: loss: 0.70251
Epoch 23: loss: 0.69938
Epoch 24: loss: 0.69692
Epoch 25: loss: 0.69536
Epoch 26: loss: 0.69475
Epoch 27: loss: 0.69412
Epoch 28: loss: 0.69354
Epoch 29: loss: 0.69334
Epoch 30: loss: 0.6932
Epoch 31: loss: 0.69324
Epoch 32: loss: 0.69322
Epoch 33: loss: 0.69316
Epoch 34: loss: 0.69314
Epoch 35: loss: 0.69316
Epoch 36: loss: 0.69315
Epoch 37: loss: 0.69315
Epoch 38: loss: 0.69313
Epoch 39: loss: 0.69315
Epoch 40: loss: 0.69315
Epoch 41: loss: 0.69315
Epoch 42: loss: 0.69315
Epo

In [None]:
train_loss                      = []
train_accuracy                  = []
show_metrics_every              = 20
best_accuracy                   = 0 
best_graphsage_model_path       = ""
best_link_predictor_model_path  = ""

for epoch in range(1,epochs):
    
    ### TRAIN ####
    loss = train(model, link_predictor,train_dataset, optimizer,device) # Get Loss
    train_loss.append(loss)
    print(f"Epoch {epoch}: loss: {round(loss, 5)}")
    
    ### EVALUATE ###
    if (epoch % 20 == 0) or (epoch ==1):
        accuracy = evaluate(model, link_predictor ,test_dataset,device=device,best_accuracy=best_accuracy,show_extra_metrics=True)
        
    else:
        accuracy = evaluate(model, link_predictor ,test_dataset,device=device,best_accuracy=best_accuracy)
    
    train_accuracy.append(accuracy)
    ### SAVE ###
    if best_accuracy < accuracy:
        if os.path.exists(best_graphsage_model_path):
            
            os.remove(best_graphsage_model_path)
            
        if os.path.exists(best_link_predictor_model_path):
            os.remove(best_link_predictor_model_path)
        print(f"Replacing models: {best_graphsage_model_path }  {best_link_predictor_model_path}")
            
        best_accuracy  = accuracy
        best_graphsage_model_path      = f"GraphSage_epoch_{epoch}.pt"
        best_link_predictor_model_path =  f"link_predictor_epoch_{epoch}.pt"
        print(f"with: Best models at {best_graphsage_model_path }  {best_link_predictor_model_path}")
        save_torch_model(model,         epoch=epoch,PATH=best_graphsage_model_path ,     optimizer=optimizer)
        save_torch_model(link_predictor,epoch=epoch,PATH=best_link_predictor_model_path, optimizer=optimizer)

        
#### Load Best Models ####

print(f"Loading best models:  {best_graphsage_model_path }  {best_link_predictor_model_path}")
checkpoint = torch.load(best_graphsage_model_path)
model.load_state_dict(checkpoint['model_state_dict'])

checkpoint = torch.load(best_link_predictor_model_path)
link_predictor.load_state_dict(checkpoint['model_state_dict'])

del checkpoint
