In [1]:
import os
import glob
import cudf
import dask_cudf
import hydra
import cupy as cp

### Loading BioBridge-PrimeKG Multimodal Data

In [2]:
# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction
cfg

{'_target_': 'talk2knowledgegraphs.tools.multimodal_subgraph_extraction', 'ollama_embeddings': ['nomic-embed-text'], 'temperature': 0.1, 'streaming': False, 'topk': 5, 'topk_e': 5, 'cost_e': 0.5, 'c_const': 0.01, 'root': -1, 'num_clusters': 1, 'pruning': 'gw', 'verbosity_level': 0, 'node_id_column': 'node_id', 'node_attr_column': 'node_attr', 'edge_src_column': 'edge_src', 'edge_attr_column': 'edge_attr', 'edge_dst_column': 'edge_dst', 'node_colors_dict': {'gene/protein': '#6a79f7', 'molecular_function': '#82cafc', 'cellular_component': '#3f9b0b', 'biological_process': '#c5c9c7', 'drug': '#c4a661', 'disease': '#80013f'}, 'biobridge': {'source': '/mnt/blockstorage/biobridge_multimodal/', 'node_type': ['gene/protein', 'molecular_function', 'cellular_component', 'biological_process', 'drug', 'disease']}}

In [3]:
cfg.biobridge.source = "/mnt/blockstorage/biobridge_multimodal"
# cfg.biobridge.source = "../../../../AIAgents4Pharma/aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal"

In [4]:
# Loop over nodes and edges
graph_dict = {}
for element in ["nodes", "edges"]:
    # Make an empty dictionary for each folder
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        # Create the file pattern for the current subfolder
        file_list = glob.glob(os.path.join(cfg.biobridge.source,
                                           element,
                                           stage, '*.parquet.gzip'))
        print(file_list)
        # Read and concatenate all dataframes in the folder
        # Except the edges embedding, which is too large to read in one go
        # We are using a chunk size to read the edges embedding in smaller parts instead
        if element == "edges" and stage == "embedding":
            # For edges embedding, only read two columns: triplet_index and edge_emb
            # graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in file_list[:2]], ignore_index=True)
            # Loop by chunks
            # file_list = file_list[:2]
            chunk_size = 5
            graph_dict[element][stage] = []
            for i in range(0, len(file_list), chunk_size):
                chunk_files = file_list[i:i+chunk_size]
                chunk_df = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in chunk_files], ignore_index=True)
                graph_dict[element][stage].append(chunk_df)
        else:
            # For nodes and edges enrichment, read and concatenate all dataframes in the folder
            # This includes the nodes embedding, which is small enough to read in one go
            graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f) for f in file_list], ignore_index=True)

nodes enrichment
['/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/molecular_function.parquet.gzip']
nodes embedding
['/mnt/blockstorage/biobridge_multimodal/nodes/embedding/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_mu

In [5]:
# Get nodes enrichment and embedding dataframes
nodes_enrichment_df = graph_dict['nodes']['enrichment']
nodes_embedding_df = graph_dict['nodes']['embedding']
# Get edges enrichment and embedding dataframes
edges_enrichment_df = graph_dict['edges']['enrichment']
edges_embedding_df = graph_dict['edges']['embedding'] # consisted of a list of dataframes

### Setup Milvus Database

In [6]:
import numpy as np

from collections import defaultdict
from scipy.sparse import csr_matrix
from pymilvus import MilvusClient
from pymilvus import CollectionSchema, FieldSchema, DataType
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from tqdm import tqdm

In [7]:
# Configuration for Milvus
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_database = "t2kg_primekg"

In [9]:
# Setup Milvus client and database
milvus_client = MilvusClient(uri=milvus_uri, token=milvus_token)
# Check if the database exists, if not create it
if milvus_database not in milvus_client.list_databases():
    # Create the database if it does not exist
    milvus_client.create_database(db_name=milvus_database)
# Use the newly created database
milvus_client.using_database(milvus_database)

#### Building Node Collection

In [113]:
# Merge nodes enrichment and embedding dataframes
merged_nodes_df = nodes_enrichment_df.merge(
    nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]],
    on="node_id",
    how="left"
)
# del nodes_enrichment_df, nodes_embedding_df  # Free memory

# Utility: chunk generator
def chunked(data_list, chunk_size):
    for i in range(0, len(data_list), chunk_size):
        yield data_list[i:i + chunk_size]

