In [1]:
import os
import glob
import cudf
import dask_cudf
import hydra
import cupy as cp

In [None]:
from langchain_openai import OpenAIEmbeddings, ChatOpenAI

os.environ["OPENAI_API_KEY"] = "xxx"

emb_model = OpenAIEmbeddings(model='text-embedding-ada-002')

In [3]:
# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction
cfg

{'_target_': 'talk2knowledgegraphs.tools.multimodal_subgraph_extraction', 'ollama_embeddings': ['nomic-embed-text'], 'temperature': 0.1, 'streaming': False, 'topk': 5, 'topk_e': 5, 'cost_e': 0.5, 'c_const': 0.01, 'root': -1, 'num_clusters': 1, 'pruning': 'gw', 'verbosity_level': 0, 'node_id_column': 'node_id', 'node_attr_column': 'node_attr', 'edge_src_column': 'edge_src', 'edge_attr_column': 'edge_attr', 'edge_dst_column': 'edge_dst', 'node_colors_dict': {'gene/protein': '#6a79f7', 'molecular_function': '#82cafc', 'cellular_component': '#3f9b0b', 'biological_process': '#c5c9c7', 'drug': '#c4a661', 'disease': '#80013f'}, 'biobridge': {'source': '/mnt/blockstorage/biobridge_multimodal/', 'node_type': ['gene/protein', 'molecular_function', 'cellular_component', 'biological_process', 'drug', 'disease']}}

In [4]:
cfg.biobridge.source = "/mnt/blockstorage/biobridge_multimodal"
# cfg.biobridge.source = "../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal"

In [5]:
# Loop over nodes and edges
graph_dict = {}
for element in ["nodes", "edges"]:
    # Make an empty dictionary for each folder
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        # Create the file pattern for the current subfolder
        file_list = glob.glob(os.path.join(cfg.biobridge.source,
                                           element,
                                           stage, '*.parquet.gzip'))
        print(file_list)
        # Read and concatenate all dataframes in the folder
        # Except the edges embedding, which is too large to read in one go
        # We are using a chunk size to read the edges embedding in smaller parts instead
        if element == "edges" and stage == "embedding":
            # For edges embedding, only read two columns: triplet_index and edge_emb
            # graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in file_list[:2]], ignore_index=True)
            # Loop by chunks
            # file_list = file_list[:2]
            chunk_size = 5
            graph_dict[element][stage] = []
            for i in range(0, len(file_list), chunk_size):
                chunk_files = file_list[i:i+chunk_size]
                chunk_df = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in chunk_files], ignore_index=True)
                graph_dict[element][stage].append(chunk_df)
        else:
            # For nodes and edges enrichment, read and concatenate all dataframes in the folder
            # This includes the nodes embedding, which is small enough to read in one go
            graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f) for f in file_list], ignore_index=True)

