# PrimeKG Enrichment

In this tutorial, we will explain how to perform multimodal enrichment of PrimeKG.

Prior information about the PrimeKG can be found in the following repositories:
- https://github.com/mims-harvard/PrimeKG
- https://github.com/mims-harvard/TDC/

Note that we are leveraging the PrimeKG provided in Harvard Dataverse, which is publicly available in the following link:

https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/IXA7BM

By the time we are writing this tutorial, the latest version of PrimeKG (`kg.csv`) is `2.1`.

First of all, we need to import necessary libraries as follows:

In [1]:
# Import necessary libraries
import sys
import os
import torch
import networkx as nx
import pandas as pd
from tqdm import tqdm
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.datasets.primekg import PrimeKG
# from aiagents4pharma.talk2knowledgegraphs.utils.enrichments.uniprot_proteins import EnrichmentWithUniProt
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.huggingface import EmbeddingWithHuggingFace

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Import necessary libraries
import time
import json
import zlib
import requests
from requests.adapters import HTTPAdapter, Retry
from urllib.parse import urlparse, parse_qs

# Define variables to perform UniProt ID mapping
# Adopted from https://www.uniprot.org/help/id_mapping
API_URL = "https://rest.uniprot.org"
POLLING_INTERVAL = 5
retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504])
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=retries))

def submit_id_mapping(from_db, to_db, ids) -> str:
    """
    Function to submit a job to perform ID mapping.

    Args:
        from_db (str): The source database.
        to_db (str): The target database.
        ids (list): The list of IDs to map.

    Returns:
        str: The job ID.
    """
    request = requests.post(f"{API_URL}/idmapping/run",
                            data={"from": from_db,
                                  "to": to_db,
                                  "ids": ",".join(ids)},)
    try:
        request.raise_for_status()
    except requests.HTTPError:
        print(request.json())
        raise

    return request.json()["jobId"]

def check_id_mapping_results_ready(job_id):
    """
    Function to check if the ID mapping results are ready.

    Args:
        job_id (str): The job ID.

    Returns:
        bool: True if the results are ready, False otherwise.
    """
    while True:
        request = session.get(f"{API_URL}/idmapping/status/{job_id}")

        try:
            request.raise_for_status()
        except requests.HTTPError:
            print(request.json())
            raise

        j = request.json()
        if "jobStatus" in j:
            if j["jobStatus"] in ("NEW", "RUNNING"):
                print(f"Retrying in {POLLING_INTERVAL}s")
                time.sleep(POLLING_INTERVAL)
            else:
                raise Exception(j["jobStatus"])
        else:
            return bool(j["results"] or j["failedIds"])

def get_id_mapping_results_link(job_id):
    """
    Function to get the link to the ID mapping results.

    Args:
        job_id (str): The job ID.

    Returns:
        str: The link to the ID mapping results.
    """
    url = f"{API_URL}/idmapping/details/{job_id}"
    request = requests.Session().get(url)

    try:
        request.raise_for_status()
    except requests.HTTPError:
        print(request.json())
        raise

    return request.json()["redirectURL"]

def decode_results(response, file_format, compressed):
    """
    Function to decode the ID mapping results.

    Args:
        response (requests.Response): The response object.
        file_format (str): The file format of the results.
        compressed (bool): Whether the results are compressed.

    Returns:
        str: The ID mapping results
    """

    if compressed:
        decompressed = zlib.decompress(response.content, 16 + zlib.MAX_WBITS)
        if file_format == "json":
            j = json.loads(decompressed.decode("utf-8"))
            return j
        elif file_format == "tsv":
            return [line for line in decompressed.decode("utf-8").split("\n") if line]
        elif file_format == "xlsx":
            return [decompressed]
        elif file_format == "xml":
            return [decompressed.decode("utf-8")]
        else:
            return decompressed.decode("utf-8")
    elif file_format == "json":
        return response.json()
    elif file_format == "tsv":
        return [line for line in response.text.split("\n") if line]
    elif file_format == "xlsx":
        return [response.content]
    elif file_format == "xml":
        return [response.text]
    return response.text

def get_id_mapping_results_stream(url):
    """
    Function to get the ID mapping results from a stream.

    Args:
        url (str): The URL to the ID mapping results.

    Returns:
        str: The ID mapping results.
    """
    if "/stream/" not in url:
        url = url.replace("/results/", "/results/stream/")

    request = session.get(url)

    try:
        request.raise_for_status()
    except requests.HTTPError:
        print(request.json())
        raise

    parsed = urlparse(url)
    query = parse_qs(parsed.query)
    file_format = query["format"][0] if "format" in query else "json"
    compressed = (
        query["compressed"][0].lower() == "true" if "compressed" in query else False
    )
    return decode_results(request, file_format, compressed)

