AUTHORS: Alejandro and Selina primary, contributions listed by section
1) Selina primary with significant modularization by Alejandro, one chunk to save a dataset by Alex
2) Alejandro and Selina
3) Alejandro sole
4) Selina sole

In [1]:
import pandas as pd
import numpy as np
import os 
import torch_geometric.transforms as T

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from SIMP_LLM.DRKG_loading   import  get_triplets, read_tsv,filter_drkg,map_drkg_relationships,filter_interaction_subset,print_head
from SIMP_LLM.DRKG_translate import  load_lookups
from SIMP_LLM.DRKG_entity_processing import get_unique_entities, get_entity_lookup, convert_entitynames, flip_headtail
from SIMP_LLM.raredisease_loading import get_orphan_data

# 1) Load Data

AUTHORS: First pass by Selina, first 2 cells in this section modularized by Alejandro, rest of section by Selina except one cell by Alex to save the data (drkg_translated) to a CSV

In [3]:
### 1) Read: This section reads DRKG and a glossary (used to map entities from codes to words)
DATA_DIR           = os.path.join("data")
verbose            =  True 
triplets,drkg_df   =  get_triplets(drkg_file = os.path.join(DATA_DIR  ,'drkg.tsv'),             verbose=verbose)  # Read triplets (head,relationship,tail)
relation_glossary  =  read_tsv(relation_file = os.path.join(DATA_DIR  ,'relation_glossary.tsv'),verbose=verbose)  # Read relationship mapping  


### 2) Filter & Map Interactions: This section returns a list of interactions e.g. DRUGBANK::treats::Compound:Disease )
# 2.1: First  we filter the interactions to only Compound-Disease
# 2.2: Then   we map the codes -> text  (this will be use to further filter interactions based on text) e.g.  Hetionet::CpD::Compound:Disease -> palliation
# 2.3: We use natural text to fitler  interactions based on terms such as "treat" (but we return the orignal interaction name )



# modularize this in create_dataframe
drkg_rx_dx_relations        = filter_drkg(data_frame = drkg_df ,  filter_column = 1 ,  filter_term = r'.*?Compound:Disease', verbose = verbose) # 2.1 Filter only Compound-Disease Interactions
drkg_rx_dx_relations_mapped = map_drkg_relationships(drkg_rx_dx_relations,relation_glossary,verbose=verbose)                                    # 2.2 Map codes to text 

### 2.3 Filter Drug interactions Interaction types to only include: treat inhibit or alleviate interactions  ###
drkg_rx_dx_relation_subset =  filter_interaction_subset(df                  = drkg_rx_dx_relations_mapped,
                                                        filter_colunm_name = 'Interaction-type' ,
                                                        regex_string       =  'treat|inhibit|alleviate',
                                                        return_colunm_name =  'Relation-name')

# 3) Use Filter Interactions to get Gilter DRKG 
drkg_df_filtered = drkg_df[drkg_df[1].isin(drkg_rx_dx_relation_subset)] # 3.1 Filter DRKG  to only  Compund-Disease 
print_head(df=drkg_df_filtered)



###

rx_dx_triplets   = drkg_df_filtered.values.tolist()                     # 3.2 Convert filtered DRKG to list


 Triplets:

