In this tutorial, we will prepare Milvus database for storing and searching nodes and edges of a graph.

In particular, we are using PrimeKG multimodal data from the BioBridge project.

In [1]:
# Load necessary libraries
import os
import glob
import hydra
import cudf
import cupy as cp
import numpy as np
from pymilvus import (
    db,
    connections,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
    utility,
    MilvusClient
)
from tqdm import tqdm
import time
import pickle

### Loading IBD BioBridge-PrimeKG Multimodal Data

First, we need to get the path to the directory containing the parquet files of nodes and edges.

For nodes and edges, we have a separate folder that contains its enrichment and embeddings.

In [2]:
# Load pickle of the graph data
with open('../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_pyg_graph.pkl', 'rb') as f:
    graph = pickle.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def normalize_matrix(m, axis=1):
    """
    Normalize each row of a 2D matrix using CuPy.
    
    Parameters:
    m (cupy.ndarray): 2D matrix to normalize.
    
    Returns:
    cupy.ndarray: Normalized matrix.
    """
    norms = cp.linalg.norm(m, axis=axis, keepdims=True)
    return m / norms

def normalize_vector(v):
    """
    Normalize a vector using CuPy.
    
    Parameters:
    v (cupy.ndarray): Vector to normalize.
    
    Returns:
    cupy.ndarray: Normalized vector.
    """
    v = cp.asarray(v)
    norm = cp.linalg.norm(v)
    return v / norm

In [4]:
# Convert the list of embeddings to a 2D CuPy array (N x D)
graph_desc_x_cp = cp.asarray(graph['desc_x'].tolist())

# Normalize all rows (vectors) using broadcasting
graph_desc_x_normalized = normalize_matrix(graph_desc_x_cp, axis=1)
graph_x_normalized = [normalize_vector(v).tolist() for v in graph['x']]

# Convert the graph nodes to a cudf DataFrame
nodes_df = cudf.DataFrame({
    'node_id': graph['node_id'],
    'node_name': graph['node_name'],
    'node_type': graph['node_type'],
    'desc': graph['desc'],
    'desc_emb': graph_desc_x_normalized.tolist(),
    'feat': graph['enriched_node'],
    'feat_emb': graph_x_normalized,
})
nodes_df.reset_index(inplace=True)
nodes_df.rename(columns={'index': 'node_index'}, inplace=True)
nodes_df.head(3)

Unnamed: 0,node_index,node_id,node_name,node_type,desc,desc_emb,feat,feat_emb
0,0,SMAD3_(144),SMAD3,gene/protein,SMAD3 belongs to gene/protein node. SMAD3 is S...,"[0.02974936784975063, 0.05350021171537046, -0....",MSSILPFTPPIVKRLLGWKKGEQNGQEEKWCEKAVKSLVKKLKKTG...,"[-0.0010794274069904548, -0.0028632148270051, ..."
1,1,IL10RB_(179),IL10RB,gene/protein,IL10RB belongs to gene/protein node. IL10RB is...,"[0.02842173040130417, 0.01986006372730412, -0....",MAWSLGSWLGGCLLVSALGMVPPPENVRMNSVNFKNILQWESPAFA...,"[-0.007157766077247574, 0.006195289622587354, ..."
2,2,GNA12_(192),GNA12,gene/protein,GNA12 belongs to gene/protein node. GNA12 is G...,"[0.003668847841835145, 0.051380571197126614, -...",MSGVVRTLSRCLLPAEAGGARERRAGSGARDAEREARRRSRDIDAL...,"[-0.001562959383761, -0.01338132129666802, -0...."


In [5]:
nodes_df.shape

(2991, 8)

In [14]:
nodes_df[nodes_df.node_index.isin([14, 51, 743, 749, 1590, 2531, 2756, 2903])]

