# PrimeKG Enrichment and Embedding

In this tutorial, we will explain how to perform multimodal enrichment and embedding of PrimeKG nodes.

We will consider the following node types
1. Drugs (PubChem/DrugBank/CTD) - TEXT and SMILES
2. Proteins (NCBI/Gene) - TEXT and amino-acid sequence
3. Pathways (Reactome) - TEXT
4. Phenotypes (HPO) - TEXT
5. Protein function (GO) - TEXT
6. Disease (MONDO) - TEXT
7. Anatomy (UBERON) - TEXT

Prior information about the PrimeKG can be found in the following repositories:
- https://github.com/mims-harvard/PrimeKG
- https://github.com/mims-harvard/TDC/

Note that we are leveraging the PrimeKG provided in Harvard Dataverse, which is publicly available in the following link:

https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/IXA7BM

By the time we are writing this tutorial, the latest version of PrimeKG (`kg.csv`) is `2.1`.

First of all, we need to import necessary libraries as follows:

In [1]:
# Import necessary libraries
import sys
import torch
import networkx as nx
from tqdm import tqdm
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.datasets.primekg import PrimeKG
from aiagents4pharma.talk2knowledgegraphs.utils.enrichments.uniprot_proteins import EnrichmentWithUniProt
from aiagents4pharma.talk2knowledgegraphs.utils.enrichments.ols_terms import EnrichmentWithOLS
from aiagents4pharma.talk2knowledgegraphs.utils.enrichments.reactome_pathways import EnrichmentWithReactome
from aiagents4pharma.talk2knowledgegraphs.utils.enrichments.pubchem_strings import EnrichmentWithPubChem
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.huggingface import EmbeddingWithHuggingFace
from aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils import external_id2pubchem_cid

  from .autonotebook import tqdm as notebook_tqdm
INFO:aiagents4pharma.talk2scholars.tools.paper_download.download_medrxiv_input:DOI Received: annotation=NoneType required=True description='The medRxiv DOI, from search_helper or multi_helper or single_helper, \n    used to retrieve the paper details and PDF URL.'
INFO:aiagents4pharma.talk2scholars.tools.paper_download.download_biorxiv_input:DOI Received: annotation=NoneType required=True description='The bioRxiv DOI, from search_helper or multi_helper or single_helper, \n    used to retrieve the paper details and PDF URL.'


### Check device availability

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

### Load the BioBERT model

In [3]:
# # Using MSFT's BioBERT
# biobert_model = EmbeddingWithHuggingFace(model_name='microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract',
#                                      model_cache_dir="../../../../data/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract/",
#                                      truncation=False,
#                                      device=device)

### Load PrimeKG

The `PrimeKG` dataset allows to load the data from the Harvard Dataverse server if the data is not available locally. 

Otherwise, the data is loaded from the local directory as defined in the `local_dir`.

In [4]:
# Define primekg data by providing a local directory where the data is stored
primekg_data = PrimeKG(local_dir="../../../../data/primekg/")

To load the dataframes of nodes and edges from PrimeKG, we just need to invoke a method as follows.

In [5]:
# Invoke a method to load the data
primekg_data.load_data()

# Get primekg_nodes and primekg_edges
primekg_nodes = primekg_data.get_nodes()
primekg_edges = primekg_data.get_edges()

Loading nodes of PrimeKG dataset ...
../../../../data/primekg/primekg_nodes.tsv.gz already exists. Loading the data from the local directory.
Loading edges of PrimeKG dataset ...
../../../../data/primekg/primekg_edges.tsv.gz already exists. Loading the data from the local directory.


### Check PrimeKG Dataframes

As mentioned before, the primekg_nodes and primekg_edges are the dataframes of nodes and edges, respectively.

We can further analyze the dataframes to extract the information we need.

For instance, we can construct a graph from the nodes and edges dataframes using the networkx library.

#### PrimeKG Nodes

`primekg_nodes` is a dataframe of nodes, which has the following columns:
- `node_index`: the index of the node
- `node`: the node name
- `node_id`: the id of the node (currently set as node name itself, for visualization purposes)
- `node_uid`: the unique identifier of the node (source name + unique id)
- `node_type`: the type of the node

We can check a sample of the primekg nodes to see the list of nodes in the PrimeKG dataset as follows.

In [8]:
# Check a sample of the primekg nodes
primekg_nodes.head()

Unnamed: 0,node_index,node_name,node_source,node_id,node_type
0,0,PHYHIP,NCBI,9796,gene/protein
1,1,GPANK1,NCBI,7918,gene/protein
2,2,ZRSR2,NCBI,8233,gene/protein
3,3,NRF1,NCBI,4899,gene/protein
4,4,PI4KA,NCBI,5297,gene/protein


The current version of PrimeKG has about 130K of nodes in total as we can observe in the following cell.

In [9]:
# Check dimensions of the primekg nodes
primekg_nodes.shape

(129375, 5)

 We can breakdown the statistics of the primekg nodes by their types as follows.

