In this tutorial, we will prepare Milvus database for storing and searching nodes and edges of a graph.

In particular, we are using PrimeKG multimodal data from the BioBridge project.

In [1]:
# Load necessary libraries
import os
import glob
import hydra
import cudf
import cupy as cp
from pymilvus import (
    db,
    connections,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
    utility,
    MilvusClient
)
from tqdm import tqdm

### Loading BioBridge-PrimeKG Multimodal Data

First, we need to get the path to the directory containing the parquet files of nodes and edges.

For nodes and edges, we have a separate folder that contains its enrichment and embeddings.

In [2]:
# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction
cfg

{'_target_': 'talk2knowledgegraphs.tools.multimodal_subgraph_extraction', 'ollama_embeddings': ['nomic-embed-text'], 'temperature': 0.1, 'streaming': False, 'topk': 5, 'topk_e': 5, 'cost_e': 0.5, 'c_const': 0.01, 'root': -1, 'num_clusters': 1, 'pruning': 'gw', 'verbosity_level': 0, 'node_id_column': 'node_id', 'node_attr_column': 'node_attr', 'edge_src_column': 'edge_src', 'edge_attr_column': 'edge_attr', 'edge_dst_column': 'edge_dst', 'node_colors_dict': {'gene/protein': '#6a79f7', 'molecular_function': '#82cafc', 'cellular_component': '#3f9b0b', 'biological_process': '#c5c9c7', 'drug': '#c4a661', 'disease': '#80013f'}, 'biobridge': {'source': '/mnt/blockstorage/biobridge_multimodal/', 'node_type': ['gene/protein', 'molecular_function', 'cellular_component', 'biological_process', 'drug', 'disease']}}

In [3]:
# You can set the source directory for biobridge data here
cfg.biobridge.source = "/mnt/blockstorage/biobridge_multimodal"

In [4]:
%%time

# Loop over nodes and edges
graph_dict = {}
for element in ["nodes", "edges"]:
    # Make an empty dictionary for each folder
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        # Create the file pattern for the current subfolder
        file_list = glob.glob(os.path.join(cfg.biobridge.source,
                                           element,
                                           stage, '*.parquet.gzip'))
        print(file_list)
        # Read and concatenate all dataframes in the folder
        # Except the edges embedding, which is too large to read in one go
        # We are using a chunk size to read the edges embedding in smaller parts instead
        if element == "edges" and stage == "embedding":
            # For edges embedding, only read two columns: triplet_index and edge_emb
            # graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in file_list[:2]], ignore_index=True)
            # Loop by chunks
            # file_list = file_list[:2]
            chunk_size = 5
            graph_dict[element][stage] = []
            for i in range(0, len(file_list), chunk_size):
                chunk_files = file_list[i:i+chunk_size]
                chunk_df = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in chunk_files], ignore_index=True)
                graph_dict[element][stage].append(chunk_df)
        else:
            # For nodes and edges enrichment, read and concatenate all dataframes in the folder
            # This includes the nodes embedding, which is small enough to read in one go
            graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f) for f in file_list], ignore_index=True)

nodes enrichment
['/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/molecular_function.parquet.gzip']
nodes embedding
['/mnt/blockstorage/biobridge_multimodal/nodes/embedding/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_mu

In [5]:
# Get nodes enrichment and embedding dataframes
nodes_enrichment_df = graph_dict['nodes']['enrichment']
nodes_embedding_df = graph_dict['nodes']['embedding']

# Get edges enrichment and embedding dataframes
edges_enrichment_df = graph_dict['edges']['enrichment']
edges_embedding_df = graph_dict['edges']['embedding'] # !!consisted of a list of dataframes!!

In [6]:
# Merge nodes enrichment and embedding dataframes
merged_nodes_df = nodes_enrichment_df.merge(
    nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]],
    on="node_id",
    how="left"
)
# del nodes_enrichment_df, nodes_embedding_df  # Free memory

### Setup Milvus Database

In [7]:
# Configuration for Milvus
milvus_host = "localhost"
milvus_port = "19530"
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_user = "root"
milvus_password = "Milvus"
milvus_database = "t2kg_primekg"

In [8]:
# Connect to Milvus
connections.connect(
    alias="default",
    host=milvus_host,
    port=milvus_port,
    user=milvus_user,
    password=milvus_password
)

In [9]:
# Check if the database exists, create if it doesn't
if milvus_database not in db.list_database():
    db.create_database(milvus_database)