# Loop over group enrichment nodes by type
for node_type, nodes_df in merged_nodes_df.groupby('node_type'):
    print(f"Processing node type: {node_type}")
        
    # Configuration for Milvus collection
    node_coll_name = f"{milvus_database}_node_{node_type.replace('/', '_')}"

    # Define schema for the collection
    desc_emb_dim = len(nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])
    feat_emb_dim = len(nodes_df.iloc[0]['feat_emb'].to_arrow().to_pylist()[0])
    node_fields = [
        FieldSchema(name="node_index", dtype=DataType.INT64, is_primary=True),
        FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(name="node_name", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(name="node_type", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(name="desc", dtype=DataType.VARCHAR, max_length=40960),
        FieldSchema(name="feat", dtype=DataType.VARCHAR, max_length=40960),
        FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),
        FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=feat_emb_dim),
    ]
    node_schema = CollectionSchema(fields=node_fields, description=f"{node_type} nodes")

    # Set index params
    index_params = milvus_client.prepare_index_params()
    # Indexing for desc_emb field
    index_params.add_index(field_name="desc_emb", index_type="AUTOINDEX", 
                           index_name="desc_emb_index", metric_type="COSINE")
    # Indexing for feat_emb field
    index_params.add_index(field_name="feat_emb", index_type="AUTOINDEX", 
                           index_name="feat_emb_index", metric_type="COSINE")

    # Create the collection if it does not exist
    if node_coll_name not in milvus_client.list_collections():
        milvus_client.create_collection(collection_name=node_coll_name, 
                                        schema=node_schema,
                                        index_params=index_params)
        
    # Populate the collection with data
    # Build list of records
    data = [
        {
            "node_index": idx,
            "node_id": nid,
            "node_name": nname,
            "node_type": ntype,
            "desc": desc,
            "feat": feat,
            "desc_emb": desc_emb,
            "feat_emb": feat_emb,
        }
        for idx, nid, nname, ntype, desc, feat, desc_emb, feat_emb in zip(
            nodes_df["node_index"].to_arrow().to_pylist(),
            nodes_df["node_id"].to_arrow().to_pylist(),
            nodes_df["node_name"].to_arrow().to_pylist(),
            nodes_df["node_type"].to_arrow().to_pylist(),
            nodes_df["desc"].to_arrow().to_pylist(),
            nodes_df["feat"].to_arrow().to_pylist(),
            nodes_df["desc_emb"].list.leaves.to_cupy().astype(cp.float32).reshape(nodes_df.shape[0], -1).tolist(),
            nodes_df["feat_emb"].list.leaves.to_cupy().astype(cp.float32).reshape(nodes_df.shape[0], -1).tolist()
        )                
    ]
                
    # Insert data in chunks to avoid memory issues
    # Batch insert into Milvus
    batch_size = 500
    for i, batch in enumerate(chunked(data, batch_size)):
        print(f"Inserting batch {i + 1}/{(len(data) - 1) // batch_size + 1} ...")
        milvus_client.insert(collection_name=node_coll_name, data=batch)

    # Flush the collection to ensure data is written
    milvus_client.flush(collection_name=node_coll_name)
    
    # Printout collection status
    print(milvus_client.get_collection_stats(collection_name=node_coll_name))

Processing node type: biological_process
Inserting batch 1/55 ...
Inserting batch 2/55 ...
Inserting batch 3/55 ...
Inserting batch 4/55 ...
Inserting batch 5/55 ...
Inserting batch 6/55 ...
Inserting batch 7/55 ...
Inserting batch 8/55 ...
Inserting batch 9/55 ...
Inserting batch 10/55 ...
Inserting batch 11/55 ...
Inserting batch 12/55 ...
Inserting batch 13/55 ...
Inserting batch 14/55 ...
Inserting batch 15/55 ...
Inserting batch 16/55 ...
Inserting batch 17/55 ...
Inserting batch 18/55 ...
Inserting batch 19/55 ...
Inserting batch 20/55 ...
Inserting batch 21/55 ...
Inserting batch 22/55 ...
Inserting batch 23/55 ...
Inserting batch 24/55 ...
Inserting batch 25/55 ...
Inserting batch 26/55 ...
Inserting batch 27/55 ...
Inserting batch 28/55 ...
Inserting batch 29/55 ...
Inserting batch 30/55 ...
Inserting batch 31/55 ...
Inserting batch 32/55 ...
Inserting batch 33/55 ...
Inserting batch 34/55 ...
Inserting batch 35/55 ...
Inserting batch 36/55 ...
Inserting batch 37/55 ...
Insert