[['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::2157'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::5264'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::2158'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::3309'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::28912'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::811'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::2159'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::821'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::5627'], ['Gene::2157', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene::5624']]

 data/drkg.tsv  Dataframe:

+----+------------+--------------------------------+-------------+
|    | 0          | 1                              | 2           |
|----+------------+--------------------------------+-------------|
|  0 | Gene::2157 | bioarx::HumGenHumGen:Gene:Gene | Gene::2157  |
|  1 | Gene::2157 | bioarx::HumGenHumGen:Gene:Gene

In [4]:
# 4) Load Data frames for translation
hetionet_df, gene_df, drugbank_df, omim_df, mesh_dict, chebi_df, chembl_df = load_lookups(data_path=DATA_DIR,verbose=verbose)


 data/hetionet-v1.0-nodes.tsv  Dataframe:

+----+-------------------------+---------------------------+---------+
|    | id                      | name                      | kind    |
|----+-------------------------+---------------------------+---------|
|  0 | Anatomy::UBERON:0000002 | uterine cervix            | Anatomy |
|  1 | Anatomy::UBERON:0000004 | nose                      | Anatomy |
|  2 | Anatomy::UBERON:0000006 | islet of Langerhans       | Anatomy |
|  3 | Anatomy::UBERON:0000007 | pituitary gland           | Anatomy |
|  4 | Anatomy::UBERON:0000010 | peripheral nervous system | Anatomy |
+----+-------------------------+---------------------------+---------+

 Sample of Hetionet Data Types (Before processing):

+-------+----------------------------------+-------------------------------------------+---------------------+
|       | id                               | name                                      | kind                |
|-------+--------------------------------

In [5]:
# Load orphan disease names and codes (28 Nov 2022 version)
orphan_names, orphan_codes = get_orphan_data(os.path.join(DATA_DIR, 'en_product1-Orphadata.xml'), verbose=verbose)

# Get orphan disease MeSH codes
orphan_codes_mesh = orphan_codes[orphan_codes['code_source']=='MeSH'].copy()
orphan_codes_mesh['id'] = 'MESH::'+orphan_codes_mesh['code']


 Long-form orphan disease data (before processing):

+----+-----------+------------------------------------------------------------------------+
|    | cols      | data                                                                   |
|----+-----------+------------------------------------------------------------------------|
|  0 | Orphacode | 166024                                                                 |
|  1 | Name      | Multiple epiphyseal dysplasia, Al-Gazali type                          |
|  2 | Synonym   | Multiple epiphyseal dysplasia-macrocephaly-distinctive facies syndrome |
|  3 | Source    | ICD-10                                                                 |
|  4 | Reference | Q77.3                                                                  |
+----+-----------+------------------------------------------------------------------------+

 Long-form orphan disease data (after processing):

+----+-------------------------+---------------------------------

In [6]:
# Make dictionaries for codes
code_df   = pd.concat([hetionet_df[['name', 'id']], 
                       gene_df.rename(columns = {"description":"name", "GeneID":"id"}),
                       drugbank_df.rename(columns = {"Common name":"name", "DrugBank ID":"id"}),
                       omim_df.rename(columns = {"MIM Number":"id"}),
                       chebi_df.rename(columns = {"NAME":"name", "CHEBI_ACCESSION":"id"}),
                       chembl_df.rename(columns = {"pref_name":"name", "chembl_id":"id"}),
                       orphan_codes_mesh.rename(columns = {"Name":"name"})
                       ], ignore_index=True, axis=0).drop_duplicates() 
code_dict = pd.Series(code_df['name'].values, index=code_df['id']).to_dict() | mesh_dict # Convert node df to dict and merge with MeSH dictionary

# Get unique DRKG entities
drkg_entities = get_unique_entities(drkg_df, [0,2])

# Create and use convert_entitynames function
drkg_entity_df, drkg_unmatched = get_entity_lookup(drkg_entities, code_dict)

# Create final node dictionary
node_dict = pd.Series(drkg_entity_df['name'].values, index=drkg_entity_df['drkg_id']).to_dict() 

# Initialize translated DRKG and manually clean heads/tails for one case where they were flipped
drkg_translated    = drkg_df.copy()
drkg_translated = flip_headtail(drkg_translated, 'Gene:Compound')

# Map DRKG to translated entity names
drkg_translated = convert_entitynames(drkg_translated, 0, node_dict)
drkg_translated = convert_entitynames(drkg_translated, 2, node_dict)
drkg_translated = drkg_translated.dropna()
drkg_translated.columns = ["head_entity", "relationshipId", "tail_entity"]
print_head(drkg_translated) 

# Summarize percentage translated
print("Number of unique DRKG entities: ", len(drkg_entities)) # should be 97238
print("Number of translated entities: ", drkg_entity_df.shape[0])
print("Number of untranslated entities: ", drkg_unmatched.shape[0])
pct_entity_translated = drkg_entity_df.shape[0]/len(drkg_entities)
print('Percentage of entities translated: ', round(pct_entity_translated*100,1), '%')

print('Total DRKG relationships: ', drkg_df.shape[0])
print('Translated DRKG relationships: ', drkg_translated.shape[0])
pct_translated = drkg_translated.shape[0]/drkg_df.shape[0]
print('Percentage of relationships fully translated: ', round(pct_translated*100,1), '%')

+----+------------------------------+--------------------------------+------------------------------------------------------+
|    | head_entity                  | relationshipId                 | tail_entity                                          |
|----+------------------------------+--------------------------------+------------------------------------------------------|
|  0 | coagulation factor VIII (F8) | bioarx::HumGenHumGen:Gene:Gene | coagulation factor VIII (F8)                         |
|  1 | coagulation factor VIII (F8) | bioarx::HumGenHumGen:Gene:Gene | phytanoyl-CoA 2-hydroxylase (PHYH)                   |
|  2 | coagulation factor VIII (F8) | bioarx::HumGenHumGen:Gene:Gene | coagulation factor IX (F9)                           |
|  3 | coagulation factor VIII (F8) | bioarx::HumGenHumGen:Gene:Gene | heat shock protein family A (Hsp70) member 5 (HSPA5) |
|  4 | coagulation factor VIII (F8) | bioarx::HumGenHumGen:Gene:Gene | immunoglobulin kappa variable 3-20 (IGKV3-20)  

In [7]:
# drkg_translated.to_csv(os.path.join(DATA_DIR, "drkg_translated.csv"))

In [8]:
# Update relation glossary 
relation_df = relation_glossary.copy().rename(columns={'Relation-name':'drkg_id'})
relation_df[['head_entity','tail_entity']] = relation_df['drkg_id'].str.split('::', expand=True)[2].str.split(':', expand=True) # Set head and tail nodes

# Manually fix head and tail nodes for DGIDB relations, which reverse compound-gene interactions
relation_df.loc[relation_df['drkg_id'].str.contains('Gene:Compound'),'head_entity'] = 'Compound'
relation_df.loc[relation_df['drkg_id'].str.contains('Gene:Compound'),'tail_entity'] = 'Gene'

# Fix bioarx entries without the second "::" delimiter
bioarx_ht = relation_df['drkg_id'].str.split(':', expand=True)[[3,4]]
relation_df['head_entity'] = np.where(relation_df['head_entity'].isna(), bioarx_ht[3], relation_df['head_entity'])
relation_df['tail_entity'] = np.where(relation_df['tail_entity'].isna(), bioarx_ht[4], relation_df['tail_entity'])

# Add mapped relation group labels
relation_groups = [['activation', 'agonism', 'agonism, activation', 'activates, stimulates'],
    ['antagonism', 'blocking', 'antagonism, blocking'],
    ['binding', 'binding, ligand (esp. receptors)'],
    ['blocking', 'channel blocking'],
    ['inhibition', 'inhibits cell growth (esp. cancers)', 'inhibits'],
    ['enzyme', 'enzyme activity'],
    ['upregulation', 'increases expression/production'],
    ['downregulation', 'decreases expression/production'],
    ['Compound treats the disease', 'treatment/therapy (including investigatory)', 'treatment']]

relation_df['relation_name'] = relation_df['Interaction-type']

for grp in relation_groups:
    relation_df_subset = relation_df[relation_df['Interaction-type'].isin(grp)].copy()
    for entities in relation_df_subset['Connected entity-types'].unique():
        subgrp = relation_df_subset[relation_df_subset['Connected entity-types'] == entities]['Interaction-type'].unique()
        relation_df.loc[(relation_df_subset['Connected entity-types'] == entities) & (relation_df['Interaction-type'].isin(subgrp)), 'relation_name'] = subgrp[0]

# Remove special characters from relation names (for embedding) and create relation label with underscores instead of spaces (for edge name in heterograph)
relation_df['relation_name'] = relation_df['relation_name'].str.replace(',|/', ' or', regex=True)
relation_df['relation_name'] = relation_df['relation_name'].str.replace('esp.','especially')
relation_df['relation_name'] = relation_df['relation_name'].str.replace('\(|\)|-|\.', '', regex=True)
relation_df['relation_label'] = relation_df['relation_name'].str.replace(' ', '_')

# Check if any relationshp names still have non alpha numeric values except space
error_relation_names = relation_df['relation_name'][relation_df['relation_name'].str.replace(' ', '').str.contains(r"[^a-zA-Z0-9]+", regex=True)].drop_duplicates()
if len(error_relation_names):
    print('Warning: The following relation names contain special characters, which can interfere with PyG/GraphSage')
    print(error_relation_names)
    
relation_df


Unnamed: 0,drkg_id,Data-source,Connected entity-types,Interaction-type,Description,Reference for the description,head_entity,tail_entity,relation_name,relation_label
0,DGIDB::ACTIVATOR::Gene:Compound,DGIDB,Compound:Gene,activation,An activator interaction is when a drug activa...,http://www.dgidb.org/getting_started,Compound,Gene,activation,activation
1,DGIDB::AGONIST::Gene:Compound,DGIDB,Compound:Gene,agonism,An agonist interaction occurs when a drug bind...,http://www.dgidb.org/getting_started,Compound,Gene,activation,activation
2,DGIDB::ALLOSTERIC MODULATOR::Gene:Compound,DGIDB,Compound:Gene,allosteric modulation,An allosteric modulator interaction occurs whe...,http://www.dgidb.org/getting_started,Compound,Gene,allosteric modulation,allosteric_modulation
3,DGIDB::ANTAGONIST::Gene:Compound,DGIDB,Compound:Gene,antagonism,An antagonist interaction occurs when a drug b...,http://www.dgidb.org/getting_started,Compound,Gene,antagonism,antagonism
4,DGIDB::ANTIBODY::Gene:Compound,DGIDB,Compound:Gene,antibody,An antibody interaction occurs when an antibod...,http://www.dgidb.org/getting_started,Compound,Gene,antibody,antibody
...,...,...,...,...,...,...,...,...,...,...
102,bioarx::Covid2_acc_host_gene::Disease:Gene,BIBLIOGRAPHY,Disease:Gene,interaction,"Interactions between 27 viral proteins, and ...",,Disease,Gene,interaction,interaction
103,bioarx::DrugHumGen:Compound:Gene,BIBLIOGRAPHY,Compound:Gene,interaction,,,Compound,Gene,interaction,interaction
104,bioarx::DrugVirGen:Compound:Gene,BIBLIOGRAPHY,Compound:Gene,interaction,,,Compound,Gene,interaction,interaction
105,bioarx::HumGenHumGen:Gene:Gene,BIBLIOGRAPHY,Gene:Gene,interaction,Protein-protein interaction,,Gene,Gene,interaction,interaction


In [9]:
relation_data_to_merge = relation_df[
    ["drkg_id", "relation_label", "head_entity", "tail_entity"]
]
relation_data_to_merge = relation_data_to_merge.rename(
    columns={"head_entity": "head_entity_type", "tail_entity": "tail_entity_type"}
)
for col_name in ["relation_label", "head_entity_type", "tail_entity_type"]:
    relation_data_to_merge[col_name] = relation_data_to_merge[col_name].str.upper()
drkg_translated_with_relation_labels = drkg_translated.merge(
    relation_data_to_merge,
    left_on="relationshipId",
    right_on="drkg_id",
    how="left",
)
drkg_translated_with_relation_labels.drop(["relationshipId"], axis=1, inplace=True)
print_head(drkg_translated_with_relation_labels)

+----+------------------------------+------------------------------------------------------+--------------------------------+------------------+--------------------+--------------------+
|    | head_entity                  | tail_entity                                          | drkg_id                        | relation_label   | head_entity_type   | tail_entity_type   |
|----+------------------------------+------------------------------------------------------+--------------------------------+------------------+--------------------+--------------------|
|  0 | coagulation factor VIII (F8) | coagulation factor VIII (F8)                         | bioarx::HumGenHumGen:Gene:Gene | INTERACTION      | GENE               | GENE               |
|  1 | coagulation factor VIII (F8) | phytanoyl-CoA 2-hydroxylase (PHYH)                   | bioarx::HumGenHumGen:Gene:Gene | INTERACTION      | GENE               | GENE               |
|  2 | coagulation factor VIII (F8) | coagulation factor IX (F9) 

In [10]:
drkg_translated_with_relation_labels.to_csv(os.path.join(DATA_DIR, "drkg_translated_with_relation_labels.csv"))

In [11]:
drkg_entity_df.to_pickle(os.path.join(DATA_DIR, 'drkg_entity_df.pkl'))

In [13]:
# Save h-r-t triples
triplet_cats = relation_df[relation_df['drkg_id'].isin(drkg_translated['tail_entity'])].copy()
triplet_cats = triplet_cats[['head_entity', 'relation_label', 'tail_entity']].drop_duplicates()
# triplet_cats.to_csv("triplets.csv")


# 2) BioLinkBERT embedding

AUTHORS: Alejandro and Selina (equal)

In [14]:
from torch_geometric.data import HeteroData
from SIMP_LLM.llm_encode import EntityEncoder
from SIMP_LLM.dataloader_mappings import create_mapping, create_edges, embed_entities, embed_edges

### Set variables and load data

In [16]:
## Set variables
device   = "cpu"
Encoder  = EntityEncoder(device = device )
run_full_sample = 0
Sample = 5

if run_full_sample:
    # Run full DRKG
    entity_df = drkg_entity_df.copy()
    hrt_data = drkg_translated.copy()
    relation_lookup = relation_df.copy()
else:
    # Change interaction list inputs as needed
    interaction_list = ['activation', 'agonism', 'agonism, activation', 'Compound treats the disease', 'treats', 'inhibition'] 

    # Create relationship subset for testing based on interaction and umbrella relation names from DRKG relations
    relation_name_list = relation_df.loc[relation_df['Interaction-type'].isin(interaction_list), 'relation_name']
    test_relation_df = relation_df[relation_df['relation_name'].isin(relation_name_list)].copy()
    relation_entity_list = pd.concat([relation_df['head_entity'], relation_df['tail_entity']], ignore_index=True, axis=0).unique()
    print_head(test_relation_df)

    # Create test sample of DRKG relationships filtering to these relations
    test_hrt_df = drkg_translated[drkg_translated['relationshipId'].isin(test_relation_df['drkg_id'])]
    test_hrt_df = test_hrt_df.groupby('relationshipId').head(Sample).reset_index(drop=True)

    # Get dataframe of entities, careful to only use types represented in test_relation_df
    test_unique_entities = get_unique_entities(test_hrt_df, columns=['head_entity','tail_entity'])
    test_entity_df = drkg_entity_df[drkg_entity_df['name'].isin(test_unique_entities)]
    test_entity_df = test_entity_df[test_entity_df['entity_type'].isin(relation_entity_list)]
    print_head(test_hrt_df)
    print_head(test_entity_df)


    entity_df = test_entity_df.copy()
    hrt_data = test_hrt_df.copy()
    relation_lookup = relation_df.copy()

+----+------------------------------------+---------------+--------------------------+-----------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------+---------------+---------------+-----------------------------+-----------------------------+
|    | drkg_id                            | Data-source   | Connected entity-types   | Interaction-type            | Description                                                                                                                                                                                            | Reference for the description        | head_entity   | tail_entity   | relation_name               | relation_label              |
|----+------------------------------------+---------------+--------------------------+----------------

In [18]:
hrt_data[hrt_data['relationshipId'].str.contains('Disease')]
entity_df[entity_df['name']=='Cystic Fibrosis']

Unnamed: 0,drkg_id,drkg_dict_id,name,entity_type,ontology_code,ontology_name,code
43514,Disease::MESH:D003550,MESH::D003550,Cystic Fibrosis,Disease,MESH:D003550,MESH,D003550


### Build HeteroData Object

In [11]:
# Initialize heterograph object
data = HeteroData()

# Embed entities, add to graph, and save embedding mapping dictionary of dictionaries
mapping_dict = embed_entities(entity_df, data, Encoder, device) 

# Embed relationships, add to graph, and save relation embeddings/mapping dictionary
relation_X, relation_mapping = embed_edges(hrt_data, relation_lookup, data, mapping_dict, Encoder, device)

# Print summary
#data = T.ToUndirected()(data)

print(data)
for ent_type in entity_df['entity_type'].unique():
    print(f"Unique {ent_type}s: {len(mapping_dict[ent_type])} \t Matrix shape: {data[ent_type].x.shape }")
    # print(mapping_dict[ent_type]) # Prints whole dictionary so delete/uncomment if using all entities

RuntimeError: Parent directory data2/ckpts/entity does not exist.

In [None]:
data = T.ToUndirected()(data)
data


HeteroData(
  [1mCompound[0m={ x=[29, 768] },
  [1mDisease[0m={ x=[14, 768] },
  [1mGene[0m={ x=[51, 768] },
  [1m(Compound, activation, Gene)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  },
  [1m(Compound, inhibition, Gene)[0m={
    edge_index=[2, 10],
    edge_label=[10, 768]
  },
  [1m(Compound, Compound_treats_the_disease, Disease)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  },
  [1m(Gene, activates_or_stimulates, Gene)[0m={
    edge_index=[2, 20],
    edge_label=[20, 768]
  },
  [1m(Gene, inhibition, Gene)[0m={
    edge_index=[2, 10],
    edge_label=[10, 768]
  },
  [1m(Gene, rev_activation, Compound)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  },
  [1m(Gene, rev_inhibition, Compound)[0m={
    edge_index=[2, 10],
    edge_label=[10, 768]
  },
  [1m(Disease, rev_Compound_treats_the_disease, Compound)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  }
)

## 3) GRAPH SAGE
AUTHORS: Alejandro

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric
from torch_geometric.nn import SAGEConv, to_hetero
from   torch.utils.data      import Dataset, DataLoader
from   torch_geometric.data  import Data
from   torch_geometric.utils import negative_sampling

from torch_geometric.nn import SAGEConv, to_hetero




class GNNStack(torch.nn.Module):
    def __init__(self, input_dim:int, hidden_dim:int, output_dim:int, layers:int, dropout:float=0.3, return_embedding=False):
        """
            A stack of GraphSAGE Module 
            input_dim        <int>:   Input dimension
            hidden_dim       <int>:   Hidden dimension
            output_dim       <int>:   Output dimension
            layers           <int>:   Number of layers
            dropout          <float>: Dropout rate
            return_embedding <bool>:  Whether to return the return_embeddingedding of the input graph
        """
        
        super(GNNStack, self).__init__()
        graphSage_conv               = pyg.nn.SAGEConv
        self.dropout                 = dropout
        self.layers                  = layers
        self.return_embedding        = return_embedding

        ### Initalize the layers ###
        self.convs                   = nn.ModuleList()                      # ModuleList to hold the layers
        for l in range(self.layers):
            if l == 0:
                ### First layer  maps from input_dim to hidden_dim ###
                self.convs.append(graphSage_conv(input_dim, hidden_dim))
            else:
                ### All other layers map from hidden_dim to hidden_dim ###
                self.convs.append(graphSage_conv(hidden_dim, hidden_dim))

        # post-message-passing processing MLP
        self.post_mp = nn.Sequential(
                                     nn.Linear(hidden_dim, hidden_dim), 
                                     nn.Dropout(self.dropout),
                                     nn.Linear(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        for i in range(self.layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.post_mp(x)

        # Return final layer of return_embeddingeddings if specified
        if self.return_embedding:
            return x

        # Else return class probabilities
        return F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
    


class LinkPredictorMLP(nn.Module):
    def __init__(self, in_channels:int, hidden_channels:int, out_channels:int, n_layers:int,dropout_probabilty:float=0.3):
        """
        Args:
            in_channels (int):     Number of input features.
            hidden_channels (int): Number of hidden features.
            out_channels (int):    Number of output features.
            n_layers (int):        Number of MLP layers.
            dropout (float):       Dropout probability.
            """
        super(LinkPredictorMLP, self).__init__()
        self.dropout_probabilty    = dropout_probabilty  # dropout probability
        self.mlp_layers            = nn.ModuleList()     # ModuleList: is a list of modules
        self.non_linearity         = F.relu              # non-linearity
        
        for i in range(n_layers - 1):                                 
            if i == 0:
                self.mlp_layers.append(nn.Linear(in_channels, hidden_channels))          # input layer (in_channels, hidden_channels)
            else:
                self.mlp_layers.append(nn.Linear(hidden_channels, hidden_channels))      # hidden layers (hidden_channels, hidden_channels)

        self.mlp_layers.append(nn.Linear(hidden_channels, out_channels))                 # output layer (hidden_channels, out_channels)


    def reset_parameters(self):
        for mlp_layer in self.mlp_layers:
            mlp_layer.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j                                                     # element-wise multiplication
        for mlp_layer in self.mlp_layers[:-1]:                            # iterate over all layers except the last one
            x = mlp_layer(x)                                              # apply linear transformation
            x = self.non_linearity(x)                                     # Apply non linear activation function
            x = F.dropout(x, p=self.dropout_probabilty,training=self.training)      # Apply dropout
        x = self.mlp_layers[-1](x)                                        # apply linear transformation to the last layer
        x = torch.sigmoid(x)                                              # apply sigmoid activation function to get the probability
        return x
    
### We will use This function to save our best model during trainnig ###
def save_torch_model(model,epoch,PATH:str,optimizer):
    print(f"Saving Model in Path {PATH}")
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer':optimizer,      
                }, PATH)

In [None]:
epochs        = 500
hidden_dim    = 524      # 256 
dropout       = 0.7
num_layers    = 3
learning_rate = 1e-4
node_emb_dim  = 768



HomoGNN         = GNNStack(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout, return_embedding=True).to(device) # the graph neural network that takes all the node embeddings as inputs to message pass and agregate
HeteroGNN       = to_hetero(HomoGNN   , data.metadata(), aggr='sum')
link_predictor  = LinkPredictorMLP(hidden_dim, hidden_dim, 1, num_layers , dropout).to(device) # the MLP that takes embeddings of a pair of nodes and predicts the existence of an edge between them
#optimizer      = torch.optim.AdamW(list(model.parameters()) + list(link_predictor.parameters() ), lr=learning_rate, weight_decay=1e-4)
optimizer       = torch.optim.Adam(list(HeteroGNN.parameters()) + list(link_predictor.parameters() ), lr=learning_rate)

print(HeteroGNN )
print(link_predictor)
print(f"Models Loaded to {device}")

GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (Compound__activation__Gene): SAGEConv(768, 524, aggr=mean)
      (Compound__inhibition__Gene): SAGEConv(768, 524, aggr=mean)
      (Compound__Compound_treats_the_disease__Disease): SAGEConv(768, 524, aggr=mean)
      (Gene__activates_or_stimulates__Gene): SAGEConv(768, 524, aggr=mean)
      (Gene__inhibition__Gene): SAGEConv(768, 524, aggr=mean)
      (Gene__rev_activation__Compound): SAGEConv(768, 524, aggr=mean)
      (Gene__rev_inhibition__Compound): SAGEConv(768, 524, aggr=mean)
      (Disease__rev_Compound_treats_the_disease__Compound): SAGEConv(768, 524, aggr=mean)
    )
    (1-2): 2 x ModuleDict(
      (Compound__activation__Gene): SAGEConv(524, 524, aggr=mean)
      (Compound__inhibition__Gene): SAGEConv(524, 524, aggr=mean)
      (Compound__Compound_treats_the_disease__Disease): SAGEConv(524, 524, aggr=mean)
      (Gene__activates_or_stimulates__Gene): SAGEConv(524, 524, aggr=mean)
      (Gene__inhibition__Gene): 

In [None]:
print(data)

HeteroData(
  [1mCompound[0m={ x=[29, 768] },
  [1mDisease[0m={ x=[14, 768] },
  [1mGene[0m={ x=[51, 768] },
  [1m(Compound, activation, Gene)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  },
  [1m(Compound, inhibition, Gene)[0m={
    edge_index=[2, 10],
    edge_label=[10, 768]
  },
  [1m(Compound, Compound_treats_the_disease, Disease)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  },
  [1m(Gene, activates_or_stimulates, Gene)[0m={
    edge_index=[2, 20],
    edge_label=[20, 768]
  },
  [1m(Gene, inhibition, Gene)[0m={
    edge_index=[2, 10],
    edge_label=[10, 768]
  },
  [1m(Gene, rev_activation, Compound)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  },
  [1m(Gene, rev_inhibition, Compound)[0m={
    edge_index=[2, 10],
    edge_label=[10, 768]
  },
  [1m(Disease, rev_Compound_treats_the_disease, Compound)[0m={
    edge_index=[2, 15],
    edge_label=[15, 768]
  }
)


In [None]:
node_emb   = HeteroGNN(data.x_dict, data.edge_index_dict)
edge_index = data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_index 
pos_pred   = link_predictor(node_emb["Compound"][edge_index[0]], node_emb["Disease"][edge_index[1]])   # (B, )



In [None]:
node_emb

{'Compound': tensor([[ 0.5875, -1.0616, -0.7926,  ..., -0.0278, -0.0979,  0.5913],
         [ 0.2394,  0.2247,  1.2243,  ...,  0.8381,  1.1482, -0.6650],
         [-0.6849, -1.8316, -0.0687,  ...,  0.7773, -0.1265,  1.0705],
         ...,
         [ 0.4461,  0.2052, -0.9088,  ...,  0.0704, -0.5500,  0.4950],
         [-0.8308,  0.2126,  1.3543,  ..., -0.1609,  0.9807,  0.3113],
         [-0.0263,  0.0214,  0.8781,  ...,  0.2766, -0.5748,  0.3060]],
        grad_fn=<AddmmBackward0>),
 'Disease': tensor([[ 1.1654,  1.1531, -0.0988,  ..., -0.4830, -0.7136, -0.3199],
         [-0.0638,  0.4340, -1.4309,  ..., -0.4483, -1.2654,  0.1086],
         [ 0.6284, -0.2473,  0.5736,  ..., -0.1155, -0.5782, -1.0763],
         ...,
         [ 0.0311,  0.0726,  1.3918,  ..., -0.5468, -0.1218, -0.0914],
         [-0.1779,  0.6497,  0.6867,  ..., -0.5519,  0.2511, -0.0758],
         [ 0.7770,  0.5642,  0.9319,  ...,  0.3275,  0.5305,  0.4870]],
        grad_fn=<AddmmBackward0>),
 'Gene': tensor([[-0.1290

## 4) Identify rare diseases in DRKG

AUTHORS: Selina

In [9]:
from SIMP_LLM.raredisease_loading import get_drkg_entity_ontologies, read_and_process_doid, create_orphanet_regex, merge_regex, find_drkg_rarediseases, check_raredisease_multiple_codes, read_and_process_rep_drugs

### Download and process data

In [10]:
# All drkg entities 
drkg_all_entities = pd.concat([drkg_entity_df, drkg_unmatched], ignore_index=True, axis=0) 

 # Check disease code ontologies in DRKG - extra code in DKRG called DOID - and Orphanet
if verbose:
    condition_list = ['Disease', 'Symptom', 'Side Effect']
    drkg_ontology_counts = get_drkg_entity_ontologies(drkg_all_entities, condition_list)
    print('DRKG code counts:')
    print(drkg_ontology_counts)
    print('\nOrphanet code types: ', orphan_codes['code_source'].unique())

# Download and process DOID (disease ontology) codes
doid_df  =  read_and_process_doid(relation_file = os.path.join(DATA_DIR,'DOID.csv'), verbose=verbose) 

# Create regex for match Orphanet mapped codes to cross reference codes in DOID
orphan_codes_match = create_orphanet_regex(orphan_codes, verbose=verbose)

DRKG code counts:
   matched ontology_name  count
0        0          MESH   1002
1        1          DOID    127
2        1          MESH   4284
3        1          OMIM     78
4        1      UMLS CUI   5701

Orphanet code types:  ['ICD-10' 'OMIM' 'UMLS' 'MeSH' 'ICD-11' 'GARD' 'MedDRA']

 DOID Dataframe (After processing):

+-----+--------------+--------------------------+------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+------------------------------------------------------------------------------------------------------------------+----------------------+------------------------------------------------+---------------------------------------------+
|     | id           | Preferred Label          | Synonyms                                       | Definitions                                            

  df_raw = pd.read_csv(relation_file)


### Code and regex matching between Orphacodes and DRKG

In [11]:
# Match Orphacodes and DOIDs - Takes ~3min to run
orphacode_doid_regex = merge_regex(orphan_codes_match, 'regex', doid_df, 'database_cross_reference')

# Match rare diseases in DRKG
matched_rarediseases = find_drkg_rarediseases(drkg_all_entities, orphan_codes, orphacode_doid_regex, verbose=verbose)

# Check for rare diseases with multiple codes
if verbose:
    multiple_orphacode, multiple_drkg = check_raredisease_multiple_codes(matched_rarediseases)
    print_head(multiple_orphacode)
    print_head(multiple_drkg)

# Get non-matched rare diseases (could still be in DRKG but under different name or code - need to check that embeddings are separate)
unmatched_rarediseases = orphan_codes[~orphan_codes['Orphacode'].isin(matched_rarediseases['Orphacode'])]
if verbose:
    print('Number of unmatched Orphacodes: ', len(unmatched_rarediseases['Orphacode'].unique()))
    print_head(unmatched_rarediseases)

# Get all DRKG relationships involving matched rare diseases
raredisease_heads = drkg_df[drkg_df[0].isin(matched_rarediseases['drkg_id'])]
raredisease_tails = drkg_df[drkg_df[2].isin(matched_rarediseases['drkg_id'])]

if verbose:
    print('Rare disease heads in untranslated DRKG: ', raredisease_heads.shape[0])
    print_head(raredisease_heads)
    print('\n Rare disease tails in untranslated DRKG: ',raredisease_tails.shape[0])
    print_head(raredisease_tails)

# Get DRKG rare disease entries
raredisease_index = raredisease_heads.index.tolist() + raredisease_tails.index.tolist()
drkg_raredisease = drkg_translated.loc[drkg_translated.index.isin(raredisease_index)]
if verbose:
    print_head(drkg_raredisease)


 DRKG-Orphacode matches:

+----+-----------------------+----------------+------------------------------------------------------------------------+---------------+-----------------+-----------------+---------+-----------+-------------+------------------------------------------------------------------------+---------------+----------------------------------------------------+------------------------------+-----------------------------------+---------------------+--------------+----------+------+----------+---------+
|    | drkg_id               | drkg_dict_id   | name                                                                   | entity_type   | ontology_code   | ontology_name   | code    |   matched |   Orphacode | Name                                                                   | code_source   | Disordermappingrelation                            |   Disordermappingicdrelation | Disordermappingvalidationstatus   | code_source_upper   | match_type   |   code_x |   id |   code

### Check drug repurposing hub (DRH) database for examples

In [12]:
rep_drugs_df       =  read_and_process_rep_drugs(os.path.join(DATA_DIR, 'repurposing_drugs_20200324.txt'), verbose=verbose)     # Process entity names for clarity (e.g., F8 -> Gene F8) 


 data/repurposing_drugs_20200324.txt  Drug Repurposing Dataframe:

+----+------------------------------+------------------+---------------------------------+----------------------------------------------------------------------------------------------+----------------------+---------------------+
|    | pert_iname                   | clinical_phase   | moa                             | target                                                                                       | disease_area         | indication          |
|----+------------------------------+------------------+---------------------------------+----------------------------------------------------------------------------------------------+----------------------+---------------------|
|  0 | (R)-(-)-apomorphine          | Launched         | dopamine receptor agonist       | ADRA2A|ADRA2B|ADRA2C|CALY|DRD1|DRD2|DRD3|DRD4|DRD5|HTR1A|HTR1B|HTR1D|HTR2A|HTR2B|HTR2C|HTR5A | neurology/psychiatry | Parkinson's Disease |
|  1 | (

In [13]:
# Get repurposed drugs with indications
drug_with_indication = rep_drugs_df[~rep_drugs_df['indication'].isna()]

# Find entries of DRH data with drug in DRKG and indication in DRKG rare diseases
drh_in_drkg = drug_with_indication[(drug_with_indication['indication'].str.lower().isin(matched_rarediseases['name'].str.lower())) & 
                                   (drug_with_indication['pert_iname'].str.lower().isin(drkg_entity_df['name'].str.lower())) ]

In [19]:
# Find DRKG rare disease relationships whose drug-treatment relationship in DRH is not represented in DRKG,
# first by matching drug-disease pairs in both
drh_matches = drh_in_drkg.merge(drkg_raredisease, how='left', left_on=[drh_in_drkg['pert_iname'].str.lower(), drh_in_drkg['indication'].str.lower()], right_on=[drkg_raredisease['head_entity'].str.lower(), drkg_raredisease['tail_entity'].str.lower()])
print(drh_matches.shape)

# Then dropping matches
drh_unmatched_try1 = drh_matches[drh_matches['head_entity'].isna()].drop(columns=['head_entity', 'relationshipId', 'tail_entity','key_0', 'key_1'])
print(drh_unmatched_try1.shape)

# Try replacing 'cancer' with 'neoplasms' and dropping further matches
drh_unmatched_try1['indication'] = drh_unmatched_try1['indication'].str.replace('cancer', 'neoplasms')

drh_matches_try2 = drh_unmatched_try1.merge(drkg_translated, how='left', left_on=[drh_unmatched_try1['pert_iname'].str.lower(), drh_unmatched_try1['indication'].str.lower()], right_on=[drkg_translated['head_entity'].str.lower(), drkg_translated['tail_entity'].str.lower()])
drh_matches_try2
print(drh_matches_try2.shape)

# # Then dropping matches again
drh_unmatched = drh_matches_try2[drh_matches_try2['head_entity'].isna()].drop(columns=['head_entity', 'relationshipId', 'tail_entity', 'key_0', 'key_1']).drop_duplicates()
print(drh_unmatched.shape)


(114, 11)
(31, 6)
(35, 11)
(20, 6)


In [36]:
# Check if prostate cancer included in rare diseases
matched_rarediseases[matched_rarediseases['Name'].str.lower().str.contains('prostate cancer')]

Unnamed: 0,drkg_id,drkg_dict_id,name,entity_type,ontology_code,ontology_name,code,matched,Orphacode,Name,code_source,Disordermappingrelation,Disordermappingicdrelation,Disordermappingvalidationstatus,code_source_upper,match_type,code_x,id,code_y,key_0
1187,Disease::DOID:10283,Disease::DOID:10283,prostate cancer,Disease,DOID:10283,DOID,,1,1331,Familial prostate cancer,MeSH,E (Exact mapping: the two concepts are equival...,,Validated,,DOID regex,10283,DOID:10283,C537243,


In [32]:
# Merge DRKG name of diseases (matching cancer/neoplasms too)
drkg_test = drh_unmatched.merge(matched_rarediseases[['name', 'drkg_id']], how='left', left_on=[drh_unmatched['indication'].str.lower()], right_on=[matched_rarediseases['name'].str.lower().str.replace('cancer', 'neoplasms')])
drkg_test = drkg_test.drop(columns=['key_0']).drop_duplicates().rename(columns={'name':'rare_disease', 'drkg_id':'disease_drkg_id'})

# Merge DRKG name of drugs
drkg_test = drkg_test.merge(drkg_entity_df[['name', 'drkg_id']], how='left', left_on=[drkg_test['pert_iname'].str.lower()], right_on=[drkg_entity_df['name'].str.lower()])
drkg_test = drkg_test[['name', 'drkg_id', 'rare_disease', 'disease_drkg_id']].rename(columns={'name':'compound', 'drkg_id':'compound_drkg_id'})
drkg_test = drkg_test[drkg_test['disease_drkg_id'].str.contains('Disease')]
print(len(drkg_test))
drkg_test

19


Unnamed: 0,compound,compound_drkg_id,rare_disease,disease_drkg_id
0,Alpelisib,Compound::DB12015,breast cancer,Disease::DOID:1612
1,Ambenonium,Compound::DB01122,Myasthenia Gravis,Disease::MESH:D009157
4,Apalutamide,Compound::DB11901,prostate cancer,Disease::DOID:10283
5,Artesunate,Compound::DB09274,Malaria,Disease::MESH:D008288
6,Artesunate,Compound::DB09274,malaria,Disease::DOID:12365
8,Darolutamide,Compound::DB12941,prostate cancer,Disease::DOID:10283
9,Delamanid,Compound::DB11637,Tuberculosis,Disease::MESH:D014376
11,Didox,Compound::DB12948,breast cancer,Disease::DOID:1612
12,Enzalutamide,Compound::DB08899,prostate cancer,Disease::DOID:10283
15,Leucovorin,Compound::DB00650,Osteosarcoma,Disease::MESH:D012516


In [33]:
# Check for rare disease links in DRKG
test_rd_drkg = drkg_translated[drkg_translated['relationshipId'].str.contains('Compound:Disease')]
test_rd_drkg[test_rd_drkg['head_entity'].str.lower() == 'apalutamide']

# Count distinct links
drkg_test_ct = drkg_test.copy()
drkg_test_ct['compound'] = drkg_test_ct['compound'].str.lower()
drkg_test_ct['rare_disease'] = drkg_test_ct['rare_disease'].str.lower()
drkg_test_ct = drkg_test_ct.drop_duplicates(subset=['compound', 'rare_disease'])
print(len(drkg_test_ct))
drkg_test_ct

17


Unnamed: 0,compound,compound_drkg_id,rare_disease,disease_drkg_id
0,alpelisib,Compound::DB12015,breast cancer,Disease::DOID:1612
1,ambenonium,Compound::DB01122,myasthenia gravis,Disease::MESH:D009157
4,apalutamide,Compound::DB11901,prostate cancer,Disease::DOID:10283
5,artesunate,Compound::DB09274,malaria,Disease::MESH:D008288
8,darolutamide,Compound::DB12941,prostate cancer,Disease::DOID:10283
9,delamanid,Compound::DB11637,tuberculosis,Disease::MESH:D014376
11,didox,Compound::DB12948,breast cancer,Disease::DOID:1612
12,enzalutamide,Compound::DB08899,prostate cancer,Disease::DOID:10283
15,leucovorin,Compound::DB00650,osteosarcoma,Disease::MESH:D012516
16,leuprolide,Compound::DB00007,prostate cancer,Disease::DOID:10283


In [34]:
drkg_test.to_csv("test_raredisease_links.csv")
drkg_test_ct.to_csv("test_raredisease_links_unique.csv")