In [10]:
# Show node types and their counts
primekg_nodes['node_type'].value_counts()

node_type
biological_process    28642
gene/protein          27671
disease               17080
effect/phenotype      15311
anatomy               14035
molecular_function    11169
drug                   7957
cellular_component     4176
pathway                2516
exposure                818
Name: count, dtype: int64

PrimeKG was built using various sources, as we can observe from their unique node sources as follows.

In [11]:
# Show source of the primekg nodes
primekg_nodes['node_source'].value_counts()

node_source
GO               43987
NCBI             27671
MONDO            15813
HPO              15311
UBERON           14035
DrugBank          7957
REACTOME          2516
MONDO_grouped     1267
CTD                818
Name: count, dtype: int64

In [12]:
primekg_nodes[primekg_nodes['node_source'] == 'CTD']
# primekg_edges.head()

Unnamed: 0,node_index,node_name,node_source,node_id,node_type
61677,61677,1-hydroxyphenanthrene,CTD,C092102,exposure
61678,61678,1-hydroxypyrene,CTD,C033146,exposure
61679,61679,1-naphthol,CTD,C029350,exposure
61680,61680,"2,2',3',4,4',5-hexachlorobiphenyl",CTD,C029790,exposure
61681,61681,"2,2',3,5,5',6-hexachlorobiphenyl",CTD,C066675,exposure
...,...,...,...,...,...
127593,127593,Heptanes,CTD,D006536,exposure
127594,127594,octane,CTD,C026728,exposure
127595,127595,pseudocumene,CTD,C010313,exposure
127596,127596,pentane,CTD,C033353,exposure


In [13]:
test = EnrichmentWithPubChem()

In [14]:
test.enrich_documents(['24667'])

INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.


(["Butylated Hydroxyanisole can cause cancer according to The World Health Organization's International Agency for Research on Cancer (IARC)."],
 [None])

In [107]:
dup = primekg_nodes[primekg_nodes.node_name.isin(primekg_nodes[primekg_nodes.node_name.duplicated()].node_name.values)].sort_values('node_name')
dup[dup.node_name.isin(['rRNA processing', 'Horner syndrome', 'parallel fiber', 'Acetaminophen'])]

Unnamed: 0,node_index,node_name,node_source,node_id,node_type
61867,61867,Acetaminophen,CTD,D000082,exposure
14154,14154,Acetaminophen,DrugBank,DB00316,drug
83876,83876,Horner syndrome,MONDO,1294,disease
33753,33753,Horner syndrome,HPO,2277,effect/phenotype
75256,75256,parallel fiber,UBERON,2002218,anatomy
125914,125914,parallel fiber,GO,1990032,cellular_component
62887,62887,rRNA processing,REACTOME,R-HSA-72312,pathway
45528,45528,rRNA processing,GO,6364,biological_process


In [None]:
dup = primekg_nodes[primekg_nodes.node_name.isin(primekg_nodes[primekg_nodes.node_name.duplicated()].node_name.values)].sort_values('node_name')
for i in a.iterrows():
    print(i[1].node_name, i[1].node_source, i[1].node_type)

4-hydroxyphenylacetic aciduria HPO effect/phenotype
4-hydroxyphenylacetic aciduria MONDO disease
Acetaminophen CTD exposure
Acetaminophen DrugBank drug
Ammonia DrugBank drug
Ammonia CTD exposure
Amoxicillin CTD exposure
Amoxicillin DrugBank drug
Androstenedione CTD exposure
Androstenedione DrugBank drug
Attachment and Entry REACTOME pathway
Attachment and Entry REACTOME pathway
Axenfeld anomaly MONDO disease
Axenfeld anomaly HPO effect/phenotype
Barrett esophagus HPO effect/phenotype
Barrett esophagus MONDO disease
Budd-Chiari syndrome MONDO disease
Budd-Chiari syndrome HPO effect/phenotype
Burkitt lymphoma HPO effect/phenotype
Burkitt lymphoma MONDO disease
Caffeine DrugBank drug
Caffeine CTD exposure
Calcifediol DrugBank drug
Calcifediol CTD exposure
Calcium CTD exposure
Calcium DrugBank drug
Chloroform CTD exposure
Chloroform DrugBank drug
Chromium CTD exposure
Chromium DrugBank drug
Ciprofloxacin DrugBank drug
Ciprofloxacin CTD exposure
Clarithromycin DrugBank drug
Clarithromycin C

In [93]:

for i in primekg_nodes[primekg_nodes.node_name.isin(primekg_nodes[primekg_nodes.node_name.duplicated()].node_name.values)].sort_values('node_name'):
    print(i)

node_index
node_name
node_source
node_id
node_type


### Create a directed graph using the egdes

In [68]:
kg = nx.DiGraph()

## Make a KG using the edgelist
G = nx.from_pandas_edgelist(
    primekg_edges,
    source="head_name",
    target="tail_name",
    edge_key="relation",
    # edge_attr=["edge_id", "edge_type", "feature_value", "feature_id"],
    create_using=nx.DiGraph(),
)
kg = nx.compose(G, kg)