### Load PrimeKG

The `PrimeKG` dataset allows to load the data from the Harvard Dataverse server if the data is not available locally. 

Otherwise, the data is loaded from the local directory as defined in the `local_dir`.

In [3]:
# Define primekg data by providing a local directory where the data is stored
primekg_data = PrimeKG(local_dir="../../../../data/primekg/")

To load the dataframes of nodes and edges from PrimeKG, we just need to invoke a method as follows.

In [4]:
# Invoke a method to load the data
primekg_data.load_data()

# Get primekg_nodes and primekg_edges
primekg_nodes = primekg_data.get_nodes()
primekg_edges = primekg_data.get_edges()

Loading nodes of PrimeKG dataset ...
../../../../data/primekg/primekg_nodes.tsv.gz already exists. Loading the data from the local directory.
Loading edges of PrimeKG dataset ...
../../../../data/primekg/primekg_edges.tsv.gz already exists. Loading the data from the local directory.


### Check PrimeKG Dataframes

As mentioned before, the primekg_nodes and primekg_edges are the dataframes of nodes and edges, respectively.

We can further analyze the dataframes to extract the information we need.

For instance, we can construct a graph from the nodes and edges dataframes using the networkx library.

#### PrimeKG Nodes

`primekg_nodes` is a dataframe of nodes, which has the following columns:
- `node_index`: the index of the node
- `node`: the node name
- `node_id`: the id of the node (currently set as node name itself, for visualization purposes)
- `node_uid`: the unique identifier of the node (source name + unique id)
- `node_type`: the type of the node

We can check a sample of the primekg nodes to see the list of nodes in the PrimeKG dataset as follows.

In [5]:
# Check a sample of the primekg nodes
primekg_nodes.head()

Unnamed: 0,node_index,node_name,node_source,node_id,node_type
0,0,PHYHIP,NCBI,9796,gene/protein
1,1,GPANK1,NCBI,7918,gene/protein
2,2,ZRSR2,NCBI,8233,gene/protein
3,3,NRF1,NCBI,4899,gene/protein
4,4,PI4KA,NCBI,5297,gene/protein


The current version of PrimeKG has about 130K of nodes in total as we can observe in the following cell.

In [6]:
# Check dimensions of the primekg nodes
primekg_nodes.shape

(129375, 5)

 We can breakdown the statistics of the primekg nodes by their types as follows.

In [7]:
# Show node types and their counts
primekg_nodes['node_type'].value_counts()

node_type
biological_process    28642
gene/protein          27671
disease               17080
effect/phenotype      15311
anatomy               14035
molecular_function    11169
drug                   7957
cellular_component     4176
pathway                2516
exposure                818
Name: count, dtype: int64

PrimeKG was built using various sources, as we can observe from their unique node sources as follows.

In [8]:
# Show source of the primekg nodes
primekg_nodes['node_source'].value_counts()

node_source
GO               43987
NCBI             27671
MONDO            15813
HPO              15311
UBERON           14035
DrugBank          7957
REACTOME          2516
MONDO_grouped     1267
CTD                818
Name: count, dtype: int64

In [9]:
primekg_nodes[primekg_nodes['node_source'] == 'NCBI'].head(10)
primekg_edges.head()

Unnamed: 0,head_index,head_name,head_source,head_id,head_type,tail_index,tail_name,tail_source,tail_id,tail_type,display_relation,relation
0,0,PHYHIP,NCBI,9796,gene/protein,8889,KIF15,NCBI,56992,gene/protein,ppi,protein_protein
1,1,GPANK1,NCBI,7918,gene/protein,2798,PNMA1,NCBI,9240,gene/protein,ppi,protein_protein
2,2,ZRSR2,NCBI,8233,gene/protein,5646,TTC33,NCBI,23548,gene/protein,ppi,protein_protein
3,3,NRF1,NCBI,4899,gene/protein,11592,MAN1B1,NCBI,11253,gene/protein,ppi,protein_protein
4,4,PI4KA,NCBI,5297,gene/protein,2122,RGS20,NCBI,8601,gene/protein,ppi,protein_protein


### Create a directed graph using the egdes

In [10]:
kg = nx.DiGraph()