# Switch to the desired database
db.using_database(milvus_database)

In [10]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

In [11]:
# A helper function to chunk the data into smaller parts
# Utility: chunk generator
def chunked(data_list, chunk_size):
    for i in range(0, len(data_list), chunk_size):
        yield data_list[i:i + chunk_size]

#### Building Node Collection (Description Embedding)

In [14]:
%%time

# Configuration for Milvus collection
node_coll_name = f"{milvus_database}_nodes"

# Define schema for the collection
# Leave out the feat and feat_emb fields for now
desc_emb_dim = len(merged_nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])
node_fields = [
    FieldSchema(name="node_index", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="node_name", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="node_type", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="desc", dtype=DataType.VARCHAR, max_length=40960),
    FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),
]
schema  = CollectionSchema(fields=node_fields, description=f"Schema for collection {node_coll_name}")

# Create collection if it doesn't exist
if not utility.has_collection(node_coll_name):
    collection = Collection(name=node_coll_name, schema=schema)
else:
    collection = Collection(name=node_coll_name)

# Create indexes
collection.create_index(
    field_name="node_index",
    index_params={"index_type": "GPU_CAGRA"}, # STL_SORT
    index_name="node_index_index"
)

collection.create_index(
    field_name="desc_emb",
    index_params={"index_type": "GPU_CAGRA", "metric_type": "COSINE"}, # AUTOINDEX
    index_name="desc_emb_index"
)

# Prepare data for insertion
data = [
    merged_nodes_df["node_index"].to_arrow().to_pylist(),
    merged_nodes_df["node_id"].to_arrow().to_pylist(),
    merged_nodes_df["node_name"].to_arrow().to_pylist(),
    merged_nodes_df["node_type"].to_arrow().to_pylist(),
    merged_nodes_df["desc"].to_arrow().to_pylist(),
    cp.asarray(merged_nodes_df["desc_emb"].list.leaves).astype(cp.float32)
        .reshape(merged_nodes_df.shape[0], -1)
        .tolist(),
]

# Insert data in batches
batch_size = 500
total = len(data[0])
for i in tqdm(range(0, total, batch_size)):
    batch = [
        col[i:i+batch_size] for col in data
    ]
    collection.insert(batch)

# Flush to persist data
collection.flush()

# Get collection stats
print(collection.num_entities)

KeyboardInterrupt: 

In [12]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes
84981