In [70]:
primekg_nodes.shape, kg.number_of_nodes()

((129375, 5), 129262)

In [69]:
primekg_edges.shape, kg.number_of_edges()

((8100498, 12), 8098805)

### Add additional node attributes (e.g. source, id and type)

In [None]:
# Start by extracting node information
df_nodes = primekg_nodes[['node_name', 'node_source', 'node_id', 'node_type']]
# Set the node_name as index
df_nodes = df_nodes.set_index('node_name')
# Add the additional attributes to graph
G.add_nodes_from((n, dict(d)) for n, d in df_nodes.iterrows())
# Recompose the graph
kg = nx.compose(G, kg)

In [None]:
df_nodes.shape

(8100498, 3)

# CTD enrichment
We will map CTD IDs to their corresponding PubChem IDs, and extract their descriptions and SMILES representation using EnrichmentWithPubChem.

In [17]:
from dataclasses import dataclass
# Create a dataclass to hold the node attributes
@dataclass
class PubChemAttr:
    """Dataclass to hold the attributes of a node."""
    pubchem_cid: str
    name: str
    # Make description optional
    # If not provided, it will be set to None
    description: str = None
    smiles: str = None


## Go iteratively over every CTD ID and fetch its description and SMILES rep

In [None]:
list_pubchem_attrs = []
# For the sake of space and time, we will enrich only the first 5 nodes of each DB
# Extract all gene IDs from the graph
pubchem_obj = EnrichmentWithPubChem()
pubchem_cids = []
count = 0
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_source') != 'CTD':
        continue
    count += 1
    # Get the node attributes
    node_attr = kg.nodes[node]
    # Cnvert CTD ID into PubChem ID
    pubchem_cid = external_id2pubchem_cid('Comparative Toxicogenomics Database', node_attr.get('node_id'))
    # Save all PubChem CIDs
    pubchem_cids.append(pubchem_cid)
    # Create a ReactomeAttr object
    pubchem_attr = PubChemAttr(
        pubchem_cid=pubchem_cid,
        name=node
    )
    list_pubchem_attrs.append(pubchem_attr)
    # if count == 2:
    #     break

# # Enrich PubChem attr
# for pubchem_attr in list_pubchem_attrs:
#     if pubchem_attr.pubchem_cid is None or pubchem_attr.pubchem_cid == '':
#         # If the PubChem CID is not available, skip the enrichment
#         continue
#     # Fetch descriptions and SMILES representation
#     description, smiles = pubchem_obj.enrich_documents([pubchem_attr.pubchem_cid])
#     # Add descriptions to the corresponding Reactome attributes
#     pubchem_attr.description = description[0]
#     pubchem_attr.smiles = smiles[0]

  0%|          | 0/129262 [00:00<?, ?it/s]INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem ID conversion.
  9%|▊         | 11282/129262 [00:00<00:08, 14056.07it/s]INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem ID conversion.
 14%|█▍        | 18355/129262 [00:01<00:12, 9149.51it/s] INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem ID conversion.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem ID conversion.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem ID conversion.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem ID conversion.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem ID conversion.
INFO:aiagents4pharma.talk2knowledgegraphs.util

ValueError: identifier/cid cannot be None

In [None]:
import pandas as pd
list_pubchem_attrs_df = pd.DataFrame([attr.__dict__ for attr in list_pubchem_attrs]).head()
list_pubchem_attrs_df.to_csv('../../../../data/primekg/pubchem_attrs.csv', index=False)

In [50]:
# Enrich PubChem attr
for pubchem_attr in list_pubchem_attrs:
    if pubchem_attr.pubchem_cid is None or pubchem_attr.pubchem_cid == '':
        # If the PubChem CID is not available, skip the enrichment
        continue
    # Fetch descriptions and SMILES representation
    description, smiles = pubchem_obj.enrich_documents([pubchem_attr.pubchem_cid])
    # Add descriptions to the corresponding Reactome attributes
    pubchem_attr.description = description[0]
    pubchem_attr.smiles = smiles[0]

INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra configuration for PubChem CID description.
INFO:aiagents4pharma.talk2knowledgegraphs.utils.pubchem_utils:Load Hydra

## Add descrioptions to the CTD nodes and recompose the graph

In [51]:
for pubchem_attr in list_pubchem_attrs:
    node = pubchem_attr.name
    description = pubchem_attr.description
    # print (f"node: {node}, description: {description}")
    G.add_nodes_from([(node, {'description': description})])

# Recompose the graph
kg = nx.compose(G, kg)

In [52]:
list_pubchem_attrs