## Make a KG using the edgelist
G = nx.from_pandas_edgelist(
    primekg_edges,
    source="head_name",
    target="tail_name",
    edge_key="relation",
    # edge_attr=["edge_id", "edge_type", "feature_value", "feature_id"],
    create_using=nx.DiGraph(),
)
kg = nx.compose(G, kg)

### Add additional node attributes (e.g. source, id and type)

In [11]:
# Start by extracting slicing the df to include only thge head nodes
df_head_nodes = primekg_edges[['head_name', 'head_source', 'head_id', 'head_type']]
# Rename the columns
df_head_nodes = df_head_nodes.rename(columns={
    'head_name': 'node_name',
    'head_source': 'node_source',
    'head_id': 'node_id',
    'head_type': 'node_type'
})
# Set the node_name as index
df_head_nodes = df_head_nodes.set_index('node_name')
# Add the additional attributes to graph
G.add_nodes_from((n, dict(d)) for n, d in df_head_nodes.iterrows())
# Recompose the graph
kg = nx.compose(G, kg)

# Protein enrichments
Now, we will encrich the protein nodes with their description and sequence.
We will query the UniProt via API to get the descp and sequence. For this, we
will first need to get all the node IDs.

### Get node IDs

In [12]:
# Extract all gene IDs from the graph
gene_ids = set()
for n in tqdm(kg.nodes):
    if kg.nodes[n].get('node_type') != 'gene/protein' and kg.nodes[n].get('node_source') != 'NCBI':
        continue
    gene_ids.add(kg.nodes[n].get('node_id'))
len(list(gene_ids))

100%|██████████| 129262/129262 [00:00<00:00, 965957.84it/s]


27609

27609

### Submit a job to UniProt to map the Gene ID to its description and sequence
For the sake of time, we will use only the **first 10 nodes**

In [13]:
inputs = list(gene_ids)[:10]

job_id = submit_id_mapping(
    from_db="GeneID", to_db="UniProtKB", ids=inputs
)
print (f"Job ID: {job_id}")
# Check the status of the job
status = check_id_mapping_results_ready(job_id)
print (f"Job status: {status}")

Job ID: pya125zGEs
Retrying in 5s
Job status: True


### Check and get the ID mapping results

In [14]:
if check_id_mapping_results_ready(job_id):
    link = get_id_mapping_results_link(job_id)
    mapping_results = get_id_mapping_results_stream(link)
    print(mapping_results)