nodes enrichment
['/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/molecular_function.parquet.gzip']
nodes embedding
['/mnt/blockstorage/biobridge_multimodal/nodes/embedding/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_mu

In [6]:
graph = graph_dict

# Get nodes enrichment and embedding dataframes
nodes_enrichment_df = graph['nodes']['enrichment']
nodes_embedding_df = graph['nodes']['embedding']
# Get edges enrichment and embedding dataframes
edges_enrichment_df = graph['edges']['enrichment']
edges_embedding_df = graph['edges']['embedding'] # consisted of a list of dataframes

In [23]:
node_idxs = nodes_enrichment_df.node_id.to_pandas().tolist()
edges_enrichment_df[(edges_enrichment_df.head_id.isin(node_idxs)) & (edges_enrichment_df.head_id.isin(node_idxs))]

Unnamed: 0,triplet_index,primekg_head_index,primekg_tail_index,head_id,tail_id,display_relation,edge_type,edge_type_str,head_index,tail_index,feat
0,0,0,8889,PHYHIP_(0),KIF15_(8889),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,0,8816,PHYHIP (gene/protein) has a direct relationshi...
1,1,1,2798,GPANK1_(1),PNMA1_(2798),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,1,2787,GPANK1 (gene/protein) has a direct relationshi...
2,2,2,5646,ZRSR2_(2),TTC33_(5646),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,2,5610,ZRSR2 (gene/protein) has a direct relationship...
3,3,3,11592,NRF1_(3),MAN1B1_(11592),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,3,11467,NRF1 (gene/protein) has a direct relationship ...
4,4,4,2122,PI4KA_(4),RGS20_(2122),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,4,2117,PI4KA (gene/protein) has a direct relationship...
...,...,...,...,...,...,...,...,...,...,...,...
3904605,3904605,52855,34572,B cell receptor transport into membrane raft_(...,CD24_(34572),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,45323,27800,B cell receptor transport into membrane raft (...
3904606,3904606,113352,34572,chemokine receptor transport out of membrane r...,CD24_(34572),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,71241,27800,chemokine receptor transport out of membrane r...
3904607,3904607,42264,57675,negative regulation of cytoskeleton organizati...,IQCJ-SCHIP1_(57675),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,35166,49927,negative regulation of cytoskeleton organizati...
3904608,3904608,109904,58770,mesendoderm migration_(109904),APELA_(58770),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,67928,50777,mesendoderm migration (biological_process) has...


[13814,
 13815,
 13816,
 13817,
 13818,
 13819,
 13820,
 13821,
 13822,
 13823,
 13824,
 13825,
 13826,
 13827,
 13828,
 13829,
 13830,
 13831,
 13832,
 13833,
 13834,
 13835,
 13836,
 13837,
 13838,
 13839,
 13840,
 13841,
 13842,
 13843,
 13844,
 13845,
 13846,
 13847,
 13848,
 13849,
 13850,
 13851,
 13852,
 13853,
 13854,
 13855,
 13856,
 13857,
 13858,
 13859,
 13860,
 13861,
 13862,
 13863,
 13864,
 13865,
 13866,
 13867,
 13868,
 13869,
 13870,
 13871,
 13872,
 13873,
 13874,
 13875,
 13876,
 13877,
 13878,
 13879,
 13880,
 13881,
 13882,
 13883,
 13884,
 13885,
 13886,
 13887,
 13888,
 13889,
 13890,
 13891,
 13892,
 13893,
 13894,
 13895,
 13896,
 13897,
 13898,
 13899,
 13900,
 13901,
 13902,
 13903,
 13904,
 13905,
 13906,
 13907,
 13908,
 13909,
 13910,
 13911,
 13912,
 13913,
 13914,
 13915,
 13916,
 13917,
 13918,
 13919,
 13920,
 13921,
 13922,
 13923,
 13924,
 13925,
 13926,
 13927,
 13928,
 13929,
 13930,
 13931,
 13932,
 13933,
 13934,
 13935,
 13936,
 13937,
 13938,


In [30]:
len(graph['nodes']['embedding'][mask])

18797

In [14]:
mask = graph['nodes']['enrichment'].node_type == 'gene/protein'
graph['nodes']['enrichment'][mask]

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat
10770,0,0,PHYHIP_(0),PHYHIP,gene/protein,PHYHIP belongs to gene/protein node. PHYHIP is...,MELLSTPHSIEINNITCDSFRISWAMEDSDLERVTHYFIDLNKKEN...
10771,1,1,GPANK1_(1),GPANK1,gene/protein,GPANK1 belongs to gene/protein node. GPANK1 is...,MSRPLLITFTPATDPSDLWKDGQQQPQPEKPESTLDGAAARAFYEA...
10772,2,2,ZRSR2_(2),ZRSR2,gene/protein,ZRSR2 belongs to gene/protein node. ZRSR2 is z...,MAAPEKMTFPEKPSHKKYRAALKKEKRKKRRQELARLRDSGLSQKE...
10773,3,3,NRF1_(3),NRF1,gene/protein,NRF1 belongs to gene/protein node. NRF1 is nuc...,MEEHGVTQTEHMATIEAHAVAQQVQQVHVATYTEHSMLSADEDSPS...
10774,4,4,PI4KA_(4),PI4KA,gene/protein,PI4KA belongs to gene/protein node. PI4KA is p...,MAAAPARGGGGGGGGGGGCSGSGSSASRGFYFNTVLSLARSLAVQR...
...,...,...,...,...,...,...,...
29562,52445,83734,MAB21L4_(83734),MAB21L4,gene/protein,MAB21L4 belongs to gene/protein node. MAB21L4 ...,MPAPALPTSAMAVQVPLWHHYLQAIRSREAPRAQDFQRAENVLLTV...
29563,52446,83735,PRR23D2_(83735),PRR23D2,gene/protein,PRR23D2 belongs to gene/protein node. PRR23D2 ...,MYGYRRLRSPRDSQTEPQNDNEGETSLATTQMNPPKRRQVEQGPST...
29564,52447,83740,C8orf86_(83740),C8orf86,gene/protein,C8orf86 belongs to gene/protein node.,MRPLGKGLLPAEELIRSNLGVGRSLRDCLSQSGKLAEELGSKRLKP...
29565,52448,83746,CRACDL_(83746),CRACDL,gene/protein,CRACDL belongs to gene/protein node. CRACDL is...,MISTRVMDIKLREAAEGLGEDSTGKKKSKFKTFKKFFGKKKRKESP...


In [20]:
mask

0        False
1        False
2        False
3        False
4        False
         ...  
84976    False
84977    False
84978    False
84979    False
84980    False
Name: node_type, Length: 84981, dtype: bool

In [25]:
graph['nodes']['embedding'][mask].shape[0]

18797

In [7]:
nodes_enrichment_df.shape, nodes_embedding_df.shape

((84981, 7), (84981, 3))

In [8]:
edges_enrichment_df.shape, len(edges_embedding_df)

((3904610, 11), 16)

In [9]:
e_prompt = cp.array(edges_embedding_df[0].loc[0, "edge_emb"])
e_prompt

array([-0.02087522, -0.00298006, -0.00025966, ..., -0.00838831,
       -0.0213233 , -0.02149462])

In [9]:
e_prompt.shape

(1536,)

In [27]:
from cuvs.distance import pairwise_distance
def _compute_sim_scores(features_a: cp.ndarray,
                        features_b:  cp.ndarray,
                        metric: str="cosine"):
    """
    Compute the similarity scores between two sets of features using the specified metric.

    Args:
        features_a: The first set of features.
        features_b: The second set of features.
        metric: The metric to use for computing the similarity scores.

    Returns:
        The similarity scores between the two sets of features.
    """
    scores = pairwise_distance(features_a, features_b, metric=metric)
    scores = 1 - cp.asarray(scores).ravel()
    return scores

In [14]:
nodes_enrichment_df.columns, nodes_embedding_df.columns, edges_enrichment_df.columns

(Index(['node_index', 'primekg_node_index', 'node_id', 'node_name', 'node_type',
        'desc', 'feat'],
       dtype='object'),
 Index(['node_id', 'desc_emb', 'feat_emb'], dtype='object'),
 Index(['triplet_index', 'primekg_head_index', 'primekg_tail_index', 'head_id',
        'tail_id', 'display_relation', 'edge_type', 'edge_type_str',
        'head_index', 'tail_index', 'feat'],
       dtype='object'))

In [11]:
edges_enrichment_df.columns

Index(['triplet_index', 'primekg_head_index', 'primekg_tail_index', 'head_id',
       'tail_id', 'display_relation', 'edge_type', 'edge_type_str',
       'head_index', 'tail_index', 'feat'],
      dtype='object')

In [12]:
edges_embedding_df[0]

Unnamed: 0,triplet_index,edge_emb
0,3050000,"[-0.02087522, -0.002980056, -0.00025966353, -0..."
1,3050001,"[-0.01760313, 0.014069217, 0.019861642, -0.011..."
2,3050002,"[-0.019058425, 0.013303685, 0.019948881, -0.00..."
3,3050003,"[-0.019800434, 0.010459852, 0.014790365, -0.00..."
4,3050004,"[-0.022970848, 0.012780121, 0.0066903764, -0.0..."
...,...,...
249995,1599995,"[-0.0012519549, -0.0014191568, 0.022983257, -0..."
249996,1599996,"[-0.011657391, -0.0136824455, 0.017119711, -0...."
249997,1599997,"[-0.010922454, -0.0057538515, 0.010422692, -0...."
249998,1599998,"[-0.0056640725, -0.00014858306, 0.016633065, -..."


In [12]:
# Define the data path
DATA_PATH = "../../../aiagents4pharma/talk2knowledgegraphs/tests/files"

# Define the agent state
state = {
    "selections": {
        "gene/protein": [],
        "molecular_function": [],
        "cellular_component": [],
        "biological_process": [],
        "drug": [],
        "disease": []
    },
    "uploaded_files": [],
    "topk_nodes": 3,
    "topk_edges": 3,
    "dic_source_graph": [
        {
            "name": "BioBridge",
            "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
            "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
        }
    ],
}

# Update state
state["uploaded_files"] = [
    {
        "file_name": "multimodal-analysis_single_gene.xlsx",
        "file_path": f"{DATA_PATH}/multimodal-analysis_single_gene.xlsx",
        "file_type": "multimodal",
        "uploaded_by": "VPEUser",
        "uploaded_timestamp": "2025-05-12 00:00:00",
    }
]

In [13]:
import pandas as pd

multimodal_df = cudf.DataFrame({"name": [], "node_type": []})
query_df = cudf.DataFrame({"node_id": [],
                            "node_type": [],
                            "feat_emb": [],
                            "desc_emb": [],
                            "use_description": []})

# Loop over the uploaded files and find multimodal files
# logger.log(logging.INFO, "Looping over uploaded files")
for i in range(len(state["uploaded_files"])):
    # Check if multimodal file is uploaded
    if state["uploaded_files"][i]["file_type"] == "multimodal":
        # Read the Excel file
        multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
                                        sheet_name=None)

In [14]:
# Check if the multimodal_df is empty
# logger.log(logging.INFO, "Checking if multimodal_df is empty")
if len(multimodal_df) > 0:
    # Prepare multimodal_df
    # logger.log(logging.INFO, "Preparing multimodal_df")
    # Merge all obtained dataframes into a single dataframe
    multimodal_df = cudf.from_pandas(pd.concat(multimodal_df).reset_index())
    multimodal_df.drop(columns=["level_1"], inplace=True)
    multimodal_df.rename(columns={"level_0": "q_node_type",
                                    "name": "q_node_name"}, inplace=True)
    # Since an excel sheet name could not contain a `/`,
    # but the node type can be 'gene/protein' as exists in the PrimeKG
    multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace('-', '/')

    # Make and process a query dataframe by merging the graph_df and multimodal_df
    # Merging with enrichment dataframe
    # logger.log(logging.INFO, "Merging with enrichment dataframe")
    query_df = graph['nodes']['enrichment'][
        ['node_id', 'node_name', 'node_type', 'feat', 'desc']
    ].merge(multimodal_df, how='cross')
    # Post-process the query dataframe
    # logger.log(logging.INFO, "Lowering case for node names (q_node_name)")
    query_df['q_node_name'] = query_df['q_node_name'].str.lower()
    # logger.log(logging.INFO, "Lowering case for node names (node_name)")
    query_df['node_name'] = query_df['node_name'].str.lower()
    # Get the mask for filtering based on the query
    # logger.log(logging.INFO, "Filtering based on the query")
    mask = (
        query_df['node_name'].str.contains(query_df['q_node_name']) &
        (query_df['node_type'] == query_df['q_node_type'])
    )
    query_df = query_df[mask]
    # Merging the final query dataframe with embedding dataframe
    # logger.log(logging.INFO, "Merging with embedding dataframe")
    query_df = query_df.merge(
        graph['nodes']['embedding'][['node_id', 'feat_emb', 'desc_emb']],
        how='left',
        on='node_id')
    # Post-process the query dataframe
    query_df = query_df[['node_id',
                            'node_type',
                            'feat',
                            'feat_emb',
                            'desc',
                            'desc_emb']].reset_index(drop=True)
    query_df['use_description'] = False # set to False for modal-specific embeddings

In [25]:
prompt = {}
prompt["text"] = "What is the function of the gene/protein BRCA1?"
prompt["emb"] = [emb_model.embed_query(prompt["text"])]
# emb_ = emb_model.embed_query("Hello world")
# cp.asarray(emb_).reshape(1, -1).astype('float32')

In [20]:
query_df.columns

Index(['node_id', 'node_type', 'feat', 'feat_emb', 'desc', 'desc_emb',
       'use_description'],
      dtype='object')

In [22]:
query_df

Unnamed: 0,node_id,node_type,feat,feat_emb,desc,desc_emb,use_description
0,RPS8_(11),gene/protein,MGISRDNWHKRRKTGGKRKPYHKKRKYELGRPAANTKIGPRRIHTV...,"[0.05257261171936989, -0.06346512585878372, -0...",RPS8 belongs to gene/protein node. RPS8 is rib...,"[-0.017070618, -0.0182947, -0.017336722, -0.01...",False
1,FOS_(22),gene/protein,MMFSGFNADYEASSSRCSSASPAGDSLSYYHSPADSFSSMGSPVNA...,"[-0.02010018192231655, -0.0288854893296957, 0....",FOS belongs to gene/protein node. FOS is Fos p...,"[-0.01809412, -0.011164032, -0.021542521, -0.0...",False
2,CALM1_(61),gene/protein,MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTE...,"[0.015500295907258987, 0.12202692776918411, -0...",CALM1 belongs to gene/protein node. CALM1 is c...,"[-0.023683596, -0.008636789, -0.020697404, -0....",False
3,SETD7_(103),gene/protein,MDSDDEMVEEAVEGHLDDDGLPHGFCTVTYSSTDRFEGNFVHGEKN...,"[-0.09826023876667023, -0.06921421736478806, 0...",SETD7 belongs to gene/protein node. SETD7 is S...,"[-0.028181268, -0.0034074313, -0.021452002, -0...",False
4,HNRNPUL1_(159),gene/protein,MDVRRLKVNELREELQRRGLDTRGLKAELAERLQAALEAEEPDDER...,"[-0.07706909626722336, 0.004895209334790707, -...",HNRNPUL1 belongs to gene/protein node. HNRNPUL...,"[-0.027910667, -0.0030972639, -0.020933, -0.01...",False
...,...,...,...,...,...,...,...
406,PRPSAP2_(5003),gene/protein,MFCVTPPELETKMNITKGGLVLFSANSNSSCMELSKKIAERLGVEM...,"[-0.07528620213270187, 0.011673472821712494, 0...",PRPSAP2 belongs to gene/protein node. PRPSAP2 ...,"[-0.007135198, -0.0068572033, 0.0063343085, -0...",False
407,ANXA5_(5004),gene/protein,MAQVLRGTVTDFPGFDERADAETLRKAMKGLGTDEESILTLLTSRS...,"[0.0765886902809143, 0.10536942631006241, 0.02...",ANXA5 belongs to gene/protein node. ANXA5 is a...,"[-0.037901636, 0.0032147528, -0.007965701, -0....",False
408,LTBP4_(5024),gene/protein,MPRPGTSGRRPLLLVLLLPLFAAATSAASPSPSPSQVVEVPGVPSR...,"[-0.018929054960608482, -0.058839354664087296,...",LTBP4 belongs to gene/protein node. LTBP4 is l...,"[-0.03300405, -0.008081024, -0.00728992, -0.03...",False
409,GYPC_(5074),gene/protein,MWSTRSPNSTAWPLSLEPDPGMASASTTMHTTTIAEPDPGMSGWPD...,"[0.0609404481947422, -0.03754159063100815, -0....",GYPC belongs to gene/protein node. GYPC is gly...,"[-0.002514031, 0.008406238, 0.017207902, -0.03...",False


In [26]:
cudf.DataFrame({
        'node_id': 'user_prompt',
        'node_type': 'prompt',
        'feat': prompt["text"],
        'feat_emb': prompt["emb"],
        'desc': prompt["text"],
        'desc_emb': prompt["emb"],
        'use_description': True # set to True for user prompt embedding
    })

Unnamed: 0,node_id,node_type,feat,feat_emb,desc,desc_emb,use_description
0,user_prompt,prompt,What is the function of the gene/protein BRCA1?,"[-0.028648853302001953, -0.0061271535232663155...",What is the function of the gene/protein BRCA1?,"[-0.028648853302001953, -0.0061271535232663155...",True


In [27]:
query_df = cudf.concat([
    query_df,
    cudf.DataFrame({
        'node_id': 'user_prompt',
        'node_type': 'prompt',
        'feat': prompt["text"],
        'feat_emb': prompt["emb"],
        'desc': prompt["text"],
        'desc_emb': prompt["emb"],
        'use_description': True # set to True for user prompt embedding
    })
]).reset_index(drop=True)
query_df

Unnamed: 0,node_id,node_type,feat,feat_emb,desc,desc_emb,use_description
0,RPS8_(11),gene/protein,MGISRDNWHKRRKTGGKRKPYHKKRKYELGRPAANTKIGPRRIHTV...,"[0.05257261171936989, -0.06346512585878372, -0...",RPS8 belongs to gene/protein node. RPS8 is rib...,"[-0.017070618, -0.0182947, -0.017336722, -0.01...",False
1,FOS_(22),gene/protein,MMFSGFNADYEASSSRCSSASPAGDSLSYYHSPADSFSSMGSPVNA...,"[-0.02010018192231655, -0.0288854893296957, 0....",FOS belongs to gene/protein node. FOS is Fos p...,"[-0.01809412, -0.011164032, -0.021542521, -0.0...",False
2,CALM1_(61),gene/protein,MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTE...,"[0.015500295907258987, 0.12202692776918411, -0...",CALM1 belongs to gene/protein node. CALM1 is c...,"[-0.023683596, -0.008636789, -0.020697404, -0....",False
3,SETD7_(103),gene/protein,MDSDDEMVEEAVEGHLDDDGLPHGFCTVTYSSTDRFEGNFVHGEKN...,"[-0.09826023876667023, -0.06921421736478806, 0...",SETD7 belongs to gene/protein node. SETD7 is S...,"[-0.028181268, -0.0034074313, -0.021452002, -0...",False
4,HNRNPUL1_(159),gene/protein,MDVRRLKVNELREELQRRGLDTRGLKAELAERLQAALEAEEPDDER...,"[-0.07706909626722336, 0.004895209334790707, -...",HNRNPUL1 belongs to gene/protein node. HNRNPUL...,"[-0.027910667, -0.0030972639, -0.020933, -0.01...",False
...,...,...,...,...,...,...,...
407,ANXA5_(5004),gene/protein,MAQVLRGTVTDFPGFDERADAETLRKAMKGLGTDEESILTLLTSRS...,"[0.0765886902809143, 0.10536942631006241, 0.02...",ANXA5 belongs to gene/protein node. ANXA5 is a...,"[-0.037901636, 0.0032147528, -0.007965701, -0....",False
408,LTBP4_(5024),gene/protein,MPRPGTSGRRPLLLVLLLPLFAAATSAASPSPSPSQVVEVPGVPSR...,"[-0.018929054960608482, -0.058839354664087296,...",LTBP4 belongs to gene/protein node. LTBP4 is l...,"[-0.03300405, -0.008081024, -0.00728992, -0.03...",False
409,GYPC_(5074),gene/protein,MWSTRSPNSTAWPLSLEPDPGMASASTTMHTTTIAEPDPGMSGWPD...,"[0.0609404481947422, -0.03754159063100815, -0....",GYPC belongs to gene/protein node. GYPC is gly...,"[-0.002514031, 0.008406238, 0.017207902, -0.03...",False
410,CD74_(4694),gene/protein,MHRRRSRSCREDQKPVMDDQRDLISNNEQLPMLGRRPGAPESKCSR...,"[0.04463711008429527, -0.0819774940609932, 0.0...",CD74 belongs to gene/protein node. CD74 is CD7...,"[-0.023425357, -0.007765312, -0.015478855, -0....",False


In [36]:
for i in query_df['feat_emb'].to_pandas().tolist():
    print(len(i))

2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560
2560


In [33]:
query_emb = cp.asarray(graph['nodes']['embedding']['desc_emb'][0]).reshape(1, -1).astype(cp.float32)
# ===
sim_scores = cudf.Series(cp.zeros(len(graph['nodes']['embedding']), dtype=cp.float32))
sim_scores[:] = _compute_sim_scores(
                graph['nodes']['embedding']['desc_emb'].list.leaves.to_cupy().reshape(
                    -1, len(graph['nodes']['embedding']["desc_emb"][0])
                ).astype(cp.float32),
                query_emb
            )  # shape [N, 1]

In [37]:
graph['nodes']['embedding']

Unnamed: 0,node_id,desc_emb,feat_emb
0,Copper_(14012),"[-0.008437113, -0.0066760327, 0.017289385, -0....","[1.9533191919326782, 0.6867725253105164, 0.456..."
1,Oxygen_(14013),"[0.0011131193, 0.00035166115, -0.011181605, -0...","[0.21793906390666962, 0.13471993803977966, 0.0..."
2,Flunisolide_(14014),"[-0.034678344, 0.0014620015, -0.00069094024, -...","[-0.12004103511571884, -0.36696380376815796, -..."
3,Alclometasone_(14015),"[-0.00836069, 0.0018004619, 0.0005102673, -0.0...","[-0.24334891140460968, -0.6172693967819214, 0...."
4,Medrysone_(14016),"[-0.022494577, 0.00894156, -0.004813894, -0.02...","[-0.1530350148677826, -0.33404994010925293, -0..."
...,...,...,...
84976,structural constituent of tooth enamel_(124215),"[0.0027853793, 0.010976661, -0.0056875315, -0....","[-0.011537055, 0.008254736, 0.01594828, -0.017..."
84977,D-glucuronate transmembrane transporter activi...,"[-0.0019518828, -0.024248326, 0.006703199, -0....","[-0.011491481, -0.013787123, 0.016679898, -0.0..."
84978,histone propionyltransferase activity_(124217),"[-0.010533179, -0.032219134, -0.0032646793, -0...","[-0.0020738854, -0.010127607, 0.021126406, -0...."
84979,sirohydrochlorin ferrochelatase activity_(124219),"[-0.0029098308, -0.0076754037, -0.0111139845, ...","[-0.005530768, 0.011203765, -0.0023338485, -0...."


In [39]:
mask = graph['nodes']['enrichment'].node_type == 'gene/protein'
graph['nodes']['enrichment'][mask].head(5)

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat
10770,0,0,PHYHIP_(0),PHYHIP,gene/protein,PHYHIP belongs to gene/protein node. PHYHIP is...,MELLSTPHSIEINNITCDSFRISWAMEDSDLERVTHYFIDLNKKEN...
10771,1,1,GPANK1_(1),GPANK1,gene/protein,GPANK1 belongs to gene/protein node. GPANK1 is...,MSRPLLITFTPATDPSDLWKDGQQQPQPEKPESTLDGAAARAFYEA...
10772,2,2,ZRSR2_(2),ZRSR2,gene/protein,ZRSR2 belongs to gene/protein node. ZRSR2 is z...,MAAPEKMTFPEKPSHKKYRAALKKEKRKKRRQELARLRDSGLSQKE...
10773,3,3,NRF1_(3),NRF1,gene/protein,NRF1 belongs to gene/protein node. NRF1 is nuc...,MEEHGVTQTEHMATIEAHAVAQQVQQVHVATYTEHSMLSADEDSPS...
10774,4,4,PI4KA_(4),PI4KA,gene/protein,PI4KA belongs to gene/protein node. PI4KA is p...,MAAAPARGGGGGGGGGGGCSGSGSSASRGFYFNTVLSLARSLAVQR...


In [40]:
graph['nodes']['embedding'][mask].head(5)

Unnamed: 0,node_id,desc_emb,feat_emb
10770,PHYHIP_(0),"[-0.038923346, -0.022871112, -0.012125405, -0....","[0.04029838368296623, -0.018344514071941376, 0..."
10771,GPANK1_(1),"[-0.025375651, 0.012858219, 0.008264126, -0.00...","[-0.049913737922906876, -0.04380067065358162, ..."
10772,ZRSR2_(2),"[-0.032085866, 0.0071205534, -0.017097335, -0....","[0.035360466688871384, -0.09613325446844101, 0..."
10773,NRF1_(3),"[-0.030888347, -0.024794728, -0.020263912, -0....","[-0.052261918783187866, -0.022747397422790527,..."
10774,PI4KA_(4),"[-0.029845022, -0.023542346, -0.01622012, -0.0...","[0.005174526944756508, -0.049968406558036804, ..."


In [34]:
sim_scores

0        0.999999
1        0.804206
2        0.762189
3        0.776730
4        0.762661
           ...   
84976    0.777796
84977    0.809125
84978    0.777640
84979    0.797852
84980    0.779340
Length: 84981, dtype: float32

In [21]:
graph['nodes']['embedding']['desc_emb'].list.leaves.to_cupy()

array([-0.00843711, -0.00667603,  0.01728939, ..., -0.02651415,
       -0.02254812, -0.04615015])

In [None]:
graph['edges']['embedding'][0]['edge_emb'].list.leaves.to_cupy()

(384000000,)

In [None]:
graph_nodes['embedding']['desc_emb'].list.leaves.to_cupy()

In [42]:
sim_dict = {}
sim_dict['index'] = []
sim_dict['score'] = []
for i, cur_edges_embedding_df in enumerate(edges_embedding_df):
    print(f"Processing edges chunk {i+1}/{len(edges_embedding_df)}")
    # Convert the edge embedding to a cupy array
    cur_triplet_index = cur_edges_embedding_df['triplet_index'].to_cupy()
    # cur_arr = cp.asarray(cur_edges_embedding_df["edge_emb"].explode()).reshape(cur_edges_embedding_df.shape[0], -1)
    cur_arr = cur_edges_embedding_df['edge_emb'].list.leaves.to_cupy().reshape(cur_edges_embedding_df.shape[0], -1).astype(cp.float32)

    # Calculate the similarity score
    # sim_score.append(cp.dot(cur_arr, e_prompt) / (cp.linalg.norm(cur_arr, axis=1) * cp.linalg.norm(e_prompt)))
    sim_dict['score'].append(_compute_sim_scores(cur_arr, e_prompt.reshape(1, -1).astype(cp.float32)))

    # Append the triplet indices
    sim_dict['index'].append(cur_triplet_index)

# Create a DataFrame with the results
sim_scores_df = cudf.DataFrame({
    'triplet_index': cp.concatenate(sim_dict['index']),
    'sim_score': cp.concatenate(sim_dict['score'])
})
sim_scores_df = sim_scores_df.sort_values("triplet_index").reset_index(drop=True)

Processing edges chunk 1/16
Processing edges chunk 2/16
Processing edges chunk 3/16
Processing edges chunk 4/16
Processing edges chunk 5/16
Processing edges chunk 6/16
Processing edges chunk 7/16
Processing edges chunk 8/16
Processing edges chunk 9/16
Processing edges chunk 10/16
Processing edges chunk 11/16
Processing edges chunk 12/16
Processing edges chunk 13/16
Processing edges chunk 14/16
Processing edges chunk 15/16
Processing edges chunk 16/16


In [54]:
graph["nodes"]["embedding"]

Unnamed: 0,node_id,desc_emb,feat_emb
0,Copper_(14012),"[-0.008437113, -0.0066760327, 0.017289385, -0....","[1.9533191919326782, 0.6867725253105164, 0.456..."
1,Oxygen_(14013),"[0.0011131193, 0.00035166115, -0.011181605, -0...","[0.21793906390666962, 0.13471993803977966, 0.0..."
2,Flunisolide_(14014),"[-0.034678344, 0.0014620015, -0.00069094024, -...","[-0.12004103511571884, -0.36696380376815796, -..."
3,Alclometasone_(14015),"[-0.00836069, 0.0018004619, 0.0005102673, -0.0...","[-0.24334891140460968, -0.6172693967819214, 0...."
4,Medrysone_(14016),"[-0.022494577, 0.00894156, -0.004813894, -0.02...","[-0.1530350148677826, -0.33404994010925293, -0..."
...,...,...,...
84976,structural constituent of tooth enamel_(124215),"[0.0027853793, 0.010976661, -0.0056875315, -0....","[-0.011537055, 0.008254736, 0.01594828, -0.017..."
84977,D-glucuronate transmembrane transporter activi...,"[-0.0019518828, -0.024248326, 0.006703199, -0....","[-0.011491481, -0.013787123, 0.016679898, -0.0..."
84978,histone propionyltransferase activity_(124217),"[-0.010533179, -0.032219134, -0.0032646793, -0...","[-0.0020738854, -0.010127607, 0.021126406, -0...."
84979,sirohydrochlorin ferrochelatase activity_(124219),"[-0.0029098308, -0.0076754037, -0.0111139845, ...","[-0.005530768, 0.011203765, -0.0023338485, -0...."


In [53]:
graph["nodes"]["enrichment"].sort_values("node_index", ignore_index=True)

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat
0,0,0,PHYHIP_(0),PHYHIP,gene/protein,PHYHIP belongs to gene/protein node. PHYHIP is...,MELLSTPHSIEINNITCDSFRISWAMEDSDLERVTHYFIDLNKKEN...
1,1,1,GPANK1_(1),GPANK1,gene/protein,GPANK1 belongs to gene/protein node. GPANK1 is...,MSRPLLITFTPATDPSDLWKDGQQQPQPEKPESTLDGAAARAFYEA...
2,2,2,ZRSR2_(2),ZRSR2,gene/protein,ZRSR2 belongs to gene/protein node. ZRSR2 is z...,MAAPEKMTFPEKPSHKKYRAALKKEKRKKRRQELARLRDSGLSQKE...
3,3,3,NRF1_(3),NRF1,gene/protein,NRF1 belongs to gene/protein node. NRF1 is nuc...,MEEHGVTQTEHMATIEAHAVAQQVQQVHVATYTEHSMLSADEDSPS...
4,4,4,PI4KA_(4),PI4KA,gene/protein,PI4KA belongs to gene/protein node. PI4KA is p...,MAAAPARGGGGGGGGGGGCSGSGSSASRGFYFNTVLSLARSLAVQR...
...,...,...,...,...,...,...,...
84976,84976,127430,host cell rough endoplasmic reticulum membrane...,host cell rough endoplasmic reticulum membrane,cellular_component,host cell rough endoplasmic reticulum membrane...,The lipid bilayer surrounding the host cell ro...
84977,84977,127431,collagen type VII anchoring fibril_(127431),collagen type VII anchoring fibril,cellular_component,collagen type VII anchoring fibril belongs to ...,An antiparallel dimer of two collagen VII trim...
84978,84978,127432,cofilin-actin rod_(127432),cofilin-actin rod,cellular_component,cofilin-actin rod belongs to cellular_componen...,"A cellular structure consisting of parallel, h..."
84979,84979,127433,"condensed chromosome, centromeric region_(127433)","condensed chromosome, centromeric region",cellular_component,"condensed chromosome, centromeric region belon...",The region of a condensed chromosome that incl...


In [49]:
graph["edges"]["enrichment"].sort_values("triplet_index", ignore_index=True)

Unnamed: 0,triplet_index,primekg_head_index,primekg_tail_index,head_id,tail_id,display_relation,edge_type,edge_type_str,head_index,tail_index,feat
0,0,0,8889,PHYHIP_(0),KIF15_(8889),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,0,8816,PHYHIP (gene/protein) has a direct relationshi...
1,1,1,2798,GPANK1_(1),PNMA1_(2798),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,1,2787,GPANK1 (gene/protein) has a direct relationshi...
2,2,2,5646,ZRSR2_(2),TTC33_(5646),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,2,5610,ZRSR2 (gene/protein) has a direct relationship...
3,3,3,11592,NRF1_(3),MAN1B1_(11592),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,3,11467,NRF1 (gene/protein) has a direct relationship ...
4,4,4,2122,PI4KA_(4),RGS20_(2122),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,4,2117,PI4KA (gene/protein) has a direct relationship...
...,...,...,...,...,...,...,...,...,...,...,...
3904605,3904605,52855,34572,B cell receptor transport into membrane raft_(...,CD24_(34572),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,45323,27800,B cell receptor transport into membrane raft (...
3904606,3904606,113352,34572,chemokine receptor transport out of membrane r...,CD24_(34572),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,71241,27800,chemokine receptor transport out of membrane r...
3904607,3904607,42264,57675,negative regulation of cytoskeleton organizati...,IQCJ-SCHIP1_(57675),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,35166,49927,negative regulation of cytoskeleton organizati...
3904608,3904608,109904,58770,mesendoderm migration_(109904),APELA_(58770),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,67928,50777,mesendoderm migration (biological_process) has...


In [46]:
sim_dict['score'][0]

array([0.9999999, 0.9467669, 0.9474174, ..., 0.7786807, 0.7878412,
       0.7995058], dtype=float32)

In [43]:
sim_scores_df['sim_score']

0          0.842766
1          0.845283
2          0.853746
3          0.852537
4          0.841947
             ...   
3904605    0.866746
3904606    0.856712
3904607    0.854869
3904608    0.838710
3904609    0.869283
Name: sim_score, Length: 3904610, dtype: float32

In [14]:
sim_scores_df['sim_score']

0          0.842767
1          0.845284
2          0.853746
3          0.852537
4          0.841947
             ...   
3904605    0.866746
3904606    0.856713
3904607    0.854869
3904608    0.838710
3904609    0.869283
Name: sim_score, Length: 3904610, dtype: float64

In [84]:
cudf.Series(cp.zeros(3904610, dtype=cp.float32))

0          0.0
1          0.0
2          0.0
3          0.0
4          0.0
          ... 
3904605    0.0
3904606    0.0
3904607    0.0
3904608    0.0
3904609    0.0
Length: 3904610, dtype: float32

In [74]:
sim_scores_df

Unnamed: 0,triplet_index,sim_score
0,0,0.842767
1,1,0.845284
2,2,0.853746
3,3,0.852537
4,4,0.841947
...,...,...
3904605,3904605,0.866746
3904606,3904606,0.856713
3904607,3904607,0.854869
3904608,3904608,0.838710


In [65]:
triplet_indices

[array([3050000, 3050001, 3050002, ..., 3099997, 3099998, 3099999]),
 array([2250000, 2250001, 2250002, ..., 2299997, 2299998, 2299999]),
 array([1150000, 1150001, 1150002, ..., 1199997, 1199998, 1199999]),
 array([600000, 600001, 600002, ..., 649997, 649998, 649999]),
 array([1550000, 1550001, 1550002, ..., 1599997, 1599998, 1599999]),
 array([1750000, 1750001, 1750002, ..., 1799997, 1799998, 1799999]),
 array([3100000, 3100001, 3100002, ..., 3149997, 3149998, 3149999]),
 array([1600000, 1600001, 1600002, ..., 1649997, 1649998, 1649999]),
 array([2750000, 2750001, 2750002, ..., 2799997, 2799998, 2799999]),
 array([2950000, 2950001, 2950002, ..., 2999997, 2999998, 2999999]),
 array([3700000, 3700001, 3700002, ..., 3749997, 3749998, 3749999]),
 array([500000, 500001, 500002, ..., 549997, 549998, 549999]),
 array([800000, 800001, 800002, ..., 849997, 849998, 849999]),
 array([1050000, 1050001, 1050002, ..., 1099997, 1099998, 1099999]),
 array([300000, 300001, 300002, ..., 349997, 349998,

In [59]:
a = cp.concatenate(sim_score)
a[cp.argsort(-a)[:3]]

array([1.        , 0.9927882 , 0.99180061])

In [43]:
cur_arr

array([[-0.02087522, -0.00298006, -0.00025966, ..., -0.00838831,
        -0.0213233 , -0.02149462],
       [-0.01760313,  0.01406922,  0.01986164, ..., -0.01814783,
        -0.02109718, -0.02145589],
       [-0.01905843,  0.01330369,  0.01994888, ..., -0.01749016,
        -0.02174308, -0.0217165 ],
       ...,
       [-0.03222188, -0.00610927,  0.00940235, ...,  0.00802325,
        -0.010137  , -0.03085568],
       [-0.02284207, -0.00256165,  0.01529403, ...,  0.00218062,
        -0.00395876, -0.01533361],
       [-0.01814055, -0.01461722,  0.010439  , ..., -0.00747888,
        -0.00141784, -0.02034099]])

In [24]:
cp.array(edges_embedding_df[0]["edge_emb"].to_arrow().to_pylist())

array([[-0.02087522, -0.00298006, -0.00025966, ..., -0.00838831,
        -0.0213233 , -0.02149462],
       [-0.01760313,  0.01406922,  0.01986164, ..., -0.01814783,
        -0.02109718, -0.02145589],
       [-0.01905843,  0.01330369,  0.01994888, ..., -0.01749016,
        -0.02174308, -0.0217165 ],
       ...,
       [-0.03222188, -0.00610927,  0.00940235, ...,  0.00802325,
        -0.010137  , -0.03085568],
       [-0.02284207, -0.00256165,  0.01529403, ...,  0.00218062,
        -0.00395876, -0.01533361],
       [-0.01814055, -0.01461722,  0.010439  , ..., -0.00747888,
        -0.00141784, -0.02034099]])

In [None]:
from torch_geometric.data import TensorAttr
import torch
from cugraph_pyg.data import GraphStore, TensorDictFeatureStore

In [None]:
# Initialize FeatureStore and mapper
graph_store = GraphStore()
feature_store = TensorDictFeatureStore()
mapper = {}


In [None]:
# Loop over group enrichment nodes by type
for nt, nodes_df in nodes_enrichment_df.groupby('node_type'):
    # print(f"Node type: {nt}")
    # node_count = len(nodes_df)
    # print(f"Number of nodes: {node_count}")

    # Get node_ids
    emb_df = nodes_embedding_df[nodes_embedding_df['node_id'].isin(nodes_df['node_id'])]

    # Temporarily fix mismatched embedding
    # Use desc_emb instead
    emb_df['feat_emb'] = emb_df['desc_emb']

    # Sort both by node_id for alignment
    nodes_df_sorted = nodes_df.sort_values('node_id')
    emb_df_sorted = emb_df.sort_values('node_id')

    # Ensure sorted node_ids match
    assert cudf.Series.equals(nodes_df_sorted['node_id'].reset_index(drop=True),
                              emb_df_sorted['node_id'].reset_index(drop=True)), \
                                f"Node ID mismatch in {nt} after sorting"

    # Get node_index as torch tensor directly
    node_index_tensor = torch.tensor(nodes_df_sorted["node_index"].to_numpy(),
                                     dtype=torch.int64)
    feature_store[TensorAttr(group_name=nt, attr_name="node_index")] = node_index_tensor

    # # Construct mapper for node_index
    node_index_list = nodes_df_sorted["node_index"].to_numpy().tolist()
    mapper[nt] = {
        'to_node_index': dict(enumerate(node_index_list)),
        'from_node_index': {v: i for i, v in enumerate(node_index_list)}
    }

    # Convert embeddings as tensors and add to FeatureStore
    for attr_name in ["desc_emb", "feat_emb"]:
        emb_tensor = torch.tensor(emb_df_sorted[attr_name].to_arrow().to_pylist(),
                                  dtype=torch.float32)
        feature_store[TensorAttr(group_name=nt, attr_name=attr_name)] = emb_tensor

In [None]:
# Loop over edge types
for edge_type_str in edges_enrichment_df['edge_type_str'].unique().to_arrow().to_pylist():
    if edge_type_str == 'gene/protein|ppi|gene/protein':
        # print(f"Processing edge type: {edge_type_str}")
        src_type, rel_type, tgt_type = edge_type_str.split('|')

        # Filter edges for this edge_type_str once
        filtered_df = edges_enrichment_df[
            edges_enrichment_df['edge_type_str'] == edge_type_str
        ][['triplet_index', 'head_index', 'tail_index']]

        # Convert mapper dicts to cudf Series for vectorized mapping
        src_map = cudf.Series(mapper[src_type]['from_node_index'])
        tgt_map = cudf.Series(mapper[tgt_type]['from_node_index'])

        # Vectorized mapping of head_index and tail_index using replace (works like dict lookup)
        mapped_head = filtered_df['head_index'].replace(src_map).astype('int64')
        mapped_tail = filtered_df['tail_index'].replace(tgt_map).astype('int64')

        # Check if mapping was successful
        if mapped_head.isnull().any() or mapped_tail.isnull().any():
            raise ValueError(f"Mapping failure for edge type {edge_type_str}")

        # Edge index
        edge_index = torch.tensor(
            cudf.concat([mapped_head, mapped_tail], axis=1).to_pandas().values.T,
            dtype=torch.long
        ).contiguous()

        # Store edge index in the GraphStore
        graph_store[(src_type, rel_type, tgt_type), "coo"] = edge_index

        # Add triplet index to the FeatureStore
        triplet_index_tensor = torch.tensor(filtered_df['triplet_index'].to_numpy(),
                                            dtype=torch.long).unsqueeze(0)
        feature_store[TensorAttr(group_name=(src_type, rel_type, tgt_type),
                                 attr_name='triplet_index')] = triplet_index_tensor

        # Store edge embeddings in the FeatureStore
        # edge_emb_tensor = []
        # for i, edge_df in edges_embedding_df:
        #     edge_emb_df = edges_embedding_df[edges_embedding_df['edge_type_str'] == edge_type_str]

        #     # Convert edge embeddings to torch tensor
        #     edge_emb_tensor.append(torch.tensor(edge_emb_df['edge_emb'].to_arrow().to_pylist(),
        #                                         dtype=torch.float32).unsqueeze(0))
        # feature_store[TensorAttr(group_name=(src_type, rel_type, tgt_type),
        #                         attr_name='edge_emb')] = edge_emb_tensor

In [None]:
# Merge nodes embedding into nodes enrichment
graph_dict["nodes"] = graph_dict["nodes"]["enrichment"].merge(
    graph_dict["nodes"]["embedding"],
    how="left",
    on="node_id"
)

In [None]:
# Check head
graph_dict["nodes"].head(5)

In [None]:
# graph_dict["edges"]["enrichment"] = graph_dict["edges"]["enrichment"].merge(
#     cudf.DataFrame({
#         "triplet_index": graph_dict["edges"]["enrichment"].triplet_index,
#         "feat_emb": None
#     }),
#     how="left",
#     on="triplet_index"
# )
graph_dict["edges"]["enrichment"].drop(columns=["feat_emb_x", "feat_emb_y", "feat_emb"], inplace=True)
graph_dict["edges"]["enrichment"]

In [None]:
graph_dict["edges"]["embedding"][0].loc[0, "edge_emb"]

In [None]:
[[]] * graph_dict["edges"]["enrichment"].shape[0]

In [None]:
# Make a merged dataframe with additional column
graph_dict["edges"]["enrichment"] = graph_dict["edges"]["enrichment"].merge(
    cudf.DataFrame({
        "triplet_index": graph_dict["edges"]["enrichment"].triplet_index,
        "feat_emb": [[]] * graph_dict["edges"]["enrichment"].shape[0]
    }),
    how="left",
    on="triplet_index"
)
graph_dict["edges"]["enrichment"]

In [None]:
graph_dict["edges"]["enrichment"]['feat_emb'].list.len() == 0

In [None]:
graph_dict["edges"]["enrichment"]['feat_emb']

In [None]:
graph_dict["edges"]["enrichment"].drop(columns=["feat_emb_x", "feat_emb_y"], inplace=True)

In [None]:
# Make a merged dataframe with additional column
graph_dict["edges"]["enrichment"] = graph_dict["edges"]["enrichment"].merge(
    cudf.DataFrame({
        "triplet_index": graph_dict["edges"]["enrichment"].triplet_index,
        "feat_emb": [[]] * graph_dict["edges"]["enrichment"].shape[0]
    }),
    how="left",
    on="triplet_index"
)
# graph_dict["edges"]["enrichment"]

for i, emb_df in enumerate(graph_dict["edges"]["embedding"]):
    # Merge the embeddings into the feature column dataframe
    enrichment_df = graph_dict["edges"]["enrichment"].merge(
        emb_df,
        on="triplet_index",
        how="left",
    )

    # Create mask for rows where feat_emb is empty list and triplet_index exists in current emb_df
    mask = (
        enrichment_df['feat_emb'].list.len() == 0
    ) & (
        enrichment_df['triplet_index'].isin(emb_df['triplet_index'])
    )

    # Assign edge_emb to feat_emb only where mask is True
    enrichment_df.loc[mask, 'feat_emb'] = enrichment_df.loc[mask, 'edge_emb']

    # Drop the edge_emb column
    enrichment_df.drop(columns=['edge_emb'], inplace=True)

    # Update enrichment in graph_dict
    graph_dict["edges"]["enrichment"] = enrichment_df

    # Clean up
    del emb_df
    graph_dict["edges"]["embedding"][i] = None


# Merge the feature column with the edges enrichment dataframe
# graph_dict["edges"]["enrichment"] = graph_dict["edges"]["enrichment"].merge(
#     feat_col,
#     on="triplet_index",
#     how="left",
# )

# Store the edges enrichment
# graph_dict["edges"] = graph_dict["edges"]["enrichment"]

In [None]:
import numpy as np
np.arange(10) + 1250000

In [None]:
graph_dict["edges"]["enrichment"][graph_dict["edges"]["enrichment"].triplet_index.isin(np.arange(10) + 1350000)]

In [None]:
graph_dict["edges"]["enrichment"][graph_dict["edges"]["enrichment"].triplet_index.isin(np.arange(10) + 1350000)].iloc[0].feat_emb.to_arrow().to_pylist()[0]

In [None]:
file_list

In [None]:
# # Make a simple dataframe
# feat_col = cudf.DataFrame(
#     {
#         "triplet_index": graph_dict["edges"]["enrichment"].triplet_index,
#         "feat_emb": None
#     }
# )

# # Loop over a set of embeddings chunks
# for i, emb_df in enumerate(graph_dict["edges"]["embedding"]):
#     # Merge the embeddings into the feature column dataframe
#     feat_col = feat_col.merge(
#         emb_df,
#         on="triplet_index",
#         how="left",
#     )

#     # Fill missing embeddings with edge embeddings
#     mask = (feat_col['feat_emb'].isna()) & (feat_col['triplet_index'].isin(emb_df.triplet_index))
#     feat_col.loc[mask, 'feat_emb'] = feat_col.loc[mask, 'edge_emb']

#     # Drop the edge_emb column
#     feat_col = feat_col.drop(columns=['edge_emb'])

# # Merge the feature column with the edges enrichment dataframe
# graph_dict["edges"]["enrichment"] = graph_dict["edges"]["enrichment"].merge(
#     feat_col,
#     on="triplet_index",
#     how="left",
# )

# # Store the edges enrichment
# graph_dict["edges"] = graph_dict["edges"]["enrichment"]