[PubChemAttr(pubchem_cid=3036, name='DDT', description="DDT (Dichlorodiphenyltrichloroethane) can cause cancer according to an independent committee of scientific and health experts. It can cause developmental toxicity, female reproductive toxicity and male reproductive toxicity according to The National Institute for Occupational Safety and Health (NIOSH) and The Environmental Protection Agency (EPA).DDT is a chlorophenylethane that is 1,1,1-trichloro-2,2-diphenylethane substituted by additional chloro substituents at positions 4 of the phenyl substituents. It is a commonly used organochlorine insecticide. It has a role as a bridged diphenyl acaricide, a carcinogenic agent, a persistent organic pollutant and an endocrine disruptor. It is an organochlorine insecticide, a benzenoid aromatic compound, a member of monochlorobenzenes and a chlorophenylethane. It is functionally related to a 1,1,1-trichloro-2,2-diphenylethane and a 4,4'-dichlorodiphenylmethane.", smiles=None),
 PubChemAttr(

## SMILES Enrichment and embedding over PrimeKG using NIM/MOLMIM

### Dataclass to store drug data

In [53]:

from dataclasses import dataclass

@dataclass
class DrugData:
    name: str
    drugbank_id: str
    pubchem_cid: str = None
    smiles: str = None
    embed_smiles: list = None

dic_drug_data = {}

Load drug data in PrimeKG into the dic

In [58]:
primekg_nodes[primekg_nodes.node_source == 'DrugBank']

Unnamed: 0,node_index,node_name,node_source,node_id,node_type
14012,14012,Copper,DrugBank,DB09130,drug
14013,14013,Oxygen,DrugBank,DB09140,drug
14014,14014,Flunisolide,DrugBank,DB00180,drug
14015,14015,Alclometasone,DrugBank,DB00240,drug
14016,14016,Medrysone,DrugBank,DB00253,drug
...,...,...,...,...,...
39893,39893,Cathine,DrugBank,DB01486,drug
39894,39894,Sulfur hexafluoride,DrugBank,DB11104,drug
39895,39895,Butoconazole,DrugBank,DB00639,drug
39896,39896,Gadoversetamide,DrugBank,DB00538,drug


In [54]:
count = 0
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_source') != 'DrugBank':
        continue
    count += 1

100%|██████████| 129262/129262 [00:00<00:00, 1713096.24it/s]


In [55]:
count

7894

In [None]:
from tqdm import tqdm
# Iterate over the primekg nodes with node_source as 'DrugBank'
for index, row in tqdm(primekg_nodes[primekg_nodes['node_source'] == 'DrugBank'].iterrows(), total=primekg_nodes[primekg_nodes['node_source'] == 'DrugBank'].shape[0]):
    if row['node_source'] == 'DrugBank' and not row['node_name'].endswith('mab'):
        drug_name = row['node_name']
        drugbank_id = row['node_id']
        drug_data = DrugData(name=drug_name, drugbank_id=drugbank_id)
        dic_drug_data[drugbank_id] = drug_data
# Print the number of drug names mapped to DrugBank IDs
print(f"Number of drugs mapped to DrugBank IDs: {len(dic_drug_data)}")

### SMILES enrichment over drug data using PubChemPy

Since the drug id in PrimeKG are provided as DrugBank ID, we will convert them into their corresponding PubChemID, and use it to extract their SMILES strings representation.

**NOTE**: Comment the `count` variable if you want to get SMILES representation of all the drugs

In [None]:
enrichment = pubchem_strings.EnrichmentWithPubChem()

# Get their SMILES strings
for count, (drugbank_id, drug_data) in enumerate(dic_drug_data.items()):
    # Get PubChem CID from DrugBank ID using pubchem_utils method
    dic_drug_data[drugbank_id].pubchem_id = pubchem_utils.drugbank_id2pubchem_cid(drugbank_id)
    print (f"DrugBank ID: {drugbank_id}, PubChem CID: {dic_drug_data[drugbank_id].pubchem_id}")
    # Get SMILES from PubChem CID using enrichment method
    if dic_drug_data[drugbank_id].pubchem_id:
        smiles = enrichment.enrich_documents([dic_drug_data[drugbank_id].pubchem_id])
        dic_drug_data[drugbank_id].smiles = smiles[0]
        print (f"DrugBank ID: {drugbank_id}, SMILES: {smiles[0]}")
        # Delete the counter to get all the SMILES
        # if count == 2:
        #     break

### Embedding SMILES strings using NVIDIA's optimized MOLMIM

We will use the `EmbeddingWithMOLMIM` class to get the embeddings.
This class requires `base_url` value at the time of initialization. You must have NIM/MOLMIM running locally or on a remote machine.

In [None]:
# Load all the SMILES strings into a list
smiles_list = []
for drugbank_id in sorted(dic_drug_data.keys()):
    # Check if the SMILES string is not None
    if dic_drug_data[drugbank_id].smiles:
        smiles_list.append(drug_data.smiles)

# Embed the SMILES strings
# Define the base URL for the embedding service
base_url = "http://localhost:8000/embedding"
embedding = nim_molmim.EmbeddingWithMOLMIM(base_url=base_url)
smiles_embedding = embedding.embed_documents(smiles_list)

counter = 0
for drugbank_id in sorted(dic_drug_data.keys()):
    # Check if the SMILES string is not None
    if dic_drug_data[drugbank_id].smiles:
        dic_drug_data[drugbank_id].embed_smiles = smiles_embedding[counter]
        print (f"DrugBank ID: {drugbank_id}, SMILES: {dic_drug_data[drugbank_id].smiles}, Embedding: {dic_drug_data[drugbank_id].embed_smiles}")
        counter += 1

## Generate embedding of CTD descriptions

In [None]:
for i, node in tqdm(enumerate(kg.nodes)):
    node_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    print (node)
    desc = kg.nodes[node].get('description')
    # print (desc)
    outputs = biobert_model.embed_documents([desc])
    # print (outputs)
    G.add_nodes_from([(node, {'description_embedding': outputs})])
    # torch.cuda.synchronize()
    # torch.cuda.empty_cache()

# Recompose the graph
kg = nx.compose(G, kg)

## Display a DF with results

In [None]:
import pandas as pd
dic = {'node':[],
       'node_source':[],
       'node_id':[],
       'description':[],
       'description_embedding':[]}
for node in tqdm(kg.nodes):
    node_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    dic['node'].append(node)
    dic['node_source'].append(kg.nodes[node].get('node_source'))
    dic['node_id'].append(kg.nodes[node].get('node_id'))
    dic['description'].append(kg.nodes[node].get('description'))
    dic['description_embedding'].append(kg.nodes[node].get('description_embedding'))
    # print (node, kg.nodes[node].get('description'), kg.nodes[node].get('sequence'), kg.nodes[node].get('description_embedding'))

df = pd.DataFrame(dic)
df

# Reactome pathway enrichment
We will use Reactome API services to extract textual descriptions of pathways using the EnrichmentWithReactome class.

In [None]:
from dataclasses import dataclass
# Create a dataclass to hold the node attributes
@dataclass
class ReactomeAttr:
    """Dataclass to hold the attributes of a node."""
    pathway_id: str
    name: str
    # Make description optional
    # If not provided, it will be set to None
    description: str = None


## Go iteratively over every pathway and fetch its description

In [None]:
list_reactome_attrs = []
# For the sake of space and time, we will enrich only the first 5 nodes of each DB
# Extract all gene IDs from the graph
reactome_obj = EnrichmentWithReactome()
count = 0
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_source') != 'REACTOME':
        continue
    count += 1
    # Get the node attributes
    node_attr = kg.nodes[node]
    # Create a ReactomeAttr object
    reactome_attr = ReactomeAttr(
        pathway_id=node_attr.get('node_id'),
        name=node,
        description=node_attr.get('description')
    )
    list_reactome_attrs.append(reactome_attr)
    # if count == 2:
    #     break
for reactome_attr in list_reactome_attrs:
    # Fetch descriptions
    description = reactome_obj.enrich_documents([reactome_attr.pathway_id])
    # Add descriptions to the corresponding Reactome attributes
    reactome_attr.description = description[0]
print (list_reactome_attrs)

## Add descriptions to the Reactome nodes and recompose the graph

In [None]:
for reactome_attr in list_reactome_attrs:
    node = reactome_attr.name
    description = reactome_attr.description
    # print (f"node: {node}, description: {description}")
    G.add_nodes_from([(node, {'description': description})])

# Recompose the graph
kg = nx.compose(G, kg)

## Generate embeddings of descriptions of reactome pathways

In [None]:
for i, node in tqdm(enumerate(kg.nodes)):
    node_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    print (node)
    desc = kg.nodes[node].get('description')
    # print (desc)
    outputs = biobert_model.embed_documents([desc])
    # print (outputs)
    G.add_nodes_from([(node, {'description_embedding': outputs})])
    # torch.cuda.synchronize()
    # torch.cuda.empty_cache()

# Recompose the graph
kg = nx.compose(G, kg)

## Display the results in a DF

In [None]:
import pandas as pd
dic = {'node':[],
       'node_source':[],
       'node_id':[],
       'description':[],
       'description_embedding':[]}
for node in tqdm(kg.nodes):
    node_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    dic['node'].append(node)
    dic['node_source'].append(kg.nodes[node].get('node_source'))
    dic['node_id'].append(kg.nodes[node].get('node_id'))
    dic['description'].append(kg.nodes[node].get('description'))
    dic['description_embedding'].append(kg.nodes[node].get('description_embedding'))
    # print (node, kg.nodes[node].get('description'), kg.nodes[node].get('sequence'), kg.nodes[node].get('description_embedding'))

df = pd.DataFrame(dic)
df

# OLS terms enrichments

OLS is the Ontology Lookup Service by EMBL/EBI. We will use their API services to extract textual descriptions of the following terms using the EnrichmentWithOLS class.
1. GO
2. HPO
3. UBERON
4. MONDO
5. MONDO_grouped

In [None]:
from dataclasses import dataclass
# Create a dataclass to hold the node attributes
@dataclass
class OLSAttr:
    """Dataclass to hold the attributes of a node."""
    term_id: str
    name: str
    label: str = None
    # Make description optional
    # If not provided, it will be set to None
    description: str = None


In [None]:
# Define a dictionary to store DB name and its OLS code
dic_ols = {
    'GO': 'GO',
    'HPO': 'HP',
    'UBERON': 'UBERON',
    'MONDO': 'MONDO',
}

## Go iteratively over every DB in OLS and store results in a dic

In [None]:
list_ols_attrs = []
term_ids = []
# For the sake of space and time, we will enrich only the first 5 nodes of each DB
# Extract all gene IDs from the graph
ols_obj = EnrichmentWithOLS()
for source in ['GO', 'MONDO', 'HPO', 'UBERON']:
    count = 0
    for node in tqdm(kg.nodes):
        if kg.nodes[node].get('node_source') != source:
            continue
        count += 1
        # Get the node attributes
        node_attr = kg.nodes[node]
        # Term ID
        # OLS term must contain 7-digit integer code
        # Hence, prefix with 0s such that total number
        # of characters is 7
        term_id = dic_ols[source] + '_' + str("{:07}".format(int(node_attr.get('node_id'))))
        term_ids.append(term_id)
        # Create a OLSAttr object
        ols_attr = OLSAttr(
            term_id=term_id,
            name=node,
            label=node,
            description=node_attr.get('description')
        )
        list_ols_attrs.append(ols_attr)
        if count == 2:
            break
# Fetch descriptions
descriptions = ols_obj.enrich_documents(term_ids)
# Add descriptions to the corresponding OLS attributes
for ols_attr, description in zip(list_ols_attrs, descriptions):
    ols_attr.description = description

## Repeat the same for MONDO_grouped expect concatenate descriptions of all IDs in a group

In [None]:
# For the sake of space and time, we will enrich only the first 5 nodes of each DB
# Extract all gene IDs from the graph
ols_obj = EnrichmentWithOLS()
count = 0
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_source') != 'MONDO_grouped':
        continue
    count += 1
    # Get the node attributes
    node_attr = kg.nodes[node]
    # MONDO_grouped contains multiple codes
    # separated by a '_'
    # OLS term must contain 7-digit integer code
    # Hence, prefix with 0s such that total number
    # of characters is 7
    codes = node_attr.get('node_id')
    codes = codes.split('_')
    # print (codes)
    term_ids = []
    for code in codes:
        term_id = 'MONDO_' + str("{:07}".format(int(code)))
        term_ids.append(term_id)
    # Create a OLSAttr object
    ols_attr = OLSAttr(
        term_id=node_attr.get('node_id'),
        name=node,
        label=node,
        description=node_attr.get('description')
    )
    # Fetch descriptions
    descriptions = ols_obj.enrich_documents(term_ids)
    # Add descriptions to the corresponding OLS attributes
    ols_attr.description = '\n'.join(descriptions)
    list_ols_attrs.append(ols_attr)
    # if count == 2:
    #     break
print (list_ols_attrs)

## Add descrioptions to the OLS nodes and recompose the graph

In [None]:
for ols_attr in list_ols_attrs:
    node = ols_attr.name
    description = ols_attr.description
    # print (f"node: {node}, description: {description}")
    G.add_nodes_from([(node, {'description': description})])

# Recompose the graph
kg = nx.compose(G, kg)

## Generate embedding for all the nodes with textual descriptions

In [None]:
for i, node in tqdm(enumerate(kg.nodes)):
    node_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    print (node)
    desc = kg.nodes[node].get('description')
    outputs = biobert_model.embed_documents([desc])
    G.add_nodes_from([(node, {'description_embedding': outputs})])
    # torch.cuda.synchronize()
    # torch.cuda.empty_cache()

# Recompose the graph
kg = nx.compose(G, kg)

## Display the results in a DF

In [None]:
import pandas as pd
dic = {'node':[],
       'node_source':[],
       'node_id':[],
       'description':[],
       'description_embedding':[]}
for node in tqdm(kg.nodes):
    node_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    dic['node'].append(node)
    dic['node_source'].append(kg.nodes[node].get('node_source'))
    dic['node_id'].append(kg.nodes[node].get('node_id'))
    dic['description'].append(kg.nodes[node].get('description'))
    dic['description_embedding'].append(kg.nodes[node].get('description_embedding'))
    # print (node, kg.nodes[node].get('description'), kg.nodes[node].get('sequence'), kg.nodes[node].get('description_embedding'))

df = pd.DataFrame(dic)
df

# Protein enrichments
Now, we will encrich the protein nodes with their description and sequence.
We will query the UniProt via API to get the descp and sequence. For this, we
will first need to get all the node IDs.

In [None]:
from dataclasses import dataclass
# Create a dataclass to hold the node attributes
@dataclass
class GeneAttr:
    """Dataclass to hold the attributes of a gene node."""
    id: str
    name: str
    # Make description optional
    # If not provided, it will be set to None
    description: str = None
    sequence: str = None


### Get node IDs

In [None]:
# Extract all gene IDs from the graph
dic_gene_ids = {}
for n in tqdm(kg.nodes):
    if kg.nodes[n].get('node_type') != 'gene/protein' and kg.nodes[n].get('node_source') != 'NCBI':
        continue
    # Get the node attributes
    node_attr = kg.nodes[n]
    # Create a GeneAttr object
    gene_attr = GeneAttr(
        id=node_attr.get('node_id'),
        name=n,
        description=node_attr.get('description'),
        sequence=node_attr.get('sequence')
    )
    # Add the gene_attr object to the dictionary
    dic_gene_ids[node_attr.get('node_id')] = gene_attr
# Check the number of gene IDs
len(dic_gene_ids)

### Submit a job to UniProt to map the Gene ID to its description and sequence

Here we show 2 ways to get description and sequence of a gene:
1. Most of the biomedical graphs offer gene names, which can be used to extract sequence and description using the EnrichmentWithUniProt class in the utils of T2KG
2. Some graphs, like PrimeKG, also offer gene IDs, which can also be used to extract sequence and descriotion using the snippet defined (borrowed from UniProt)

In [None]:
import time
import requests
from requests.adapters import HTTPAdapter, Retry
from urllib.parse import urlparse, parse_qs, urlencode

# Define variables to perform UniProt ID mapping
# Adopted from https://www.uniprot.org/help/id_mapping
API_URL = "https://rest.uniprot.org"
POLLING_INTERVAL = 5
retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504])
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=retries))