In [114]:
milvus_client.list_collections()

['t2kg_primekg_node_biological_process',
 't2kg_primekg_node_cellular_component',
 't2kg_primekg_node_disease',
 't2kg_primekg_node_drug',
 't2kg_primekg_node_gene_protein',
 't2kg_primekg_node_molecular_function']

In [116]:
for coll in milvus_client.list_collections():
    print(f"Collection: {coll}")
    print(milvus_client.get_collection_stats(collection_name=coll))

Collection: t2kg_primekg_node_biological_process
{'row_count': 27409}
Collection: t2kg_primekg_node_cellular_component
{'row_count': 4011}
Collection: t2kg_primekg_node_disease
{'row_count': 17054}
Collection: t2kg_primekg_node_drug
{'row_count': 6759}
Collection: t2kg_primekg_node_gene_protein
{'row_count': 18797}
Collection: t2kg_primekg_node_molecular_function
{'row_count': 10951}


In [118]:
merged_nodes_df

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
0,4990,5019,GAS2L1_(5019),GAS2L1,gene/protein,GAS2L1 belongs to gene/protein node. GAS2L1 is...,MADPVAGIAGSAAKSVRPFRSSEAYVEAMKEDLAEWLNALYGLGLP...,"[-0.032280046, 0.0018808923, -0.018179184, -0....","[0.005598554387688637, -0.013612139038741589, ..."
1,4991,5020,BAHCC1_(5020),BAHCC1,gene/protein,BAHCC1 belongs to gene/protein node. BAHCC1 is...,MDGRDFAPPPHLLSERGSLGHRSAAAAARLAPAGPAAQPPAHFQPG...,"[-0.03259, -0.010202489, 0.0033706415, -0.0145...","[0.006144949235022068, -0.037945706397295, 0.0..."
2,4992,5021,RBM5_(5021),RBM5,gene/protein,RBM5 belongs to gene/protein node. RBM5 is RNA...,MGSDKRVSRTERSGRYGSIIDRDDRDERESRSRRRDSDYKRSSDDR...,"[-0.050645936, -0.012549261, -0.017889792, -0....","[-0.09598235785961151, -0.05643630400300026, 0..."
3,4993,5022,ITGAM_(5022),ITGAM,gene/protein,ITGAM belongs to gene/protein node. ITGAM is i...,MALRVLLLTALTLCHGFNLDTENAMTFQENARGFGQSVVQLQGSRV...,"[-0.028032856, -0.014978595, -0.034716014, -0....","[0.06611106544733047, 0.0018108593067154288, -..."
4,4994,5023,PTGDR_(5023),PTGDR,gene/protein,PTGDR belongs to gene/protein node. PTGDR is p...,MKSPFYRCQNTTSVEKGNSAVMGGVLFSTGLLGNLLALGLLARSGL...,"[-0.030831868, -0.0094997315, -0.033455282, -0...","[0.013212235644459724, -0.04957687109708786, 0..."
...,...,...,...,...,...,...,...,...,...
84976,71907,114044,DNA dealkylation involved in DNA repair_(114044),DNA dealkylation involved in DNA repair,biological_process,DNA dealkylation involved in DNA repair belong...,"The repair of alkylation damage, e.g. the remo...","[-0.0136713805, 0.0054800464, -0.003755911, -0...","[-0.02218257, -0.0037454518, 0.0019188847, -0...."
84977,71914,114053,establishment of cell polarity involved in mes...,establishment of cell polarity involved in mes...,biological_process,establishment of cell polarity involved in mes...,The specification and formation of anisotropic...,"[-0.015679583, 0.00986015, -0.0077391095, 0.00...","[-0.030488482, 0.015335836, -0.0061991066, -0...."
84978,71910,114049,positive regulation of lactation by mesenchyma...,positive regulation of lactation by mesenchyma...,biological_process,positive regulation of lactation by mesenchyma...,"The process that increases the rate, frequency...","[-0.02807092, 0.0059592845, -0.01573714, -0.01...","[-0.04906116, 0.0010414534, -0.012413609, -0.0..."
84979,71911,114050,positive regulation of dentin-containing tooth...,positive regulation of dentin-containing tooth...,biological_process,positive regulation of dentin-containing tooth...,Any process that initiates the formation of a ...,"[-0.020066187, 0.012386619, 0.0020440302, -0.0...","[-0.027301803, 0.006703797, 0.002295426, -0.01..."


In [130]:
merged_nodes_df["desc_emb"].loc[0]