Unnamed: 0,node_index,node_id,node_name,node_type,desc,desc_emb,feat,feat_emb
14,14,IL6_(1567),IL6,gene/protein,IL6 belongs to gene/protein node. IL6 is inter...,"[0.04196309454625084, 0.029793425939964834, -0...",MNSFSTSAFGPVAFSLGLLLVLPAAFPAPVPPGEDSKDVAAPHRQP...,"[0.004271645940645016, 0.0007511508522558354, ..."
51,51,ICAM1_(4968),ICAM1,gene/protein,ICAM1 belongs to gene/protein node. ICAM1 is i...,"[0.034276778713674426, 0.03315937767925836, -0...",MAPSSPRPALPALLVLLGALFPGPGNAQTSVSPSKVILPRGGSVLV...,"[0.00507795906312394, -0.0005075638537899229, ..."
743,743,Atiprimod_(17591),Atiprimod,drug,Atiprimod belongs to drug node. Investigat...,"[0.006464077617364229, 0.005259734950775257, -...",CCCC1(CCC)CCC2(CCN(CCCN(CC)CC)C2)CC1,"[0.0030587854375391218, -0.03262633027902116, ..."
749,749,Dilmapimod_(17602),Dilmapimod,drug,Dilmapimod belongs to drug node. Dilmapimod ha...,"[0.03729557040719118, 0.002621309687478571, -0...",CC1=CC(F)=CC=C1C1=C2C=CC(=O)N(C2=NC(NC(CO)CO)=...,"[-0.02190920141208959, -0.043726178638557156, ..."
1590,1590,positive regulation of leukocyte adhesion to v...,positive regulation of leukocyte adhesion to v...,biological_process,positive regulation of leukocyte adhesion to v...,"[0.03933759383036314, 0.02968755905723463, -0....",Any process that activates or increases the fr...,"[0.053003649343574925, 0.039327081545075464, -..."
2531,2531,cellular response to interleukin-6_(109554),cellular response to interleukin-6,biological_process,cellular response to interleukin-6 belongs to ...,"[0.0544803034084408, 0.01141595255719802, -0.1...",Any process that results in a change in state ...,"[0.060826784805779056, 0.03196426937244094, -0..."
2756,2756,interleukin-6 receptor binding_(117413),interleukin-6 receptor binding,molecular_function,interleukin-6 receptor binding belongs to mole...,"[0.0003433517070531295, -0.0031268953771521973...",Binding to an interleukin-6 receptor.,"[0.0030954857584721035, 0.006912673604171493, ..."
2903,2903,interleukin-6 receptor complex_(124639),interleukin-6 receptor complex,cellular_component,interleukin-6 receptor complex belongs to cell...,"[-0.0008519993583764449, 0.01178766854270901, ...",A hexameric protein complex consisting of two ...,"[-0.024135240203797953, 0.05303054385611603, -..."


In [6]:
nodes_df[nodes_df.node_index == 5]

Unnamed: 0,node_index,node_id,node_name,node_type,desc,desc_emb,feat,feat_emb
5,5,IL7R_(625),IL7R,gene/protein,IL7R belongs to gene/protein node. IL7R is int...,"[0.04506901097086558, 0.008911124460235972, -0...",MTILGTTFGMVFSLLQVVSGESGYAQNGDLEDAELDDYSFSCYSQL...,"[0.005123700132302352, 0.005596505431503881, 0..."


In [7]:
# Convert the list of edge embeddings to a 2D CuPy array (M x D)
graph_edge_attr_cp = cp.asarray(graph['edge_attr'].tolist())

# Normalize all rows (vectors) using broadcasting
graph_edge_attr_normalized = normalize_matrix(graph_edge_attr_cp, axis=1)

# Convert the graph edges to a cudf DataFrame
edges_df = cudf.DataFrame({
    'triplet_index': graph['triplet_index'],
    'head_id': graph['head_id'],
    'head_name': graph['head_name'],
    'tail_id': graph['tail_id'],
    'tail_name': graph['tail_name'],
    'display_relation': graph['display_relation'],
    'edge_type': graph['edge_type'],
    'edge_type_str': ['|'.join(e) for e in graph['edge_type']],
    'feat': graph['enriched_edge'],
    'edge_emb': graph_edge_attr_normalized.tolist(),
})
edges_df = edges_df.merge(
    nodes_df[['node_index', 'node_id']],
    left_on='head_id',
    right_on='node_id',
    how='left'
)
edges_df.rename(columns={'node_index': 'head_index'}, inplace=True)
edges_df.drop(columns=['node_id'], inplace=True)
edges_df = edges_df.merge(
    nodes_df[['node_index', 'node_id']],
    left_on='tail_id',
    right_on='node_id',
    how='left'
)
edges_df.rename(columns={'node_index': 'tail_index'}, inplace=True)
edges_df.drop(columns=['node_id'], inplace=True)
edges_df.head(3)

Unnamed: 0,triplet_index,head_id,head_name,tail_id,tail_name,display_relation,edge_type,edge_type_str,feat,edge_emb,head_index,tail_index
0,8602,cytokine-mediated signaling pathway_(47242),cytokine-mediated signaling pathway,IL10RB_(179),IL10RB,interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,cytokine-mediated signaling pathway (biologica...,"[0.016838406414606846, 0.019238545922865967, -...",1455,1
1,8603,cytokine-mediated signaling pathway_(47242),cytokine-mediated signaling pathway,IL12B_(6168),IL12B,interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,cytokine-mediated signaling pathway (biologica...,"[0.018197947379867397, 0.03141968316046658, -0...",1455,59
2,8604,cytokine-mediated signaling pathway_(47242),cytokine-mediated signaling pathway,IRF5_(3646),IRF5,interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,cytokine-mediated signaling pathway (biologica...,"[0.018029207941198132, 0.019414354880667273, -...",1455,46


In [8]:
edges_df.shape

(11272, 12)

In [16]:
edges_df[edges_df.triplet_index == 4752]

Unnamed: 0,triplet_index,head_id,head_name,tail_id,tail_name,display_relation,edge_type,edge_type_str,feat,edge_emb,head_index,tail_index
10483,4752,ICAM1_(4968),ICAM1,cellular response to interleukin-6_(109554),cellular response to interleukin-6,interacts with,"[gene/protein, interacts with, biological_proc...",gene/protein|interacts with|biological_process,ICAM1 (gene/protein) has a direct relationship...,"[0.0445736235330466, 0.02968640366378288, -0.1...",51,2531


In [15]:
edges_df[edges_df.triplet_index == 10388]

Unnamed: 0,triplet_index,head_id,head_name,tail_id,tail_name,display_relation,edge_type,edge_type_str,feat,edge_emb,head_index,tail_index
3858,10388,cellular response to interleukin-6_(109554),cellular response to interleukin-6,ICAM1_(4968),ICAM1,interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,cellular response to interleukin-6 (biological...,"[0.057259282165135456, 0.028288940877212582, -...",2531,51


### Setup Milvus Database

In [8]:
# Configuration for Milvus
milvus_host = "localhost"
milvus_port = "19530"
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_user = "root"
milvus_password = "Milvus"
milvus_database = "t2kg_primekg"

In [9]:
# Connect to Milvus
connections.connect(
    alias="default",
    host=milvus_host,
    port=milvus_port,
    user=milvus_user,
    password=milvus_password
)

In [10]:
# Check if the database exists, create if it doesn't
if milvus_database not in db.list_database():
    db.create_database(milvus_database)

# Switch to the desired database
db.using_database(milvus_database)

In [11]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

In [12]:
# A helper function to chunk the data into smaller parts
# Utility: chunk generator
def chunked(data_list, chunk_size):
    for i in range(0, len(data_list), chunk_size):
        yield data_list[i:i + chunk_size]

#### Building Node Collection (Description Embedding)

In [13]:
%%time

# Configuration for Milvus collection
node_coll_name = f"{milvus_database}_nodes"

# Define schema for the collection
# Leave out the feat and feat_emb fields for now
desc_emb_dim = len(nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])
node_fields = [
    FieldSchema(name="node_index", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="node_name", dtype=DataType.VARCHAR, max_length=1024,
                enable_analyzer=True, enable_match=True),
    FieldSchema(name="node_type", dtype=DataType.VARCHAR, max_length=1024,
                enable_analyzer=True, enable_match=True),
    FieldSchema(name="desc", dtype=DataType.VARCHAR, max_length=40960,
                enable_analyzer=True, enable_match=True),
    FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),
]
schema  = CollectionSchema(fields=node_fields, description=f"Schema for collection {node_coll_name}")

# Create collection if it doesn't exist
if not utility.has_collection(node_coll_name):
    collection = Collection(name=node_coll_name, schema=schema)
else:
    collection = Collection(name=node_coll_name)

# Create indexes
collection.create_index(
    field_name="node_index",
    index_params={"index_type": "STL_SORT"}, # STL_SORT
    index_name="node_index_index"
)
# Create index for node_name, node_type, desc fields (inverted)
collection.create_index(
    field_name="node_name",
    index_params={"index_type": "INVERTED"},
    index_name="node_name_index"
)
collection.create_index(
    field_name="node_type",
    index_params={"index_type": "INVERTED"},
    index_name="node_type_index"
)
collection.create_index(
    field_name="desc",
    index_params={"index_type": "INVERTED"},
    index_name="desc_index"
)
collection.create_index(
    field_name="desc_emb",
    index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"}, # AUTOINDEX
    index_name="desc_emb_index"
)

# Prepare data for insertion
data = [
    nodes_df["node_index"].to_arrow().to_pylist(),
    nodes_df["node_id"].to_arrow().to_pylist(),
    nodes_df["node_name"].to_arrow().to_pylist(),
    nodes_df["node_type"].to_arrow().to_pylist(),
    nodes_df["desc"].to_arrow().to_pylist(),
    cp.asarray(nodes_df["desc_emb"].list.leaves).astype(cp.float32)
        .reshape(nodes_df.shape[0], -1)
        .tolist(),
]

# Insert data in batches
batch_size = 500
total = len(data[0])
for i in tqdm(range(0, total, batch_size)):
    batch = [
        col[i:i+batch_size] for col in data
    ]
    collection.insert(batch)

# Flush to persist data
collection.flush()

# Get collection stats
print(collection.num_entities)

100%|██████████| 6/6 [00:01<00:00,  3.10it/s]


2991
CPU times: user 281 ms, sys: 40.6 ms, total: 321 ms
Wall time: 5.49 s


In [14]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes
2991


In [15]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index == 0"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ['{\'node_index\': 0, \'node_id\': \'SMAD3_(144)\', \'node_name\': \'SMAD3\', \'node_type\': \'gene/protein\', \'desc\': "SMAD3 belongs to gene/protein node. SMAD3 is SMAD family member 3. The SMAD family of proteins are a group of intracellular signal transducer proteins similar to the gene products of the Drosophila gene \'mothers against decapentaplegic\' (Mad) and the C. elegans gene Sma. The SMAD3 protein functions in the transforming growth factor-beta signaling pathway, and transmits signals from the cell surface to the nucleus, regulating gene activity and cell proliferation. This protein forms a complex with other SMAD proteins and binds DNA, functioning both as a transcription factor and tumor suppressor. Mutations in this gene are associated with aneurysms-osteoarthritis syndrome and Loeys-Dietz Syndrome 3. [provided by RefSeq, May 2022].", \'desc_emb\': [0.029749367, 0.053500213, -0.17067125, -0.025804106, 0.0078263, -0.009656536, 0.022171432, 0.011523555, 0.03228597,

In [16]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index in [0, 1]"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ['{\'node_index\': 0, \'node_id\': \'SMAD3_(144)\', \'node_name\': \'SMAD3\', \'node_type\': \'gene/protein\', \'desc\': "SMAD3 belongs to gene/protein node. SMAD3 is SMAD family member 3. The SMAD family of proteins are a group of intracellular signal transducer proteins similar to the gene products of the Drosophila gene \'mothers against decapentaplegic\' (Mad) and the C. elegans gene Sma. The SMAD3 protein functions in the transforming growth factor-beta signaling pathway, and transmits signals from the cell surface to the nucleus, regulating gene activity and cell proliferation. This protein forms a complex with other SMAD proteins and binds DNA, functioning both as a transcription factor and tumor suppressor. Mutations in this gene are associated with aneurysms-osteoarthritis syndrome and Loeys-Dietz Syndrome 3. [provided by RefSeq, May 2022].", \'desc_emb\': [0.029749367, 0.053500213, -0.17067125, -0.025804106, 0.0078263, -0.009656536, 0.022171432, 0.011523555, 0.03228597,

In [17]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = nodes_df["desc_emb"].iloc[0]
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="desc_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"]
)

results

CPU times: user 7.76 ms, sys: 383 μs, total: 8.14 ms
Wall time: 189 ms


data: [[{'node_index': 0, 'distance': 0.9999999403953552, 'entity': {'node_name': 'SMAD3', 'node_id': 'SMAD3_(144)'}}, {'node_index': 1009, 'distance': 0.8087351322174072, 'entity': {'node_name': 'regulation of SMAD protein signal transduction', 'node_id': 'regulation of SMAD protein signal transduction_(41433)'}}, {'node_index': 7, 'distance': 0.7976452112197876, 'entity': {'node_name': 'STAT3', 'node_id': 'STAT3_(729)'}}, {'node_index': 2152, 'distance': 0.7935216426849365, 'entity': {'node_name': 'positive regulation of SMAD protein signal transduction', 'node_id': 'positive regulation of SMAD protein signal transduction_(101088)'}}, {'node_index': 2182, 'distance': 0.7918790578842163, 'entity': {'node_name': 'SMAD protein signal transduction', 'node_id': 'SMAD protein signal transduction_(101792)'}}, {'node_index': 31, 'distance': 0.768486499786377, 'entity': {'node_name': 'TGFB1', 'node_id': 'TGFB1_(2889)'}}, {'node_index': 2150, 'distance': 0.7662344574928284, 'entity': {'node_na

In [18]:
nodes_df.loc[0]

Unnamed: 0,node_index,node_id,node_name,node_type,desc,desc_emb,feat,feat_emb
0,0,SMAD3_(144),SMAD3,gene/protein,SMAD3 belongs to gene/protein node. SMAD3 is S...,"[0.02974936784975063, 0.05350021171537046, -0....",MSSILPFTPPIVKRLLGWKKGEQNGQEEKWCEKAVKSLVKKLKKTG...,"[-0.0010794274069904548, -0.0028632148270051, ..."


#### Building Node Collection (Node Type-specific Embedding)

Note that nodes information of the PrimeKG data is different for each node type, 
we are going to build a separate collection for each node type.

We will use the node type as the collection name.

In [21]:
%%time

# Loop over group enrichment nodes by node_type
for node_type, nodes_df_ in tqdm(nodes_df.groupby('node_type')):
    print(f"Processing node type: {node_type}")

    # Milvus collection name for this node_type
    node_coll_name = f"{milvus_database}_nodes_{node_type.replace('/', '_')}"

    # Define collection schema
    desc_emb_dim = len(nodes_df_.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])
    feat_emb_dim = len(nodes_df_.iloc[0]['feat_emb'].to_arrow().to_pylist()[0])
    node_fields = [
        FieldSchema(name="node_index", dtype=DataType.INT64, is_primary=True, auto_id=False),
        FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(name="node_name", dtype=DataType.VARCHAR, max_length=1024,
                    enable_analyzer=True, enable_match=True),
        FieldSchema(name="node_type", dtype=DataType.VARCHAR, max_length=1024,
                    enable_analyzer=True, enable_match=True),
        FieldSchema(name="desc", dtype=DataType.VARCHAR, max_length=40960,
                    enable_analyzer=True, enable_match=True),
        FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),        
        FieldSchema(name="feat", dtype=DataType.VARCHAR, max_length=40960,
                    enable_analyzer=True, enable_match=True),
        FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=feat_emb_dim),
    ]
    schema = CollectionSchema(fields=node_fields, description=f"schema for collection {node_coll_name}")

    # Create collection if not exists
    if not utility.has_collection(node_coll_name):
        collection = Collection(name=node_coll_name, schema=schema)
    else:
        collection = Collection(name=node_coll_name)

    # Create index for node_index field (scalar)
    collection.create_index(
        field_name="node_index",
        index_params={"index_type": "STL_SORT"},
        index_name="node_index_index"
    )
    # Create index for node_name, node_type, desc fields (inverted)
    collection.create_index(
        field_name="node_name",
        index_params={"index_type": "INVERTED"},
        index_name="node_name_index"
    )
    collection.create_index(
        field_name="node_type",
        index_params={"index_type": "INVERTED"},
        index_name="node_type_index"
    )
    collection.create_index(
        field_name="desc",
        index_params={"index_type": "INVERTED"},
        index_name="desc_index"
    )
    collection.create_index(
        field_name="desc_emb",
        index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"}, # AUTOINDEX
        index_name="desc_emb_index"
    )
    # Create index for feat_emb (vector)
    collection.create_index(
        field_name="feat_emb",
        index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"}, # AUTOINDEX
        index_name="feat_emb_index"
    )

    # Prepare data for insertion
    # Columns must be lists of values in order matching schema fields
    data = [
        nodes_df_["node_index"].to_arrow().to_pylist(),
        nodes_df_["node_id"].to_arrow().to_pylist(),
        nodes_df_["node_name"].to_arrow().to_pylist(),
        nodes_df_["node_type"].to_arrow().to_pylist(),
        nodes_df_["desc"].to_arrow().to_pylist(),
        cp.asarray(nodes_df_["desc_emb"].list.leaves).astype(cp.float32)
            .reshape(nodes_df_.shape[0], -1)
            .tolist(),
        nodes_df_["feat"].to_arrow().to_pylist(),
        cp.asarray(nodes_df_["feat_emb"].list.leaves).astype(cp.float32)
            .reshape(nodes_df_.shape[0], -1)
            .tolist(),
    ]

    # Batch insert data in chunks
    batch_size = 500
    total_rows = len(data[0])
    for i in tqdm(range(0, total_rows, batch_size)):
        batch = [col[i:i + batch_size] for col in data]
        collection.insert(batch)

    # Flush the collection to ensure data is persisted
    collection.flush()

    # Print collection stats (number of entities and segment info)
    stats = collection.num_entities
    print(f"Collection {node_coll_name} stats:")
    print(stats)

  0%|          | 0/6 [00:00<?, ?it/s]

Processing node type: biological_process


100%|██████████| 4/4 [00:02<00:00,  1.80it/s]
 17%|█▋        | 1/6 [00:06<00:33,  6.80s/it]

Collection t2kg_primekg_nodes_biological_process stats:
1615
Processing node type: cellular_component


100%|██████████| 1/1 [00:00<00:00,  2.72it/s]
 33%|███▎      | 2/6 [00:11<00:21,  5.38s/it]

Collection t2kg_primekg_nodes_cellular_component stats:
202
Processing node type: disease


100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
 50%|█████     | 3/6 [00:15<00:14,  4.92s/it]

Collection t2kg_primekg_nodes_disease stats:
7
Processing node type: drug


100%|██████████| 2/2 [00:01<00:00,  1.99it/s]
 67%|██████▋   | 4/6 [00:20<00:09,  4.96s/it]

Collection t2kg_primekg_nodes_drug stats:
748
Processing node type: gene/protein


100%|██████████| 1/1 [00:00<00:00,  2.84it/s]
 83%|████████▎ | 5/6 [00:24<00:04,  4.76s/it]

Collection t2kg_primekg_nodes_gene_protein stats:
102
Processing node type: molecular_function


100%|██████████| 1/1 [00:00<00:00,  2.69it/s]
100%|██████████| 6/6 [00:29<00:00,  4.89s/it]

Collection t2kg_primekg_nodes_molecular_function stats:
317
CPU times: user 716 ms, sys: 204 ms, total: 920 ms
Wall time: 29.4 s





In [22]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes_biological_process
1615
Collection: t2kg_primekg_nodes_gene_protein
102
Collection: t2kg_primekg_nodes_cellular_component
202
Collection: t2kg_primekg_nodes_disease
7
Collection: t2kg_primekg_nodes_drug
748
Collection: t2kg_primekg_nodes_molecular_function
317
Collection: t2kg_primekg_nodes
2991


In [23]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection('t2kg_primekg_nodes_gene_protein')

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = nodes_df["feat_emb"].iloc[0]
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"]
)
results

CPU times: user 7.6 ms, sys: 16.6 ms, total: 24.2 ms
Wall time: 1.45 s


data: [[{'node_index': 0, 'distance': 0.9999998211860657, 'entity': {'node_id': 'SMAD3_(144)', 'node_name': 'SMAD3'}}, {'node_index': 3, 'distance': 0.988673210144043, 'entity': {'node_id': 'HNF4A_(279)', 'node_name': 'HNF4A'}}, {'node_index': 850, 'distance': 0.9875491857528687, 'entity': {'node_id': 'DENND1B_(34887)', 'node_name': 'DENND1B'}}, {'node_index': 30, 'distance': 0.9863460659980774, 'entity': {'node_id': 'PRDM1_(2874)', 'node_name': 'PRDM1'}}, {'node_index': 15, 'distance': 0.9861670136451721, 'entity': {'node_id': 'JAK2_(1618)', 'node_name': 'JAK2'}}, {'node_index': 71, 'distance': 0.9860329627990723, 'entity': {'node_id': 'KIF21B_(8564)', 'node_name': 'KIF21B'}}, {'node_index': 9, 'distance': 0.9856594800949097, 'entity': {'node_id': 'PPARG_(989)', 'node_name': 'PPARG'}}, {'node_index': 62, 'distance': 0.9852690696716309, 'entity': {'node_id': 'TYK2_(6428)', 'node_name': 'TYK2'}}, {'node_index': 66, 'distance': 0.9844554662704468, 'entity': {'node_id': 'ADCY7_(7359)', 'n

In [24]:
# Check the ground truth for the search
nodes_df.loc[0]

Unnamed: 0,node_index,node_id,node_name,node_type,desc,desc_emb,feat,feat_emb
0,0,SMAD3_(144),SMAD3,gene/protein,SMAD3 belongs to gene/protein node. SMAD3 is S...,"[0.02974936784975063, 0.05350021171537046, -0....",MSSILPFTPPIVKRLLGWKKGEQNGQEEKWCEKAVKSLVKKLKKTG...,"[-0.0010794274069904548, -0.0028632148270051, ..."


In [25]:
# Get node indices from the results
[n['node_index'] for n in results[0]]

[0, 3, 850, 30, 15, 71, 9, 62, 66, 11]

In [26]:
# Get the cosine similarity scores
[n['distance'] for n in results[0]]

[0.9999998211860657,
 0.988673210144043,
 0.9875491857528687,
 0.9863460659980774,
 0.9861670136451721,
 0.9860329627990723,
 0.9856594800949097,
 0.9852690696716309,
 0.9844554662704468,
 0.984216034412384]

#### Building Edge Collection

Subsquently, we are also building the edges collection in Milvus.

Note that the edges information of PrimeKG has massive records, so once again we are chunking the data to avoid memory issues.

In [28]:
%%time

# Define collection name
edge_coll_name = f"{milvus_database}_edges"

# Define schema
edge_fields = [
    FieldSchema(name="triplet_index", dtype=DataType.INT64, is_primary=True, auto_id=False),
    FieldSchema(name="head_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="head_index", dtype=DataType.INT64),
    FieldSchema(name="tail_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="tail_index", dtype=DataType.INT64),
    FieldSchema(name="edge_type", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="display_relation", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="feat", dtype=DataType.VARCHAR, max_length=40960),
    FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=768), #1536 for text-embedding-ada-002, 768 for ollama nomic-embed-text
]
edge_schema = CollectionSchema(fields=edge_fields, description="Schema for edges collection")

# Create collection if not exists
if not utility.has_collection(edge_coll_name):
    collection = Collection(name=edge_coll_name, schema=edge_schema)
else:
    collection = Collection(name=edge_coll_name)

# Create indexes
collection.create_index(field_name="triplet_index", index_params={"index_type": "STL_SORT"}, index_name="triplet_index_index")
collection.create_index(field_name="head_index", index_params={"index_type": "STL_SORT"}, index_name="head_index_index")
collection.create_index(field_name="tail_index", index_params={"index_type": "STL_SORT"}, index_name="tail_index_index")
collection.create_index(field_name="feat_emb", index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"}, index_name="feat_emb_index") # AUTOINDEX

# Iterate over chunked edges embedding df
for edges_df_ in tqdm([edges_df]):

    # Prepare data fields in column-wise format
    data = [
        edges_df_["triplet_index"].to_arrow().to_pylist(),
        edges_df_["head_id"].to_arrow().to_pylist(),
        edges_df_["head_index"].to_arrow().to_pylist(),
        edges_df_["tail_id"].to_arrow().to_pylist(),
        edges_df_["tail_index"].to_arrow().to_pylist(),
        edges_df_["edge_type_str"].to_arrow().to_pylist(),
        edges_df_["display_relation"].to_arrow().to_pylist(),
        edges_df_["feat"].to_arrow().to_pylist(),
        cp.asarray(edges_df_["edge_emb"].list.leaves).astype(cp.float32)
            .reshape(edges_df_.shape[0], -1)
            .tolist(),
    ]

    # Insert in chunks
    batch_size = 500
    for i in tqdm(range(0, len(data[0]), batch_size)):
        batch_data = [d[i:i+batch_size] for d in data]
        collection.insert(batch_data)

    # Flush to ensure persistence
    collection.flush()

    # Print collection stats
    print(collection.num_entities)
    
    time.sleep(5)  # Sleep to avoid overwhelming the server


100%|██████████| 23/23 [00:05<00:00,  4.45it/s]


11272


100%|██████████| 1/1 [00:11<00:00, 11.11s/it]

CPU times: user 971 ms, sys: 180 ms, total: 1.15 s
Wall time: 13.4 s





In [29]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes_cellular_component
202
Collection: t2kg_primekg_nodes_disease
7
Collection: t2kg_primekg_nodes_drug
748
Collection: t2kg_primekg_nodes_molecular_function
317
Collection: t2kg_primekg_edges
11272
Collection: t2kg_primekg_nodes
2991
Collection: t2kg_primekg_nodes_biological_process
1615
Collection: t2kg_primekg_nodes_gene_protein
102


In [30]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection('t2kg_primekg_edges')

# Load the collection into memory before query
collection.load()

# Query by expr on triplet_index
expr = "triplet_index == 0"
output_fields = ["triplet_index", "head_id", "tail_id", "edge_type", "feat", "feat_emb"]

results = collection.query(expr, output_fields=output_fields)
results

CPU times: user 10.2 ms, sys: 6.34 ms, total: 16.6 ms
Wall time: 1.52 s


data: ["{'triplet_index': 0, 'head_id': 'Rose bengal_(14118)', 'tail_id': 'LTF_(3233)', 'edge_type': 'drug|carrier|gene/protein', 'feat': 'Rose bengal (drug) has a direct relationship of drug_protein:carrier with LTF (gene/protein).', 'feat_emb': [0.071049586, 0.0060329223, -0.17035195, 0.0013526214, 0.048036188, 0.003639546, -0.075369276, 0.008644691, -0.03623718, -0.03908307, -0.02776647, 0.05557716, 0.14341861, -0.005525504, 0.0033363877, -0.0004059333, 0.025640398, -0.036554623, -0.003938051, 0.03537183, -0.010803912, 0.019870885, 0.027924007, -0.015905585, 0.04759002, 0.0665622, 0.029466823, -0.0487658, -0.021191811, -0.003891353, 0.060805637, -0.03784243, -0.009008949, -0.0034025165, 0.0024893545, -0.005160498, 0.023153596, 0.039008953, -0.011503011, 0.012791931, 0.06716217, 0.054385673, 0.00804711, 0.015538905, 0.03946401, 0.048474014, 0.013319631, 0.025589347, 0.054585613, -0.04032546, 0.010447572, -0.012857375, -0.016545152, -0.009699244, 0.06947826, 0.0377996, -0.005236178, 0

In [31]:
# Check the ground truth for the search
results[0]['triplet_index'], results[0]['head_id'], results[0]['tail_id'], results[0]['edge_type'], results[0]['feat']

(0,
 'Rose bengal_(14118)',
 'LTF_(3233)',
 'drug|carrier|gene/protein',
 'Rose bengal (drug) has a direct relationship of drug_protein:carrier with LTF (gene/protein).')

In [32]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection('t2kg_primekg_edges')

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = np.array(results[0]['feat_emb']).tolist() # merged_edges_df["edge_emb"].iloc[0]
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["head_id", "tail_id", "edge_type", "feat"]
)
results

CPU times: user 1.86 ms, sys: 471 μs, total: 2.34 ms
Wall time: 5.87 ms


data: [[{'triplet_index': 0, 'distance': 1.0000001192092896, 'entity': {'head_id': 'Rose bengal_(14118)', 'tail_id': 'LTF_(3233)', 'edge_type': 'drug|carrier|gene/protein', 'feat': 'Rose bengal (drug) has a direct relationship of drug_protein:carrier with LTF (gene/protein).'}}, {'triplet_index': 5636, 'distance': 0.9807485938072205, 'entity': {'head_id': 'LTF_(3233)', 'tail_id': 'Rose bengal_(14118)', 'edge_type': 'gene/protein|carrier|drug', 'feat': 'LTF (gene/protein) has a direct relationship of drug_protein:carrier with Rose bengal (drug).'}}, {'triplet_index': 5832, 'distance': 0.8167988061904907, 'entity': {'head_id': 'LTF_(3233)', 'tail_id': '3h-Indole-5,6-Diol_(18278)', 'edge_type': 'gene/protein|target|drug', 'feat': 'LTF (gene/protein) has a direct relationship of drug_protein:target with 3h-Indole-5,6-Diol (drug).'}}, {'triplet_index': 5834, 'distance': 0.8110494613647461, 'entity': {'head_id': 'LTF_(3233)', 'tail_id': 'Nitrilotriacetic acid_(18279)', 'edge_type': 'gene/pro

In [33]:
# Get node indices from the results
[n['triplet_index'] for n in results[0]]

[0, 5636, 5832, 5834, 5835, 5833, 3, 5836, 5837, 196]

In [34]:
# Get the cosine similarity scores
[n['distance'] for n in results[0]]

[1.0000001192092896,
 0.9807485938072205,
 0.8167988061904907,
 0.8110494613647461,
 0.8035300374031067,
 0.8020656704902649,
 0.7981401085853577,
 0.7976876497268677,
 0.7929842472076416,
 0.7929696440696716]