In [13]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index == 13814"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ["{'desc': 'Copper belongs to drug node. Copper is a transition metal and a trace element in the body. It is important to the function of many enzymes including cytochrome c oxidase, monoamine oxidase and superoxide dismutase. Copper is commonly used in contraceptive intrauterine devices (IUD). Copper is absorbed from the gut via high affinity copper uptake protein and likely through low affinity copper uptake protein and natural resistance-associated macrophage protein-2. It is believed that copper is reduced to the Cu1+ form prior to transport. Once inside the enterocyte, it is bound to copper transport protein ATOX1 which shuttles the ion to copper transporting ATPase-1 on the golgi membrane which take up copper into the golgi apparatus. Once copper has been secreted by enterocytes into the systemic circulation it remain largely bound by ceruloplasmin (65-90%), albumin (18%), and alpha 2-macroglobulin (12%).  Copper is nearly entirely bound by ceruloplasmin (65-90%), plasma al

In [14]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index in [13814, 13815]"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ["{'node_type': 'drug', 'desc': 'Copper belongs to drug node. Copper is a transition metal and a trace element in the body. It is important to the function of many enzymes including cytochrome c oxidase, monoamine oxidase and superoxide dismutase. Copper is commonly used in contraceptive intrauterine devices (IUD). Copper is absorbed from the gut via high affinity copper uptake protein and likely through low affinity copper uptake protein and natural resistance-associated macrophage protein-2. It is believed that copper is reduced to the Cu1+ form prior to transport. Once inside the enterocyte, it is bound to copper transport protein ATOX1 which shuttles the ion to copper transporting ATPase-1 on the golgi membrane which take up copper into the golgi apparatus. Once copper has been secreted by enterocytes into the systemic circulation it remain largely bound by ceruloplasmin (65-90%), albumin (18%), and alpha 2-macroglobulin (12%).  Copper is nearly entirely bound by ceruloplasmi

In [15]:
merged_nodes_df.loc[2]

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
2,6544,6588,IREB2_(6588),IREB2,gene/protein,IREB2 belongs to gene/protein node. IREB2 is i...,MDAPKAGYAFEYLIETLNDSSHKKFFDVSKLGTKYDVLPYSIRVLL...,"[-0.05323749, -0.01769744, -0.016430508, -0.02...","[0.022037705406546593, -0.0630035325884819, 0...."


In [16]:
%%time

vector_to_search = merged_nodes_df["desc_emb"].loc[2]


# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = merged_nodes_df["desc_emb"].iloc[2]  # or .loc[2], both work
search_params = {"metric_type": "COSINE"}
results = collection.search(
    data=[vector_to_search],
    anns_field="desc_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"]
)

results

CPU times: user 18.6 ms, sys: 2.42 ms, total: 21.1 ms
Wall time: 28.8 ms


data: [[{'node_index': 6544, 'distance': 1.0, 'entity': {'node_id': 'IREB2_(6588)', 'node_name': 'IREB2'}}, {'node_index': 28031, 'distance': 0.9045131206512451, 'entity': {'node_id': 'IRF2BP2_(34878)', 'node_name': 'IRF2BP2'}}, {'node_index': 5767, 'distance': 0.898109495639801, 'entity': {'node_id': 'IRF2BP1_(5804)', 'node_name': 'IRF2BP1'}}, {'node_index': 2070, 'distance': 0.8949854373931885, 'entity': {'node_id': 'IRF8_(2075)', 'node_name': 'IRF8'}}, {'node_index': 10175, 'distance': 0.892490565776825, 'entity': {'node_id': 'IGF2BP3_(10268)', 'node_name': 'IGF2BP3'}}, {'node_index': 10643, 'distance': 0.8924646377563477, 'entity': {'node_id': 'IGF2BP2_(10746)', 'node_name': 'IGF2BP2'}}, {'node_index': 6243, 'distance': 0.891655445098877, 'entity': {'node_id': 'IRF2BPL_(6285)', 'node_name': 'IRF2BPL'}}, {'node_index': 7827, 'distance': 0.8912501931190491, 'entity': {'node_id': 'RREB1_(7886)', 'node_name': 'RREB1'}}, {'node_index': 6573, 'distance': 0.8909592628479004, 'entity': {'n

#### Building Node Collection (Node Type-specific Embedding)

Note that nodes information of the PrimeKG data is different for each node type, 
we are going to build a separate collection for each node type.

We will use the node type as the collection name.

In [17]:
%%time

# Loop over group enrichment nodes by node_type
for node_type, nodes_df in tqdm(merged_nodes_df.groupby('node_type')):
    print(f"Processing node type: {node_type}")

    # Milvus collection name for this node_type
    node_coll_name = f"{milvus_database}_nodes_{node_type.replace('/', '_')}"

    # Define collection schema
    feat_emb_dim = len(nodes_df.iloc[0]['feat_emb'].to_arrow().to_pylist()[0])
    node_fields = [
        FieldSchema(name="node_index", dtype=DataType.INT64, is_primary=True, auto_id=False),
        FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(name="node_name", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(name="node_type", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(name="feat", dtype=DataType.VARCHAR, max_length=40960),
        FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=feat_emb_dim),
    ]
    schema = CollectionSchema(fields=node_fields, description=f"schema for collection {node_coll_name}")

    # Create collection if not exists
    if not utility.has_collection(node_coll_name):
        collection = Collection(name=node_coll_name, schema=schema)
    else:
        collection = Collection(name=node_coll_name)

    # Create index for node_index field (scalar)
    collection.create_index(
        field_name="node_index",
        index_params={"index_type": "STL_SORT"},
        index_name="node_index_index"
    )

    # Create index for feat_emb (vector)
    collection.create_index(
        field_name="feat_emb",
        index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"},
        index_name="feat_emb_index"
    )

    # Prepare data for insertion
    # Columns must be lists of values in order matching schema fields
    data = [
        nodes_df["node_index"].to_arrow().to_pylist(),
        nodes_df["node_id"].to_arrow().to_pylist(),
        nodes_df["node_name"].to_arrow().to_pylist(),
        nodes_df["node_type"].to_arrow().to_pylist(),
        nodes_df["feat"].to_arrow().to_pylist(),
        cp.asarray(nodes_df["feat_emb"].list.leaves).astype(cp.float32)
            .reshape(nodes_df.shape[0], -1)
            .tolist(),
    ]

    # Batch insert data in chunks
    batch_size = 500
    total_rows = len(data[0])
    for i in tqdm(range(0, total_rows, batch_size)):
        batch = [col[i:i + batch_size] for col in data]
        collection.insert(batch)

    # Flush the collection to ensure data is persisted
    collection.flush()

    # Print collection stats (number of entities and segment info)
    stats = collection.num_entities
    print(f"Collection {node_coll_name} stats:")
    print(stats)

  0%|          | 0/6 [00:00<?, ?it/s]

Processing node type: biological_process


100%|██████████| 55/55 [00:36<00:00,  1.49it/s]
 17%|█▋        | 1/6 [00:42<03:32, 42.55s/it]

Collection t2kg_primekg_nodes_biological_process stats:
27409
Processing node type: cellular_component


100%|██████████| 9/9 [00:05<00:00,  1.53it/s]
 33%|███▎      | 2/6 [00:51<01:31, 22.76s/it]

Collection t2kg_primekg_nodes_cellular_component stats:
4011
Processing node type: disease


100%|██████████| 35/35 [00:23<00:00,  1.52it/s]
 50%|█████     | 3/6 [01:17<01:12, 24.27s/it]

Collection t2kg_primekg_nodes_disease stats:
17054
Processing node type: drug


100%|██████████| 14/14 [00:03<00:00,  4.18it/s]
 67%|██████▋   | 4/6 [01:23<00:34, 17.15s/it]

Collection t2kg_primekg_nodes_drug stats:
6759
Processing node type: gene/protein


100%|██████████| 38/38 [00:33<00:00,  1.14it/s]
 83%|████████▎ | 5/6 [02:00<00:24, 24.36s/it]

Collection t2kg_primekg_nodes_gene_protein stats:
18797
Processing node type: molecular_function


100%|██████████| 22/22 [00:15<00:00,  1.46it/s]
100%|██████████| 6/6 [02:19<00:00, 23.31s/it]

Collection t2kg_primekg_nodes_molecular_function stats:
10951
CPU times: user 29.4 s, sys: 3.11 s, total: 32.5 s
Wall time: 2min 20s





In [18]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes_molecular_function
10951
Collection: t2kg_primekg_nodes
84981
Collection: t2kg_primekg_nodes_cellular_component
4011
Collection: t2kg_primekg_nodes_disease
17054
Collection: t2kg_primekg_nodes_gene_protein
18797
Collection: t2kg_primekg_nodes_biological_process
27409
Collection: t2kg_primekg_nodes_drug
6759


In [19]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection('t2kg_primekg_nodes_gene_protein')

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = merged_nodes_df["feat_emb"].iloc[2]  # or .loc[2], both work
search_params = {"metric_type": "COSINE"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"]
)

results

CPU times: user 19 ms, sys: 786 μs, total: 19.8 ms
Wall time: 2.44 s


data: [[{'node_index': 6544, 'distance': 0.9999999403953552, 'entity': {'node_id': 'IREB2_(6588)', 'node_name': 'IREB2'}}, {'node_index': 698, 'distance': 0.9973249435424805, 'entity': {'node_id': 'ACO1_(698)', 'node_name': 'ACO1'}}, {'node_index': 4447, 'distance': 0.9917730689048767, 'entity': {'node_id': 'CPS1_(4473)', 'node_name': 'CPS1'}}, {'node_index': 1068, 'distance': 0.9909224510192871, 'entity': {'node_id': 'MTHFD1_(1070)', 'node_name': 'MTHFD1'}}, {'node_index': 2010, 'distance': 0.990104615688324, 'entity': {'node_id': 'CAD_(2015)', 'node_name': 'CAD'}}, {'node_index': 7587, 'distance': 0.9892357587814331, 'entity': {'node_id': 'MTR_(7641)', 'node_name': 'MTR'}}, {'node_index': 5069, 'distance': 0.9892129898071289, 'entity': {'node_id': 'ATP6V1A_(5099)', 'node_name': 'ATP6V1A'}}, {'node_index': 3111, 'distance': 0.9891620874404907, 'entity': {'node_id': 'AHCYL1_(3125)', 'node_name': 'AHCYL1'}}, {'node_index': 27259, 'distance': 0.9891443252563477, 'entity': {'node_id': 'SB

In [20]:
# Check the ground truth for the search
merged_nodes_df.loc[2]

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
2,6544,6588,IREB2_(6588),IREB2,gene/protein,IREB2 belongs to gene/protein node. IREB2 is i...,MDAPKAGYAFEYLIETLNDSSHKKFFDVSKLGTKYDVLPYSIRVLL...,"[-0.05323749, -0.01769744, -0.016430508, -0.02...","[0.022037705406546593, -0.0630035325884819, 0...."


In [21]:
# Get node indices from the results
[n['node_index'] for n in results[0]]

[6544, 698, 4447, 1068, 2010, 7587, 5069, 3111, 27259, 8547]

In [22]:
# Get the cosine similarity scores
[n['distance'] for n in results[0]]

[0.9999999403953552,
 0.9973249435424805,
 0.9917730689048767,
 0.9909224510192871,
 0.990104615688324,
 0.9892357587814331,
 0.9892129898071289,
 0.9891620874404907,
 0.9891443252563477,
 0.989010751247406]

#### Building Edge Collection

Subsquently, we are also building the edges collection in Milvus.

Note that the edges information of PrimeKG has massive records, so once again we are chunking the data to avoid memory issues.

In [23]:
%%time

# Define collection name
edge_coll_name = f"{milvus_database}_edges"

# Define schema
edge_fields = [
    FieldSchema(name="triplet_index", dtype=DataType.INT64, is_primary=True, auto_id=False),
    FieldSchema(name="head_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="head_index", dtype=DataType.INT64),
    FieldSchema(name="tail_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="tail_index", dtype=DataType.INT64),
    FieldSchema(name="edge_type", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="display_relation", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="feat", dtype=DataType.VARCHAR, max_length=40960),
    FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=1536),
]
edge_schema = CollectionSchema(fields=edge_fields, description="Schema for edges collection")

# Create collection if not exists
if not utility.has_collection(edge_coll_name):
    collection = Collection(name=edge_coll_name, schema=edge_schema)
else:
    collection = Collection(name=edge_coll_name)

# Create indexes
collection.create_index(field_name="triplet_index", index_params={"index_type": "STL_SORT"}, index_name="triplet_index_index")
collection.create_index(field_name="head_index", index_params={"index_type": "STL_SORT"}, index_name="head_index_index")
collection.create_index(field_name="tail_index", index_params={"index_type": "STL_SORT"}, index_name="tail_index_index")
collection.create_index(field_name="feat_emb", index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}, index_name="feat_emb_index")

# Iterate over chunked edges embedding df
for edges_df in tqdm(edges_embedding_df):
    # Merge enrichment with embedding
    merged_edges_df = edges_enrichment_df.merge(
        edges_df[["triplet_index", "edge_emb"]],
        on="triplet_index",
        how="inner"
    )

    # Prepare data fields in column-wise format
    data = [
        merged_edges_df["triplet_index"].to_arrow().to_pylist(),
        merged_edges_df["head_id"].to_arrow().to_pylist(),
        merged_edges_df["head_index"].to_arrow().to_pylist(),
        merged_edges_df["tail_id"].to_arrow().to_pylist(),
        merged_edges_df["tail_index"].to_arrow().to_pylist(),
        merged_edges_df["edge_type_str"].to_arrow().to_pylist(),
        merged_edges_df["display_relation"].to_arrow().to_pylist(),
        merged_edges_df["feat"].to_arrow().to_pylist(),
        cp.asarray(merged_edges_df["edge_emb"].list.leaves).astype(cp.float32)
            .reshape(merged_edges_df.shape[0], -1)
            .tolist(),
    ]

    # Insert in chunks
    batch_size = 500
    for i in tqdm(range(0, len(data[0]), batch_size)):
        batch_data = [d[i:i+batch_size] for d in data]
        collection.insert(batch_data)

    # Flush to ensure persistence
    collection.flush()

    # Print collection stats
    print(collection.num_entities)


100%|██████████| 500/500 [05:37<00:00,  1.48it/s]
  6%|▋         | 1/16 [05:52<1:28:11, 352.75s/it]

250000


100%|██████████| 500/500 [05:40<00:00,  1.47it/s]
 12%|█▎        | 2/16 [11:53<1:23:26, 357.63s/it]

500000


100%|██████████| 500/500 [05:44<00:00,  1.45it/s]
 19%|█▉        | 3/16 [17:57<1:18:04, 360.33s/it]

750000


2025-06-27 07:30:30,244 [ERROR][handler]: RPC error: [batch_insert], <MilvusException: (code=<bound method _MultiThreadedRendezvous.code of <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:19530: Failed to connect to remote host: connect: Connection refused (111)"
	debug_error_string = "UNKNOWN:Error received from peer  {grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:19530: Failed to connect to remote host: connect: Connection refused (111)", grpc_status:14, created_time:"2025-06-27T07:30:30.244005098+00:00"}"
>>, message=[batch_insert] Retry run out of 75 retry times, message=failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:19530: Failed to connect to remote host: connect: Connection refused (111))>, <Time:{'RPC start': '2025-06-27 07:26:49.476497', 'RPC error': '2025-06-27 07:30:30.244465'}> (decorator

MilvusException: <MilvusException: (code=<bound method _MultiThreadedRendezvous.code of <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:19530: Failed to connect to remote host: connect: Connection refused (111)"
	debug_error_string = "UNKNOWN:Error received from peer  {grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:19530: Failed to connect to remote host: connect: Connection refused (111)", grpc_status:14, created_time:"2025-06-27T07:30:30.244005098+00:00"}"
>>, message=[batch_insert] Retry run out of 75 retry times, message=failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:19530: Failed to connect to remote host: connect: Connection refused (111))>

In [None]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

['t2kg_primekg_nodes_disease',
 't2kg_primekg_nodes_drug',
 't2kg_primekg_nodes_molecular_function',
 't2kg_primekg_edges',
 't2kg_primekg_nodes',
 't2kg_primekg_nodes_biological_process',
 't2kg_primekg_nodes_cellular_component',
 't2kg_primekg_nodes_gene_protein']

In [None]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection('t2kg_primekg_edges')

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = merged_edges_df["edge_emb"].iloc[0]  # or .loc[0], both work
search_params = {"metric_type": "COSINE"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["head_id", "tail_id", "edge_type", "feat"]
)

CPU times: user 9.91 ms, sys: 1.88 ms, total: 11.8 ms
Wall time: 1.18 s


data: [[{'triplet_index': 1214608, 'distance': 0.9999998807907104, 'entity': {'tail_id': 'Carbutamide_(15801)', 'edge_type': 'drug|synergistic interaction|drug', 'feat': 'Linagliptin (drug) has a direct relationship of drug_drug:synergistic interaction with Carbutamide (drug).', 'head_id': 'Linagliptin_(15598)'}}, {'triplet_index': 1563071, 'distance': 0.9767932891845703, 'entity': {'tail_id': 'Carmegliptin_(21218)', 'edge_type': 'drug|synergistic interaction|drug', 'feat': 'Linagliptin (drug) has a direct relationship of drug_drug:synergistic interaction with Carmegliptin (drug).', 'head_id': 'Linagliptin_(15598)'}}, {'triplet_index': 1214400, 'distance': 0.9743505120277405, 'entity': {'tail_id': 'Carbutamide_(15801)', 'edge_type': 'drug|synergistic interaction|drug', 'feat': 'Sitagliptin (drug) has a direct relationship of drug_drug:synergistic interaction with Carbutamide (drug).', 'head_id': 'Sitagliptin_(15501)'}}, {'triplet_index': 2532955, 'distance': 0.9738765954971313, 'entity

In [89]:
# Check the ground truth for the search
merged_edges_df.loc[2]

Unnamed: 0,triplet_index,primekg_head_index,primekg_tail_index,head_id,tail_id,display_relation,edge_type,edge_type_str,head_index,tail_index,feat,edge_emb
2,1214610,14408,15801,Enzalutamide_(14408),Carbutamide_(15801),synergistic interaction,"[drug, synergistic interaction, drug]",drug|synergistic interaction|drug,14197,15480,Enzalutamide (drug) has a direct relationship ...,"[-0.021125574, -0.022837382, 0.009667073, -0.0..."


In [82]:
# Get node indices from the results
[n['triplet_index'] for n in results[0]]

[1214608,
 1563071,
 1214400,
 2532955,
 1390209,
 2145087,
 1176360,
 2048083,
 696273,
 755881]

In [81]:
# Get the cosine similarity scores
[n['distance'] for n in results[0]]

[0.9999998807907104,
 0.9767932891845703,
 0.9743505120277405,
 0.9738765954971313,
 0.9737454652786255,
 0.9732384085655212,
 0.9717762470245361,
 0.9714594483375549,
 0.970791220664978,
 0.9696710109710693]

In [87]:
milvus_client.list_collections()

['t2kg_primekg_nodes_cellular_component',
 't2kg_primekg_nodes_disease',
 't2kg_primekg_nodes_drug',
 't2kg_primekg_edges',
 't2kg_primekg_nodes',
 't2kg_primekg_nodes_biological_process',
 't2kg_primekg_nodes_gene_protein',
 't2kg_primekg_nodes_molecular_function']