def submit_id_mapping(from_db, to_db, ids) -> str:
    """
    Function to submit a job to perform ID mapping.

    Args:
        from_db (str): The source database.
        to_db (str): The target database.
        ids (list): The list of IDs to map.

    Returns:
        str: The job ID.
    """
    request = requests.post(f"{API_URL}/idmapping/run",
                            data={"from": from_db,
                                  "to": to_db,
                                  "ids": ",".join(ids)},)
    try:
        request.raise_for_status()
    except requests.HTTPError:
        print(request.json())
        raise

    return request.json()["jobId"]

def check_id_mapping_results_ready(job_id):
    """
    Function to check if the ID mapping results are ready.

    Args:
        job_id (str): The job ID.

    Returns:
        bool: True if the results are ready, False otherwise.
    """
    while True:
        request = session.get(f"{API_URL}/idmapping/status/{job_id}")

        try:
            request.raise_for_status()
        except requests.HTTPError:
            print(request.json())
            raise

        j = request.json()
        if "jobStatus" in j:
            if j["jobStatus"] in ("NEW", "RUNNING"):
                print(f"Retrying in {POLLING_INTERVAL}s")
                time.sleep(POLLING_INTERVAL)
            else:
                raise Exception(j["jobStatus"])
        else:
            return bool(j["results"] or j["failedIds"])