[-0.032280046,
 0.0018808923,
 -0.018179184,
 -0.037051413,
 -0.02556281,
 0.017912626,
 -0.011135416,
 0.028921427,
 -0.0036285063,
 -0.016939694,
 0.014847221,
 -0.005830933,
 0.00070221093,
 0.009176223,
 0.0039650346,
 0.0129013555,
 0.026495758,
 -0.011888439,
 9.553359e-06,
 0.016859727,
 0.0061374735,
 0.020644834,
 -0.02327042,
 -0.033879388,
 -0.00048230146,
 0.011688521,
 0.012474865,
 -0.022683995,
 0.0042282594,
 -0.0072770044,
 -0.0012703104,
 -0.019565279,
 -0.0013061289,
 0.002778856,
 -0.00041607872,
 -0.022923896,
 0.009196214,
 0.0017292881,
 0.020244999,
 -0.009316165,
 0.02067149,
 0.009649361,
 -0.0005472747,
 -0.0097959675,
 0.008216618,
 0.0032886462,
 0.009622705,
 -0.0076102004,
 -0.027162151,
 -0.0038950632,
 -0.0041016447,
 0.027148824,
 -0.027908511,
 -0.006693911,
 -0.00838988,
 -0.028494935,
 0.012368241,
 0.0025706084,
 0.0022124227,
 -0.0028821467,
 -0.023723567,
 -0.023123814,
 -0.024923073,
 0.0074502663,
 -0.007863429,
 0.008030028,
 -0.007996708,
 0.

In [153]:
%%time

vector_to_search = merged_nodes_df["desc_emb"].loc[2]

# Vector similarity search in Milvus by defining a particular collection (node_type)
results = milvus_client.search(
    collection_name='t2kg_primekg_node_gene_protein',
    data=[vector_to_search],
    anns_field="desc_emb",
    output_fields=["node_id", "node_name"],
    search_params={"metric_type": "COSINE"},
    limit=10,
)


CPU times: user 8.59 ms, sys: 3.65 ms, total: 12.2 ms
Wall time: 26.3 ms


In [154]:
# Check the ground truth for the search
merged_nodes_df.loc[2]

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
2,4992,5021,RBM5_(5021),RBM5,gene/protein,RBM5 belongs to gene/protein node. RBM5 is RNA...,MGSDKRVSRTERSGRYGSIIDRDDRDERESRSRRRDSDYKRSSDDR...,"[-0.050645936, -0.012549261, -0.017889792, -0....","[-0.09598235785961151, -0.05643630400300026, 0..."


In [155]:
# Check results
results

data: [[{'node_index': 4992, 'distance': 1.0, 'entity': {'node_id': 'RBM5_(5021)', 'node_name': 'RBM5'}}, {'node_index': 9760, 'distance': 0.9463825225830078, 'entity': {'node_id': 'RBM22_(9845)', 'node_name': 'RBM22'}}, {'node_index': 9691, 'distance': 0.9457356929779053, 'entity': {'node_id': 'RBM25_(9776)', 'node_name': 'RBM25'}}, {'node_index': 943, 'distance': 0.9427413940429688, 'entity': {'node_id': 'RBM6_(943)', 'node_name': 'RBM6'}}, {'node_index': 5240, 'distance': 0.9424200057983398, 'entity': {'node_id': 'RBM28_(5270)', 'node_name': 'RBM28'}}, {'node_index': 2848, 'distance': 0.9393284916877747, 'entity': {'node_id': 'RBM10_(2860)', 'node_name': 'RBM10'}}, {'node_index': 8265, 'distance': 0.938931941986084, 'entity': {'node_id': 'RBM17_(8330)', 'node_name': 'RBM17'}}, {'node_index': 9632, 'distance': 0.9384500980377197, 'entity': {'node_id': 'RBM19_(9716)', 'node_name': 'RBM19'}}, {'node_index': 6722, 'distance': 0.9366649389266968, 'entity': {'node_id': 'RBM42_(6767)', 'no

In [156]:
# Get node indices from the results
[n['node_index'] for n in results[0]]

[4992, 9760, 9691, 943, 5240, 2848, 8265, 9632, 6722, 27886]

In [157]:
# Get the cosine similarity scores
[n['distance'] for n in results[0]]

[1.0,
 0.9463825225830078,
 0.9457356929779053,
 0.9427413940429688,
 0.9424200057983398,
 0.9393284916877747,
 0.938931941986084,
 0.9384500980377197,
 0.9366649389266968,
 0.9360105991363525]