{'results': [{'from': '8821', 'to': {'entryType': 'UniProtKB reviewed (Swiss-Prot)', 'primaryAccession': 'O15327', 'secondaryAccessions': ['Q2TAI2', 'Q5XLE7', 'Q6IN59', 'Q6PJB4'], 'uniProtkbId': 'INP4B_HUMAN', 'entryAudit': {'firstPublicDate': '2005-08-30', 'lastAnnotationUpdateDate': '2025-04-09', 'lastSequenceUpdateDate': '2008-12-16', 'entryVersion': 176, 'sequenceVersion': 4}, 'annotationScore': 5.0, 'organism': {'scientificName': 'Homo sapiens', 'commonName': 'Human', 'taxonId': 9606, 'lineage': ['Eukaryota', 'Metazoa', 'Chordata', 'Craniata', 'Vertebrata', 'Euteleostomi', 'Mammalia', 'Eutheria', 'Euarchontoglires', 'Primates', 'Haplorrhini', 'Catarrhini', 'Hominidae', 'Homo']}, 'proteinExistence': '1: Evidence at protein level', 'proteinDescription': {'recommendedName': {'fullName': {'value': 'Inositol polyphosphate 4-phosphatase type II'}}, 'alternativeNames': [{'fullName': {'value': 'Type II inositol 3,4-bisphosphate 4-phosphatase'}, 'ecNumbers': [{'evidences': [{'evidenceCode'

### Store the mapping results in a dictionary
Key is the gene ID and value is a nested dictionary with keys "description" and "sequence"

In [15]:
dic_gene_id_to_descp_seq = {}
for result in mapping_results['results']:
    # print(result['to'])
    if result['to']['entryType'] == 'UniProtKB reviewed (Swiss-Prot)':
        # print (result['from'], result['to'])
        dic_gene_id_to_descp_seq[result['from']] = {}
        for comment in result['to']['comments']:
            if comment['commentType'] == 'FUNCTION':
                for text in comment['texts']:
                    # print (text['value'])
                    description = text['value']
        dic_gene_id_to_descp_seq[result['from']]['description'] = description
        dic_gene_id_to_descp_seq[result['from']]['sequence'] = result['to']['sequence']['value']

# Display the contents of the dictionary
for gene_id, descp_seq in dic_gene_id_to_descp_seq.items():
    print(f"Gene ID: {gene_id}")
    print(f"Description: {descp_seq['description']}")
    print(f"Sequence: {descp_seq['sequence']}")
    print()
        

Gene ID: 8821
Description: Catalyzes the hydrolysis of the 4-position phosphate of phosphatidylinositol 3,4-bisphosphate, inositol 1,3,4-trisphosphate and inositol 3,4-trisphosphate (PubMed:24070612, PubMed:24591580). Plays a role in the late stages of macropinocytosis by dephosphorylating phosphatidylinositol 3,4-bisphosphate in membrane ruffles (PubMed:24591580). The lipid phosphatase activity is critical for tumor suppressor function. Antagonizes the PI3K-AKT/PKB signaling pathway by dephosphorylating phosphoinositides and thereby modulating cell cycle progression and cell survival (PubMed:19647222, PubMed:24070612)
Sequence: MEIKEEGASEEGQHFLPTAQANDPGDCQFTSIQKTPNEPQLEFILACKDLVAPVRDRKLNTLVQISVIHPVEQSLTRYSSTEIVEGTRDPLFLTGVTFPSEYPIYEETKIKLTVYDVKDKSHDTVRTSVLPEHKDPPPEVGRSFLGYASFKVGELLKSKEQLLVLSLRTSDGGKVVGTIEVSVVKMGEIEDGEADHITTDVQGQKCALVCECTAPESVSGKDNLPFLNSVLKNPVCKLYRFPTSDNKWMRIREQMSESILSFHIPKELISLHIKEDLCRNQEIKELGELSPHWDNLRKNVLTHCDQMVNMYQDILTELSKETGSSFKSSSSKGEKTLEFVPINLHLQRMQVHSPHLKDALYDV

### Map the description and sequence from the dictionary to their corresponding nodes in the graph

In [16]:
from tqdm import tqdm
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    # Ignore the genes/proteins without description
    if gene_id not in dic_gene_id_to_descp_seq:
        continue
    description = dic_gene_id_to_descp_seq[gene_id]['description']
    sequence = dic_gene_id_to_descp_seq[gene_id]['sequence']
    print (f"node: {node}, gene ID: {gene_id}, description: {description}, sequence: {sequence}")
    G.add_nodes_from([(node, {'description': description, 'sequence': sequence})])

# Recompose the graph
kg = nx.compose(G, kg)

100%|██████████| 129262/129262 [00:00<00:00, 1187953.42it/s]

node: PABPC1, gene ID: 26986, description: (Microbial infection) Positively regulates the replication of dengue virus (DENV), sequence: MNPSAPSYPMASLYVGDLHPDVTEAMLYEKFSPAGPILSIRVCRDMITRRSLGYAYVNFQQPADAERALDTMNFDVIKGKPVRIMWSQRDPSLRKSGVGNIFIKNLDKSIDNKALYDTFSAFGNILSCKVVCDENGSKGYGFVHFETQEAAERAIEKMNGMLLNDRKVFVGRFKSRKEREAELGARAKEFTNVYIKNFGEDMDDERLKDLFGKFGPALSVKVMTDESGKSKGFGFVSFERHEDAQKAVDEMNGKELNGKQIYVGRAQKKVERQTELKRKFEQMKQDRITRYQGVNLYVKNLDDGIDDERLRKEFSPFGTITSAKVMMEGGRSKGFGFVCFSSPEEATKAVTEMNGRIVATKPLYVALAQRKEERQAHLTNQYMQRMASVRAVPNPVINPYQPAPPSGYFMAAIPQTQNRAAYYPPSQIAQLRPSPRWTAQGARPHPFQNMPGAIRPAAPRPPFSTMRPASSQVPRVMSTQRVANTSTQTMGPRPAAAAAAATPAVRTVPQYKYAAGVRNPQQHLNAQPQVTMQQPAVHVQGQEPLTASMLASAPPQEQKQMLGERLFPLIQAMHPTLAGKITGMLLEIDNSELLHMLESPESLRSKVDEAVAVLQAHQAKEAAQKAVNSATGVPTV
node: GTPBP3, gene ID: 84705, description: GTPase component of the GTPBP3-MTO1 complex that catalyzes the 5-taurinomethyluridine (taum(5)U) modification at the 34th wobble position (U34) of mitochondrial tRNAs (mt-tRNAs), which




### Check device availability

In [17]:

device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

### Load the ESM2 model

In [18]:
emb_model = EmbeddingWithHuggingFace(model_name='facebook/esm2_t6_8M_UR50D',
                                     model_cache_dir="../../../../data/facebook/esm2_t6_8M_UR50D/",
                                     truncation=False,
                                     device=device)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Generate sequence embedding and add it to the graph as new attribute "sequence_embedding"

In [19]:
# Embeddings using 1 sample at a time
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('sequence') is None:
        continue
    seq = kg.nodes[node].get('sequence')
    # print (node, seq)
    outputs = emb_model.embed_documents([seq])
    G.add_nodes_from([(node, {'sequence_embedding': outputs[0]})])
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

# Recompose the graph
kg = nx.compose(G, kg)

100%|██████████| 129262/129262 [00:05<00:00, 21864.09it/s]


### Load the BioBERT model

In [20]:
# Using MSFT's BioBERT
emb_model = EmbeddingWithHuggingFace(model_name='microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract',
                                     model_cache_dir="../../../../data/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract/",
                                     truncation=False,
                                     device=device)

### Generate description embedding and add it to the graph as new attribute "description_embedding"

In [21]:
for i, node in tqdm(enumerate(kg.nodes)):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    desc = kg.nodes[node].get('description')
    outputs = emb_model.embed_documents([desc])
    G.add_nodes_from([(node, {'description_embedding': outputs})])
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

# Recompose the graph
kg = nx.compose(G, kg)

129262it [00:00, 372240.20it/s]


### Put together all the results so far in a df

In [22]:
import pandas as pd
dic = {'gene':[],
       'description':[],
       'sequence':[],
       'description_embedding':[],
       'sequence_embedding':[]}
for node in tqdm(kg.nodes):
    if kg.nodes[node].get('node_type') != 'gene/protein':
        continue
    gene_id = kg.nodes[node].get('node_id')
    if kg.nodes[node].get('description') is None:
        continue
    dic['gene'].append(node)
    dic['description'].append(kg.nodes[node].get('description'))
    dic['sequence'].append(kg.nodes[node].get('sequence'))
    dic['description_embedding'].append(kg.nodes[node].get('description_embedding'))
    dic['sequence_embedding'].append(kg.nodes[node].get('sequence_embedding'))
    # print (node, kg.nodes[node].get('description'), kg.nodes[node].get('sequence'), kg.nodes[node].get('description_embedding'))

df = pd.DataFrame(dic)
df

100%|██████████| 129262/129262 [00:00<00:00, 1170116.51it/s]


Unnamed: 0,gene,description,sequence,description_embedding,sequence_embedding
0,PABPC1,(Microbial infection) Positively regulates the...,MNPSAPSYPMASLYVGDLHPDVTEAMLYEKFSPAGPILSIRVCRDM...,"[[tensor(-0.1369), tensor(-0.0554), tensor(0.0...","[tensor(0.0133), tensor(-0.0413), tensor(0.192..."
1,GTPBP3,GTPase component of the GTPBP3-MTO1 complex th...,MWRGLWTLAAQAARGPRRLCTRRSSGAPAPGSGATIFALSSGQGRC...,"[[tensor(-0.0017), tensor(0.2555), tensor(0.45...","[tensor(-0.1717), tensor(-0.2150), tensor(-0.0..."
2,SHOX2,May be a growth regulator and have a role in s...,MEELTAFVSKSFDQKVKEKKEAITYREVLESGPLRGAKEPTGCTEA...,"[[tensor(0.0310), tensor(0.3633), tensor(0.744...","[tensor(-0.0969), tensor(-0.2075), tensor(0.01..."
3,ALDH16A1,May be a growth regulator and have a role in s...,MAATRAGPRAREIFTSLEYGPVPESHACALAWLDTQDRCLGHYVNG...,"[[tensor(0.0310), tensor(0.3633), tensor(0.744...","[tensor(-0.2420), tensor(-0.0727), tensor(0.10..."
4,GMPR,Catalyzes the irreversible NADPH-dependent dea...,MPRIDADLKLDFKDVLLRPKRSSLKSRAEVDLERTFTFRNSKQTYS...,"[[tensor(0.1503), tensor(-0.1371), tensor(0.33...","[tensor(0.0235), tensor(0.0571), tensor(0.0156..."
5,INPP4B,Catalyzes the hydrolysis of the 4-position pho...,MEIKEEGASEEGQHFLPTAQANDPGDCQFTSIQKTPNEPQLEFILA...,"[[tensor(0.1104), tensor(0.2020), tensor(0.234...","[tensor(-0.0336), tensor(-0.1216), tensor(0.02..."
