In [1]:
import os
import glob
import cudf
import dask_cudf
import hydra

# import sys
# sys.path.append('../../..')
# from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.ollama import EmbeddingWithOllama

In [3]:
# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction
cfg

{'_target_': 'talk2knowledgegraphs.tools.multimodal_subgraph_extraction', 'ollama_embeddings': ['nomic-embed-text'], 'temperature': 0.1, 'streaming': False, 'topk': 5, 'topk_e': 5, 'cost_e': 0.5, 'c_const': 0.01, 'root': -1, 'num_clusters': 1, 'pruning': 'gw', 'verbosity_level': 0, 'node_id_column': 'node_id', 'node_attr_column': 'node_attr', 'edge_src_column': 'edge_src', 'edge_attr_column': 'edge_attr', 'edge_dst_column': 'edge_dst', 'node_colors_dict': {'gene/protein': '#6a79f7', 'molecular_function': '#82cafc', 'cellular_component': '#3f9b0b', 'biological_process': '#c5c9c7', 'drug': '#c4a661', 'disease': '#80013f'}, 'biobridge': {'source': 'aiagents4pharma/talk2knowledgegraphs/tests/files/ibd_biobridge_multimodal/', 'node_type': ['gene/protein', 'molecular_function', 'cellular_component', 'biological_process', 'drug', 'disease']}}

In [4]:
# cfg.biobridge.source = "/mnt/blockstorage/biobridge_multimodal"
cfg.biobridge.source = "../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal"

In [5]:
# Loop over nodes and edges
graph_dict = {}
for element in ["nodes", "edges"]:
    # Make an empty dictionary for each folder
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        # Create the file pattern for the current subfolder
        file_list = glob.glob(os.path.join(cfg.biobridge.source, 
                                           element, 
                                           stage, '*.parquet.gzip'))
        print(file_list)
        # if element != "edges" and stage == "embedding":
        # Read and concatenate all dataframes in the folder
        # graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f) for f in file_list], ignore_index=True)

nodes enrichment
['../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/molecular_function.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/disease.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/drug.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/biological_process.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/cellular_component.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/gene_protein.parquet.gzip']
nodes embedding
['../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimod

In [6]:
# Loop over nodes and edges
chunk_size = 1
graph_dict = {}
for element in ["nodes", "edges"]:
    # Make an empty dictionary for each folder
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        # Create the file pattern for the current subfolder
        file_list = glob.glob(os.path.join(cfg.biobridge.source, 
                                           element, 
                                           stage, '*.parquet.gzip'))
        print(file_list)
        # Read and concatenate all dataframes in the folder
        # Except the edges embedding, which is too large to read in one go
        # We are using a chunk size to read the edges embedding in smaller parts instead
        if element == "edges" and stage == "embedding":
            # For edges embedding, only read two columns: triplet_index and edge_emb
            # graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in file_list[:2]], ignore_index=True)
            # Loop by chunks
            file_list = file_list[:2]
            graph_dict[element][stage] = []
            for i in range(0, len(file_list), chunk_size):
                chunk_files = file_list[i:i+chunk_size]
                chunk_df = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in chunk_files], ignore_index=True)
                graph_dict[element][stage].append(chunk_df)
        else:
            # For nodes and edges enrichment, read and concatenate all dataframes in the folder
            # This includes the nodes embedding, which is small enough to read in one go
            graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f) for f in file_list], ignore_index=True)