def get_id_mapping_results_link(job_id):
    """
    Function to get the link to the ID mapping results.

    Args:
        job_id (str): The job ID.

    Returns:
        str: The link to the ID mapping results.
    """
    url = f"{API_URL}/idmapping/details/{job_id}"
    request = requests.Session().get(url)

    try:
        request.raise_for_status()
    except requests.HTTPError:
        print(request.json())
        raise

    return request.json()["redirectURL"]

def decode_results(response, file_format, compressed):
    """
    Function to decode the ID mapping results.

    Args:
        response (requests.Response): The response object.
        file_format (str): The file format of the results.
        compressed (bool): Whether the results are compressed.

    Returns:
        str: The ID mapping results
    """

    if compressed:
        decompressed = zlib.decompress(response.content, 16 + zlib.MAX_WBITS)
        if file_format == "json":
            j = json.loads(decompressed.decode("utf-8"))
            return j
        elif file_format == "tsv":
            return [line for line in decompressed.decode("utf-8").split("\n") if line]
        elif file_format == "xlsx":
            return [decompressed]
        elif file_format == "xml":
            return [decompressed.decode("utf-8")]
        else:
            return decompressed.decode("utf-8")
    elif file_format == "json":
        return response.json()
    elif file_format == "tsv":
        return [line for line in response.text.split("\n") if line]
    elif file_format == "xlsx":
        return [response.content]
    elif file_format == "xml":
        return [response.text]
    return response.text

def get_id_mapping_results_stream(url):
    """
    Function to get the ID mapping results from a stream.

    Args:
        url (str): The URL to the ID mapping results.

    Returns:
        str: The ID mapping results.
    """
    if "/stream/" not in url:
        url = url.replace("/results/", "/results/stream/")

    request = session.get(url)

    try:
        request.raise_for_status()
    except requests.HTTPError:
        print(request.json())
        raise

    parsed = urlparse(url)
    query = parse_qs(parsed.query)
    file_format = query["format"][0] if "format" in query else "json"
    compressed = (
        query["compressed"][0].lower() == "true" if "compressed" in query else False
    )
    return decode_results(request, file_format, compressed)

For the sake of time, we will use only the **first 5 nodes**

In [None]:
# Add the top 5 gene IDs to a list
inputs = list(dic_gene_ids.keys())[:5]
# Submit the job to perform ID mapping
job_id = submit_id_mapping(
    from_db="GeneID", to_db="UniProtKB", ids=inputs
)
# Print the job ID
print (f"Job ID: {job_id}")
# Check the status of the job
status = check_id_mapping_results_ready(job_id)
# Print the status of the job
print (f"Job status: {status}")

Check and get the ID mapping results

In [None]:
if check_id_mapping_results_ready(job_id):
    link = get_id_mapping_results_link(job_id)
    mapping_results = get_id_mapping_results_stream(link)
    print(mapping_results)