nodes enrichment
['../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/molecular_function.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/disease.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/drug.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/biological_process.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/cellular_component.parquet.gzip', '../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/gene_protein.parquet.gzip']
nodes embedding
['../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimod

In [8]:
len(graph_dict["edges"]["embedding"])

2

In [6]:
graph_dict["edges"]["embedding"] = []
graph_dict["edges"]["embedding"].append(cudf.read_parquet(file_list[0], columns=["triplet_index", "edge_emb"]))
graph_dict["edges"]["embedding"].append(cudf.read_parquet(file_list[1], columns=["triplet_index", "edge_emb"]))

In [None]:
# Make a simple dataframe
feat_col = cudf.DataFrame(
    {
        "triplet_index": graph_dict["edges"]["enrichment"].triplet_index,
        "feat_emb": None
    }
)

# Loop over a set of embeddings chunks
for i, emb_df in enumerate(graph_dict["edges"]["embedding"]):
    # Merge the embeddings into the feature column dataframe
    feat_col = feat_col.merge(
        emb_df,
        on="triplet_index",
        how="left",
    )

    # Fill missing embeddings with edge embeddings
    mask = (feat_col['feat_emb'].isna()) & (feat_col['triplet_index'].isin(emb_df.triplet_index))
    feat_col.loc[mask, 'feat_emb'] = feat_col.loc[mask, 'edge_emb']

    # Drop the edge_emb column
    feat_col = feat_col.drop(columns=['edge_emb'])

# Merge the feature column with the edges enrichment dataframe
graph_dict["edges"]["enrichment"] = graph_dict["edges"]["enrichment"].merge(
    feat_col,
    on="triplet_index",
    how="left",
)



In [None]:

# feat_col = cudf.DataFrame(
#     {
#         "triplet_index": graph_dict["edges"]["enrichment"].triplet_index,
#         "feat_emb": None
#     }
# )
feat_col = feat_col.merge(
    graph_dict["edges"]["embedding"][0],
    on="triplet_index",
    how="left"
)
mask = (feat_col['feat_emb'].isna()) & (feat_col['triplet_index'].isin(graph_dict["edges"]["embedding"][0].triplet_index))
feat_col.loc[mask, 'feat_emb'] = feat_col.loc[mask, 'edge_emb']
# feat_col.drop(columns=['edge_emb'], inplace=True)
# feat_col
# graph_dict["edges"]["enrichment"]
# graph_dict["edges"]["enrichment"][graph_dict["edges"]["enrichment"].triplet_index.isin(ordered_emb_df.triplet_index)].sort_values("triplet_index")["feat_emb"] = ordered_emb_df.edge_emb

In [25]:
feat_col[feat_col.triplet_index.isin(graph_dict["edges"]["embedding"][0].triplet_index[:10])]

Unnamed: 0,triplet_index,feat_emb
1253488,1250001,"[-0.019220604, 0.005907096, 0.003145554, -0.02..."
1253489,1250002,"[-0.03603969, -0.000844474, 0.01573266, -0.033..."
1253490,1250003,"[-0.004261625, 0.00337462, 0.005385386, -0.038..."
1253491,1250004,"[-0.012451958, -0.001863445, 0.015094904, -0.0..."
1253492,1250005,"[-0.018032864, -0.000848302, 0.007660525, -0.0..."
1253493,1250007,"[-0.013187784, -0.01249827, 0.009673273, -0.02..."
1253494,1250008,"[-0.014623958, 0.004236569, 0.021803234, -0.02..."
1253495,1250009,"[-0.024278272, 0.001424474, 0.01011484, -0.027..."
1253500,1250000,"[-0.022698749, -0.001357529, 0.006256367, -0.0..."
1253501,1250006,"[-0.012184068, -0.002441528, 0.016878538, -0.0..."


In [15]:
feat_col.drop(columns=['edge_emb'], inplace=True)

In [16]:
graph_dict["edges"]["embedding"][0].triplet_index[:10]

0    1250000
1    1250001
2    1250002
3    1250003
4    1250004
5    1250005
6    1250006
7    1250007
8    1250008
9    1250009
Name: triplet_index, dtype: int64

In [11]:
feat_col[feat_col.triplet_index.isin(graph_dict["edges"]["embedding"][1].triplet_index[:10])]

Unnamed: 0,triplet_index,feat_emb
1349568,1350000,"[-0.005971643, -0.024854947, 0.020295516, -0.0..."
1349569,1350001,"[-0.007006694, -0.019030526, 0.004950598, -0.0..."
1349570,1350002,"[-0.016136318, -0.030427722, 0.01814166, -0.02..."
1349571,1350003,"[-0.017075982, -0.02360424, 0.016603898, -0.02..."
1349572,1350004,"[-0.01116664, -0.03029234, 0.006278739, -0.030..."
1349573,1350005,"[-0.002655919, -0.017228309, 0.02556113, -0.03..."
1349574,1350006,"[-0.015132139, -0.023252474, -0.00335064, -0.0..."
1349575,1350007,"[-0.006733189, -0.031606145, 0.030228442, -0.0..."
1349576,1350008,"[-0.01785563, -0.029522736, 0.014774763, -0.01..."
1349577,1350009,"[-0.010561401, -0.021869097, 0.012813612, -0.0..."


In [27]:
graph_dict["edges"]["embedding"][0].triplet_index[:10]

0    1250000
1    1250001
2    1250002
3    1250003
4    1250004
5    1250005
6    1250006
7    1250007
8    1250008
9    1250009
Name: triplet_index, dtype: int64

In [28]:
feat_col[feat_col.triplet_index.isin(graph_dict["edges"]["embedding"][0].triplet_index[:10])].sort_values("triplet_index")

Unnamed: 0,triplet_index,feat_emb
1253500,1250000,"[-0.022698749, -0.001357529, 0.006256367, -0.0..."
1253488,1250001,"[-0.019220604, 0.005907096, 0.003145554, -0.02..."
1253489,1250002,"[-0.03603969, -0.000844474, 0.01573266, -0.033..."
1253490,1250003,"[-0.004261625, 0.00337462, 0.005385386, -0.038..."
1253491,1250004,"[-0.012451958, -0.001863445, 0.015094904, -0.0..."
1253492,1250005,"[-0.018032864, -0.000848302, 0.007660525, -0.0..."
1253501,1250006,"[-0.012184068, -0.002441528, 0.016878538, -0.0..."
1253493,1250007,"[-0.013187784, -0.01249827, 0.009673273, -0.02..."
1253494,1250008,"[-0.014623958, 0.004236569, 0.021803234, -0.02..."
1253495,1250009,"[-0.024278272, 0.001424474, 0.01011484, -0.027..."


In [18]:

# feat_col = cudf.DataFrame(
#     {
#         "triplet_index": graph_dict["edges"]["enrichment"].triplet_index,
#         "feat_emb": None
#     }
# )
feat_col = feat_col.merge(
    graph_dict["edges"]["embedding"][1],
    on="triplet_index",
    how="left"
)

# graph_dict["edges"]["enrichment"]
# graph_dict["edges"]["enrichment"][graph_dict["edges"]["enrichment"].triplet_index.isin(ordered_emb_df.triplet_index)].sort_values("triplet_index")["feat_emb"] = ordered_emb_df.edge_emb

In [19]:
graph_dict["edges"]["embedding"][1].triplet_index[:10]

0    1350000
1    1350001
2    1350002
3    1350003
4    1350004
5    1350005
6    1350006
7    1350007
8    1350008
9    1350009
Name: triplet_index, dtype: int64

In [12]:
feat_col[feat_col.triplet_index.isin(graph_dict["edges"]["embedding"][1].triplet_index[:10])]

Unnamed: 0,triplet_index,feat_emb
1349568,1350000,"[-0.005971643, -0.024854947, 0.020295516, -0.0..."
1349569,1350001,"[-0.007006694, -0.019030526, 0.004950598, -0.0..."
1349570,1350002,"[-0.016136318, -0.030427722, 0.01814166, -0.02..."
1349571,1350003,"[-0.017075982, -0.02360424, 0.016603898, -0.02..."
1349572,1350004,"[-0.01116664, -0.03029234, 0.006278739, -0.030..."
1349573,1350005,"[-0.002655919, -0.017228309, 0.02556113, -0.03..."
1349574,1350006,"[-0.015132139, -0.023252474, -0.00335064, -0.0..."
1349575,1350007,"[-0.006733189, -0.031606145, 0.030228442, -0.0..."
1349576,1350008,"[-0.01785563, -0.029522736, 0.014774763, -0.01..."
1349577,1350009,"[-0.010561401, -0.021869097, 0.012813612, -0.0..."


In [20]:
feat_col[feat_col.triplet_index.isin(graph_dict["edges"]["embedding"][1].triplet_index[:10])]

Unnamed: 0,triplet_index,feat_emb,edge_emb
1349744,1350000,,"[-0.0059716427, -0.024854947, 0.020295516, -0...."
1349745,1350001,,"[-0.0070066936, -0.019030526, 0.0049505983, -0..."
1349746,1350002,,"[-0.016136318, -0.030427722, 0.01814166, -0.02..."
1349747,1350003,,"[-0.017075982, -0.02360424, 0.016603898, -0.02..."
1349748,1350004,,"[-0.01116664, -0.03029234, 0.0062787393, -0.03..."
1349749,1350005,,"[-0.0026559194, -0.017228309, 0.02556113, -0.0..."
1349750,1350006,,"[-0.015132139, -0.023252474, -0.0033506402, -0..."
1349751,1350007,,"[-0.0067331893, -0.031606145, 0.030228442, -0...."
1349752,1350008,,"[-0.01785563, -0.029522736, 0.014774763, -0.01..."
1349753,1350009,,"[-0.010561401, -0.021869097, 0.012813612, -0.0..."


In [21]:
mask = (feat_col['feat_emb'].isna()) & (feat_col['triplet_index'].isin(graph_dict["edges"]["embedding"][1].triplet_index))
feat_col.loc[mask, 'feat_emb'] = feat_col.loc[mask, 'edge_emb']
feat_col

Unnamed: 0,triplet_index,feat_emb,edge_emb
0,7024,,
1,7025,,
2,7026,,
3,7027,,
4,7028,,
...,...,...,...
3904605,3903053,,
3904606,3903054,,
3904607,3903055,,
3904608,3903048,,


In [22]:
feat_col[feat_col.triplet_index.isin(graph_dict["edges"]["embedding"][1].triplet_index[:10])]

Unnamed: 0,triplet_index,feat_emb,edge_emb
1349744,1350000,"[-0.005971643, -0.024854947, 0.020295516, -0.0...","[-0.0059716427, -0.024854947, 0.020295516, -0...."
1349745,1350001,"[-0.007006694, -0.019030526, 0.004950598, -0.0...","[-0.0070066936, -0.019030526, 0.0049505983, -0..."
1349746,1350002,"[-0.016136318, -0.030427722, 0.01814166, -0.02...","[-0.016136318, -0.030427722, 0.01814166, -0.02..."
1349747,1350003,"[-0.017075982, -0.02360424, 0.016603898, -0.02...","[-0.017075982, -0.02360424, 0.016603898, -0.02..."
1349748,1350004,"[-0.01116664, -0.03029234, 0.006278739, -0.030...","[-0.01116664, -0.03029234, 0.0062787393, -0.03..."
1349749,1350005,"[-0.002655919, -0.017228309, 0.02556113, -0.03...","[-0.0026559194, -0.017228309, 0.02556113, -0.0..."
1349750,1350006,"[-0.015132139, -0.023252474, -0.00335064, -0.0...","[-0.015132139, -0.023252474, -0.0033506402, -0..."
1349751,1350007,"[-0.006733189, -0.031606145, 0.030228442, -0.0...","[-0.0067331893, -0.031606145, 0.030228442, -0...."
1349752,1350008,"[-0.01785563, -0.029522736, 0.014774763, -0.01...","[-0.01785563, -0.029522736, 0.014774763, -0.01..."
1349753,1350009,"[-0.010561401, -0.021869097, 0.012813612, -0.0...","[-0.010561401, -0.021869097, 0.012813612, -0.0..."


In [18]:
graph_dict["nodes"]["embedding"]

Unnamed: 0,node_id,desc_emb,feat_emb
0,methyltransferase activity_(53517),"[-0.015243685, -0.024490524, -0.013795467, -0....","[-0.017935446, 0.008318249, 0.0036064482, -0.0..."
1,"catalytic activity, acting on a tRNA_(53518)","[-0.012256229, -0.02646109, -0.0012187355, -0....","[-0.03426383, 0.0023465257, 0.0041918517, -0.0..."
2,"catalytic activity, acting on DNA_(53519)","[-0.023627607, -0.008710639, -0.007649029, -0....","[-0.022934051, -0.00065733766, 0.003557835, -0..."
3,"catalytic activity, acting on a protein_(53520)","[-0.018698066, -0.013873419, -0.009751656, -0....","[-0.017968671, -0.0006349569, 0.008043058, -0...."
4,"catalytic activity, acting on RNA_(53521)","[-0.02121846, -0.020333799, 0.0017944543, -0.0...","[-0.049570102, 0.011327438, 0.021395529, 0.017..."
...,...,...,...
84976,MAB21L4_(83734),"[-0.018029643, -0.0047510546, -0.016251702, -0...","[-0.021339817, 0.0035598348, -0.0067422567, -0..."
84977,PRR23D2_(83735),"[-0.006412564, 0.0011807027, -0.013452657, -0....","[-0.03778587, -0.008879421, -0.003650357, -0.0..."
84978,C8orf86_(83740),"[-0.009083541, -0.012571621, -0.017519673, -0....","[-0.0379925, -0.016725836, -0.0025223098, -0.0..."
84979,CRACDL_(83746),"[-0.012079031, 0.0054659494, -0.013663166, -0....","[-0.02916107, -0.018353855, 0.01120821, -0.007..."


In [16]:
cudf.read_parquet(file_list[0], columns=["triplet_index", "edge_emb"])

Unnamed: 0,triplet_index,edge_emb
0,1250000,"[-0.022698749, -0.0013575292, 0.006256367, -0...."
1,1250001,"[-0.019220604, 0.0059070964, 0.0031455539, -0...."
2,1250002,"[-0.03603969, -0.00084447366, 0.01573266, -0.0..."
3,1250003,"[-0.0042616245, 0.0033746199, 0.0053853863, -0..."
4,1250004,"[-0.012451958, -0.0018634446, 0.015094904, -0...."
...,...,...
49995,1299995,"[-0.037339132, -0.006667702, -0.012439743, -0...."
49996,1299996,"[-0.021750897, 0.005384018, 0.0067468057, -0.0..."
49997,1299997,"[-0.020974606, 0.0056373496, 0.017132657, -0.0..."
49998,1299998,"[-0.016745826, -0.012181237, -0.009406022, -0...."


In [5]:
graph_dict['edges']['enrichment']

Unnamed: 0,triplet_index,primekg_head_index,primekg_tail_index,head_id,tail_id,display_relation,edge_type,edge_type_str,head_index,tail_index,feat
0,0,0,8889,PHYHIP_(0),KIF15_(8889),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,0,8816,PHYHIP (gene/protein) has a direct relationshi...
1,1,1,2798,GPANK1_(1),PNMA1_(2798),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,1,2787,GPANK1 (gene/protein) has a direct relationshi...
2,2,2,5646,ZRSR2_(2),TTC33_(5646),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,2,5610,ZRSR2 (gene/protein) has a direct relationship...
3,3,3,11592,NRF1_(3),MAN1B1_(11592),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,3,11467,NRF1 (gene/protein) has a direct relationship ...
4,4,4,2122,PI4KA_(4),RGS20_(2122),ppi,"[gene/protein, ppi, gene/protein]",gene/protein|ppi|gene/protein,4,2117,PI4KA (gene/protein) has a direct relationship...
...,...,...,...,...,...,...,...,...,...,...,...
3904605,3904605,52855,34572,B cell receptor transport into membrane raft_(...,CD24_(34572),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,45323,27800,B cell receptor transport into membrane raft (...
3904606,3904606,113352,34572,chemokine receptor transport out of membrane r...,CD24_(34572),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,71241,27800,chemokine receptor transport out of membrane r...
3904607,3904607,42264,57675,negative regulation of cytoskeleton organizati...,IQCJ-SCHIP1_(57675),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,35166,49927,negative regulation of cytoskeleton organizati...
3904608,3904608,109904,58770,mesendoderm migration_(109904),APELA_(58770),interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,67928,50777,mesendoderm migration (biological_process) has...


In [None]:
# Loop over nodes and edges
graph_dict = {}
chunk_size = 10  # Adjust based on your GPU memory

for element in ["nodes", "edges"]:
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        file_list = glob.glob(os.path.join(cfg.biobridge.source, element, stage, '*.parquet.gzip'))
        print(f"{len(file_list)} files")

        # Loop by chunks
        graph_dict[element][stage] = []
        for i in range(0, len(file_list), chunk_size):
            chunk_files = file_list[i:i+chunk_size]
            chunk_df = cudf.concat([cudf.read_parquet(f) for f in chunk_files], ignore_index=True)
            graph_dict[element][stage].append(chunk_df)

In [None]:
# Concate dataframes of enrichment
graph_dict['nodes']['enrichment'] = cudf.concat(graph_dict['nodes']['enrichment'])
graph_dict['edges']['enrichment'] = cudf.concat(graph_dict['edges']['enrichment'])

In [None]:
# Merge nodes embedding into nodes enrichment
graph_dict["nodes"] = graph_dict["nodes"]["enrichment"].merge(
    graph_dict["nodes"]["embedding"][0],
    how="left",
    on="node_id"
)

In [None]:
# Check head
graph_dict["nodes"].head(5)

In [None]:
graph_dict["edges"]["enrichment"].drop(columns="feat_emb", inplace=True)

In [None]:
# Initialize 'feat_emb' column as nullable object
graph_dict["edges"]["enrichment"]["feat_emb"] = None

# Loop over each embedding chunk
for i in range(len(graph_dict['edges']['embedding'])):
    emb_df = graph_dict['edges']['embedding'][i]
    
    # Merge with enrichment on 'triplet_index'
    merged = graph_dict["edges"]["enrichment"].merge(
        emb_df[["triplet_index", "edge_emb"]],
        on="triplet_index",
        how="left",
    )

    # Update feat_emb only where it's still missing
    mask = merged["feat_emb"].isnull()
    merged.loc[mask, "feat_emb"] = merged.loc[mask, "edge_emb"]

    # # Drop the temporary 'edge_emb' column
    merged.drop(columns=["edge_emb"], inplace=True)

    # # Store updated enrichment back
    graph_dict["edges"]["enrichment"] = merged

    # # Free memory from processed embedding
    graph_dict['edges']['embedding'][i] = None
    del emb_df  # For GC

# Merge edges embedding into edges enrichment
graph_dict["edges"] = graph_dict["edges"]["enrichment"]