### Store the mapping results in a dictionary
Key is the gene ID and value is a nested dictionary with keys "description" and "sequence"

In [None]:
dic_gene_id_to_descp_seq = {}
for result in mapping_results['results']:
    # print(result['to'])
    if result['to']['entryType'] == 'UniProtKB reviewed (Swiss-Prot)':
        # print (result['from'], result['to'])
        dic_gene_id_to_descp_seq[result['from']] = {}
        for comment in result['to']['comments']:
            if comment['commentType'] == 'FUNCTION':
                for text in comment['texts']:
                    # print (text['value'])
                    description = text['value']
        dic_gene_id_to_descp_seq[result['from']]['description'] = description
        dic_gene_id_to_descp_seq[result['from']]['sequence'] = result['to']['sequence']['value']

# Display the contents of the dictionary
for gene_id, descp_seq in dic_gene_id_to_descp_seq.items():
    print(f"Gene ID: {gene_id}")
    print(f"Description: {descp_seq['description']}")
    print(f"Sequence: {descp_seq['sequence']}")
    print()


## Most of the biomedical graphs offer gene names, hence you can choose to also query seqeuences and descriptions via the Gene names using the EnrichmentWithUniprot class

In [None]:
for gene_id in inputs:
    # Get the gene name of the gene ID
    gene_name = dic_gene_ids[gene_id].name
    print (f"Gene name: {gene_name}")
    # Create an instance of the EnrichmentWithUniProt class
    enrich_uniprot = EnrichmentWithUniProt()
    # Get the sequence and description for the gene name
    description, sequence = enrich_uniprot.enrich_documents([gene_name])
    dic_gene_id_to_descp_seq[gene_id]['description'] = description
    dic_gene_id_to_descp_seq[gene_id]['sequence'] = sequence
    print (f"Gene name: {gene_name}\nDescription: {description}\nSequence: {sequence}")

### Map the description and sequence from the dictionary to their corresponding nodes in the graph

In [None]:
from tqdm import tqdm
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    # Ignore the genes/proteins without description
    if gene_id not in dic_gene_id_to_descp_seq:
        continue
    description = dic_gene_id_to_descp_seq[gene_id]['description']
    sequence = dic_gene_id_to_descp_seq[gene_id]['sequence']
    print (f"node: {node}, gene ID: {gene_id}, description: {description}, sequence: {sequence}")
    G.add_nodes_from([(node, {'description': description, 'sequence': sequence})])

# Recompose the graph
kg = nx.compose(G, kg)

### Check device availability

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

### Load the ESM2 model

In [None]:
emb_model = EmbeddingWithHuggingFace(model_name='facebook/esm2_t6_8M_UR50D',
                                     model_cache_dir="../../../../data/facebook/esm2_t6_8M_UR50D/",
                                     truncation=False,
                                     device=device)

### Generate sequence embedding and add it to the graph as new attribute "sequence_embedding"

In [None]:
# Embeddings using 1 sample at a time
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('sequence') is None:
        continue
    seq = kg.nodes[node].get('sequence')
    # print (node, seq)
    outputs = emb_model.embed_documents([seq])
    G.add_nodes_from([(node, {'sequence_embedding': outputs[0]})])
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

# Recompose the graph
kg = nx.compose(G, kg)

### Protein embedding
Load the BioBERT model

In [None]:
# Using MSFT's BioBERT
emb_model = EmbeddingWithHuggingFace(model_name='microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract',
                                     model_cache_dir="../../../../data/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract/",
                                     truncation=False,
                                     device=device)

### Generate description embedding and add it to the graph as new attribute "description_embedding"

In [None]:
for i, node in tqdm(enumerate(kg.nodes)):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    desc = kg.nodes[node].get('description')
    outputs = emb_model.embed_documents([desc])
    G.add_nodes_from([(node, {'description_embedding': outputs})])
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

# Recompose the graph
kg = nx.compose(G, kg)

### Put together all the results so far in a df

In [None]:
import pandas as pd
dic = {'gene':[],
       'description':[],
       'sequence':[],
       'description_embedding':[],
       'sequence_embedding':[]}
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    dic['gene'].append(node)
    dic['description'].append(kg.nodes[node].get('description'))
    dic['sequence'].append(kg.nodes[node].get('sequence'))
    dic['description_embedding'].append(kg.nodes[node].get('description_embedding'))
    dic['sequence_embedding'].append(kg.nodes[node].get('sequence_embedding'))
    # print (node, kg.nodes[node].get('description'), kg.nodes[node].get('sequence'), kg.nodes[node].get('description_embedding'))

df = pd.DataFrame(dic)
df