In [1]:
import torch
if torch.cuda.is_available():
    import cudf as df
    import cupy as py
else:
    import pandas as df
    import numpy as py
import pandas as pd
from pymilvus import (
    db,
    connections,
    utility,
    Collection,
)
import os
import glob
import cudf
import dask_cudf
import hydra
import logging
import sys
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.ollama import EmbeddingWithOllama
from aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning as MultimodalPCSTPruningNew
from aiagents4pharma.talk2knowledgegraphs.utils.extractions.multimodal_pcst import MultimodalPCSTPruning as MultimodalPCSTPruningOld

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration for Milvus
milvus_host = "localhost"
milvus_port = "19530"
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_user = "root"
milvus_password = "Milvus"
milvus_database = "t2kg_primekg"

In [3]:
# Connect to Milvus
connections.connect(
    alias="default",
    host=milvus_host,
    port=milvus_port,
    user=milvus_user,
    password=milvus_password
)

In [4]:
db.using_database(milvus_database)

utility.list_collections()

['t2kg_primekg_nodes_biological_process',
 't2kg_primekg_nodes_drug',
 't2kg_primekg_nodes_gene_protein',
 't2kg_primekg_edges',
 't2kg_primekg_nodes_cellular_component',
 't2kg_primekg_nodes_molecular_function',
 't2kg_primekg_nodes',
 't2kg_primekg_nodes_disease']

In [85]:
# Define the data path
DATA_PATH = "../../../aiagents4pharma/talk2knowledgegraphs/tests/files"

# Define the agent state
state = {
    "selections": {
        "gene/protein": [],
        "molecular_function": [],
        "cellular_component": [],
        "biological_process": [],
        "drug": [],
        "disease": []
    },
    "uploaded_files": [],
    "topk_nodes": 5,
    "topk_edges": 5,
    "dic_source_graph": [
        {
            "name": "BioBridge",
            "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
            "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
        }
    ],
}

In [86]:
# Update state
state["uploaded_files"] = [
    {
        "file_name": "multimodal-analysis_single_gene.xlsx",
        "file_path": f"{DATA_PATH}/multimodal-analysis_single_gene.xlsx",
        "file_type": "multimodal",
        "uploaded_by": "VPEUser",
        "uploaded_timestamp": "2025-05-12 00:00:00",
    }
]

In [87]:
# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [139]:
# Initialize dataframes
logger.log(logging.INFO, "Initializing dataframes")
multimodal_df = df.DataFrame({"name": [], "node_type": []})
query_df = []

# Loop over the uploaded files and find multimodal files
logger.log(logging.INFO, "Looping over uploaded files")
for i in range(len(state["uploaded_files"])):
    # Check if multimodal file is uploaded
    if state["uploaded_files"][i]["file_type"] == "multimodal":
        # Read the Excel file
        multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
                                        sheet_name=None)
        
        
# Check if the multimodal_df is empty
logger.log(logging.INFO, "Checking if multimodal_df is empty")
if len(multimodal_df) > 0:
    # Prepare multimodal_df
    logger.log(logging.INFO, "Preparing multimodal_df")
    # Merge all obtained dataframes into a single dataframe
    multimodal_df = pd.concat(multimodal_df).reset_index()
    multimodal_df = df.DataFrame(multimodal_df)
    multimodal_df.drop(columns=["level_1"], inplace=True)
    multimodal_df.rename(columns={"level_0": "q_node_type",
                                  "name": "q_node_name"}, inplace=True)
    # Since an excel sheet name could not contain a `/`,
    # but the node type can be 'gene/protein' as exists in the PrimeKG
    multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace('-', '_')
    
    # Query the Milvus database for each node type in multimodal_df
    logger.log(logging.INFO, "Querying Milvus database for each node type in multimodal_df")
    for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
        print(f"Processing node type: {node_type}")
        
        # Load the collection
        collection = Collection(name=f"{milvus_database}_nodes_{node_type.replace('/', '_')}")
        collection.load()

        # Query the collection with node names from multimodal_df
        q_node_names =  getattr(node_type_df['q_node_name'], 
                                "to_pandas", 
                                lambda: node_type_df['q_node_name'])().tolist()
        q_columns = ["node_id", "node_type", "feat", "feat_emb", "desc", "desc_emb"]
        res = collection.query(
            expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
            output_fields=q_columns,
        )
        # Convert the embeedings into floats
        for r_ in res:
            r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
            r_['desc_emb'] = [float(x) for x in r_['desc_emb']]

        # Convert the result to a DataFrame
        res_df = df.DataFrame(res)[q_columns]
        res_df["use_description"] = False
        
        # Append the results to query_df
        query_df.append(res_df)
    
    # Concatenate all results into a single DataFrame
    logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
    query_df = df.concat(query_df, ignore_index=True)
    
    # Update the state by adding the the selected node IDs
    logger.log(logging.INFO, "Updating state with selected node IDs")
    state["selections"] = query_df.to_pandas().groupby(
        "node_type"
    )["node_id"].apply(list).to_dict()


INFO:__main__:Initializing dataframes
INFO:__main__:Looping over uploaded files
INFO:__main__:Checking if multimodal_df is empty
INFO:__main__:Preparing multimodal_df
INFO:__main__:Querying Milvus database for each node type in multimodal_df
INFO:__main__:Concatenating all results into a single DataFrame
INFO:__main__:Updating state with selected node IDs


Processing node type: gene_protein


In [140]:
# Prompt
prompt_text = "List the drugs that target Interleukin-6."
# Embed the prompt using Ollama
prompt = {
    "text": prompt_text,
    "emb": [EmbeddingWithOllama(model_name='nomic-embed-text').embed_query(prompt_text)]
}

INFO:httpx:HTTP Request: GET http://127.0.0.1:11434/api/tags "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"


In [141]:
# Append a user prompt to the query dataframe
# query_df = []
# query_df = df.DataFrame({
#     'node_id': 'user_prompt',
#     'node_type': 'prompt',
#     'feat': prompt["text"],
#     'feat_emb': prompt["emb"],
#     'desc': prompt["text"],
#     'desc_emb': prompt["emb"],
#     'use_description': True  # set to True for user prompt embedding
# })


logger.log(logging.INFO, "Adding user prompt to query dataframe")
query_df = df.concat([
    query_df,
    df.DataFrame({
    'node_id': 'user_prompt',
    'node_type': 'prompt',
    'feat': prompt["text"],
    'feat_emb': prompt["emb"],
    'desc': prompt["text"],
    'desc_emb': prompt["emb"],
    'use_description': True  # set to True for user prompt embedding
})
], ignore_index=True)

INFO:__main__:Adding user prompt to query dataframe


In [116]:
query_df

Unnamed: 0,node_id,node_type,feat,feat_emb,desc,desc_emb,use_description
0,IL7R_(625),gene/protein,MTILGTTFGMVFSLLQVVSGESGYAQNGDLEDAELDDYSFSCYSQL...,"[0.005123700015246868, 0.005596505478024483, 0...",IL7R belongs to gene/protein node. IL7R is int...,"[0.04506901279091835, 0.008911124430596828, -0...",False
1,TCF7_(5195),gene/protein,MPQLDSGGGGAGGGDDLGAPDELLAFQDEGEEQDDKSRDSAAGPER...,"[0.0020665288902819157, 0.00025825444026850164...",TCF7 belongs to gene/protein node. TCF7 is tra...,"[0.03699759021401405, 0.038098547607660294, -0...",False
2,user_prompt,prompt,List the drugs that target Interleukin-6.,"[-0.020274699, -0.0045663384, -0.18202391, 0.0...",List the drugs that target Interleukin-6.,"[-0.020274699, -0.0045663384, -0.18202391, 0.0...",True


In [117]:
# extract_subgraph
# text_emb = py.array(query_df.loc[0]['desc_emb'][0]).reshape(1, -1).astype(py.float32)
# query_emb = py.array(query_df.loc[0]['feat_emb'][0]).reshape(1, -1).astype(py.float32)
i = 0
text_emb = query_df.loc[i]['desc_emb'][i]
query_emb = query_df.loc[i]['feat_emb'][i]
modality = query_df.loc[i]['node_type'][i]
use_description = query_df.loc[i]['use_description'][i]

In [118]:
use_description, modality

(False, 'gene/protein')

In [119]:
# Initialize variables
topk = 3
topk_e = 3
c_const = 0.01

collection = Collection(name=f"{milvus_database}_nodes")
collection.load()
topk = min(topk, collection.num_entities)

n_prizes = df.Series(py.zeros(collection.num_entities), dtype=py.float32)

### Node Prize

In [120]:
# Intialize several variables
collection = Collection(name=f"{milvus_database}_nodes")
collection.load()
topk = min(topk, collection.num_entities)
n_prizes = py.zeros(collection.num_entities, dtype=py.float32)

# Calculate cosine similarity for text features and update the score
if use_description:
    # Search the collection with the query embedding
    res = collection.search(
        data=[query_emb],
        anns_field="desc_emb",
        param={"metric_type": "IP"},
        limit=topk,
        output_fields=["node_id"])
else:
    # Load the collection for the specific node type
    col_name = f"{milvus_database}_nodes_{modality.replace('/', '_')}"
    print(f"Loading collection: {col_name}")
    collection = Collection(name=col_name)
    collection.load()

    # Search the collection with the query embedding
    res = collection.search(
        data=[query_emb],
        anns_field="feat_emb",
        param={"metric_type": "IP"},
        limit=topk,
        output_fields=["node_id"])

# Update the prizes based on the search results
n_prizes[[r['node_index'] for r in res[0]]] = py.arange(topk, 0, -1).astype(py.float32)

Loading collection: t2kg_primekg_nodes_gene_protein


In [121]:
[r['distance'] for r in res[0]]

[0.9999999403953552, 0.9898543357849121, 0.9895499348640442]

In [122]:
[r.id for r in res[0]]

[5, 845, 12]

In [123]:
[r["node_index"] for r in res[0]]

[5, 845, 12]

### Edge Prize

In [124]:
# Intialize several variables
collection = Collection(name=f"{milvus_database}_edges")
collection.load()
topk_e = min(topk_e, collection.num_entities)
e_prizes = py.zeros(collection.num_entities, dtype=py.float32)

# Search the collection with the query embedding
res = collection.search(
    data=[text_emb],
    anns_field="feat_emb",
    param={"metric_type": "IP"},
    limit=topk_e,
    output_fields=["head_id", "tail_id"])

# Update the prizes based on the search results
e_prizes[[r['triplet_index'] for r in res[0]]] = [r['distance'] for r in res[0]]

In [125]:
unique_prizes, inverse_indices = py.unique(e_prizes, return_inverse=True)
topk_e_values = unique_prizes[py.argsort(-unique_prizes)[:topk_e]]
last_topk_e_value = topk_e
for k in range(topk_e):
    indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
    value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
    e_prizes[indices] = value
    last_topk_e_value = value * (1 - c_const)

### Create Index

In [126]:
# Intialize several variables
collection = Collection(name=f"{milvus_database}_edges")
collection.load()

edges = collection.query(
    expr="triplet_index >= 0",
    output_fields=["head_index", "tail_index"]
)
edge_index = [
    [r['head_index'] for r in edges],
    [r['tail_index'] for r in edges]
]
edge_index = py.array(edge_index)

In [127]:
edge_index

array([[ 123,   99,  320, ..., 2509, 2304, 2142],
       [  36,   47,   47, ...,   22,  838,  846]])

### Run Subgraph Extraction

In [None]:
# query_df = df.read_parquet("../../../aiagents4pharma/talk2knowledgegraphs/tests/files/query_df.parquet")

In [129]:
query_df

Unnamed: 0,node_id,node_type,feat,feat_emb,desc,desc_emb,use_description
0,user_prompt,prompt,List the drugs that target Interleukin-6.,"[-0.020274699, -0.0045663384, -0.18202391, 0.0...",List the drugs that target Interleukin-6.,"[-0.020274699, -0.0045663384, -0.18202391, 0.0...",True


In [130]:
from types import SimpleNamespace

top = 10
top_k_e = 10
cost_e = 0.5
c_const = 0.01
root = -1
num_clusters = 1
pruning = "gw"
verbosity_level = 0
use_description = False
cfg_fe = SimpleNamespace(**{
    "milvus_db": SimpleNamespace(**{
        "alias": "default",
        "host": "localhost",
        "port": "19530",
        "uri": "http://localhost:19530",
        "token": "root:Milvus",
        "user": "root",
        "password": "Milvus",
        "database_name": "t2kg_primekg",
        "collection_edges": "t2kg_primekg_edges",
        "collection_nodes": "t2kg_primekg_nodes",
        "collection_nodes_gene_protein": "t2kg_primekg_nodes_gene_protein",
        "collection_nodes_molecular_function": "t2kg_primekg_nodes_molecular_function",
        "collection_nodes_cellular_component": "t2kg_primekg_nodes_cellular_component",
        "collection_nodes_biological_process": "t2kg_primekg_nodes_biological_process",
        "collection_nodes_drug": "t2kg_primekg_nodes_drug",
        "collection_nodes_disease": "t2kg_primekg_nodes_disease"
    })
})

In [131]:
for q in getattr(query_df, "to_pandas", lambda: query_df)().iterrows():
    print(q[1]['node_type'])

prompt


In [132]:
 import pickle
 with open("../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_pyg_graph.pkl", "rb") as f:
     pyg_graph= pickle.load(f)

In [133]:
# Initialize the subgraph dictionary
subgraphs = []
unified_subgraph = {
    "nodes": [],
    "edges": []
}
# subgraphs = {}
# subgraphs["nodes"] = []
# subgraphs["edges"] = []

# Loop over query embeddings and modalities
for q in getattr(query_df, "to_pandas", lambda: query_df)().iterrows():
    logger.log(logging.INFO, f"Processing query {q[1]['node_id']}")
    # Prepare the PCSTPruning object and extract the subgraph
    # Parameters were set in the configuration file obtained from Hydra
    subgraph = MultimodalPCSTPruningOld(
        topk=state["topk_nodes"],
        topk_e=state["topk_edges"],
        cost_e=cost_e,
        c_const=c_const,
        root=root,
        num_clusters=num_clusters,
        pruning=pruning,
        verbosity_level=verbosity_level,
        use_description=q[1]['use_description'],
    ).extract_subgraph(pyg_graph,
                       torch.tensor(q[1]['desc_emb']), # description embedding
                       torch.tensor(q[1]['feat_emb']), # modal-specific embedding
                       q[1]['node_type'])

    # Append the extracted subgraph to the dictionary
    unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
    unified_subgraph["edges"].append(subgraph["edges"].tolist())
    subgraphs.append((q[1]['node_id'],
                        subgraph["nodes"].tolist(),
                        subgraph["edges"].tolist()))

# Concatenate and get unique node and edge indices
unified_subgraph["nodes"] = py.unique(
    py.concatenate([py.array(list_) for list_ in unified_subgraph["nodes"]])
).tolist()
unified_subgraph["edges"] = py.unique(
    py.concatenate([py.array(list_) for list_ in unified_subgraph["edges"]])
).tolist()

# Convert the unified subgraph and subgraphs to cudf DataFrames
unified_subgraph = df.DataFrame([("Unified Subgraph",
                                    unified_subgraph["nodes"],
                                    unified_subgraph["edges"])],
                                    columns=["name", "nodes", "edges"])
subgraphs = df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])

# Concate both DataFrames
subgraphs = df.concat([unified_subgraph, subgraphs], ignore_index=True)

INFO:__main__:Processing query user_prompt


In [135]:
for sub in getattr(subgraphs, "to_pandas", lambda: subgraphs)().itertuples(index=False):
    print(sub.name)
    print(sub.nodes)
    print(sub.edges)

Unified Subgraph
[8, 14, 51, 547, 743, 749, 816, 1590, 2484, 2531, 2756, 2903]
[532, 551, 1107, 1154, 1159, 3844, 3851, 5871, 6107, 6116, 10384, 10844, 11135]
user_prompt
[8, 14, 51, 547, 743, 749, 816, 1590, 2484, 2531, 2756, 2903]
[3844, 1107, 3851, 532, 551, 11135, 10844, 10384, 1154, 1159, 5871, 6107, 6116]


In [136]:
pyg_graph.triplet_index[sub.edges].unique()

tensor([  182,   186,   187,  4351,  4352,  4752,  4753,  5818,  5823,  5959,
         7416,  8147, 10433])

In [142]:
# Initialize the subgraph dictionary
subgraphs = []
unified_subgraph = {
    "nodes": [],
    "edges": []
}
# subgraphs = {}
# subgraphs["nodes"] = []
# subgraphs["edges"] = []

# Loop over query embeddings and modalities
for q in getattr(query_df, "to_pandas", lambda: query_df)().iterrows():
    logger.log(logging.INFO, f"Processing query {q[1]['node_id']}")
    # Prepare the PCSTPruning object and extract the subgraph
    # Parameters were set in the configuration file obtained from Hydra
    subgraph = MultimodalPCSTPruningNew(
        topk=state["topk_nodes"],
        topk_e=state["topk_edges"],
        cost_e=cost_e,
        c_const=c_const,
        root=root,
        num_clusters=num_clusters,
        pruning=pruning,
        verbosity_level=verbosity_level,
        use_description=q[1]['use_description'],
    ).extract_subgraph(q[1]['desc_emb'],
                       q[1]['feat_emb'],
                       q[1]['node_type'],
                       cfg_fe)

    # Append the extracted subgraph to the dictionary
    unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
    unified_subgraph["edges"].append(subgraph["edges"].tolist())
    subgraphs.append((q[1]['node_id'],
                        subgraph["nodes"].tolist(),
                        subgraph["edges"].tolist()))

# Concatenate and get unique node and edge indices
unified_subgraph["nodes"] = py.unique(
    py.concatenate([py.array(list_) for list_ in unified_subgraph["nodes"]])
).tolist()
unified_subgraph["edges"] = py.unique(
    py.concatenate([py.array(list_) for list_ in unified_subgraph["edges"]])
).tolist()

# Convert the unified subgraph and subgraphs to cudf DataFrames
unified_subgraph = df.DataFrame([("Unified Subgraph",
                                    unified_subgraph["nodes"],
                                    unified_subgraph["edges"])],
                                    columns=["name", "nodes", "edges"])
subgraphs = df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])

# Concate both DataFrames
subgraphs = df.concat([unified_subgraph, subgraphs], ignore_index=True)

INFO:__main__:Processing query IL7R_(625)
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:compute_prizes
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:_compute_node_prizes
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:_compute_edge_prizes
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:Creating edge index
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:compute_subgraph_costs
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:Running PCST algorithm
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:Getting subgraph nodes and edges
INFO:__main__:Processing query TCF7_(5195)
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst:compute_prizes
INFO:aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pc

In [143]:
for sub in getattr(subgraphs, "to_pandas", lambda: subgraphs)().itertuples(index=False):
    print(sub.name)
    print(sub.nodes)
    print(sub.edges)

Unified Subgraph
[5, 8, 12, 14, 22, 27, 30, 54, 72, 82, 547, 743, 749, 816, 840, 845, 846, 855, 1254, 1320, 1415, 1863, 1952, 2102, 2263, 2479, 2484, 2531, 2756, 2779, 2827, 2903]
[182, 186, 187, 323, 888, 920, 921, 942, 1004, 1083, 1191, 1216, 1268, 1650, 1760, 1846, 2486, 4508, 4509, 4557, 4558, 4753, 4958, 5110, 5818, 5823, 6613, 6762, 7397, 7400, 7416, 8147, 10433]
IL7R_(625)
[5, 12, 22, 82, 840, 845, 855, 1952, 2479, 2827]
[7400, 6613, 6762, 7397, 920, 921, 1760, 1846, 4958]
TCF7_(5195)
[27, 30, 54, 72, 840, 846, 855, 1320, 1415, 1863, 2102, 2263, 2779]
[1268, 1216, 1083, 1004, 1191, 888, 942, 1650, 2486, 4557, 4558, 5110]
user_prompt
[8, 14, 547, 743, 749, 816, 1254, 2484, 2531, 2756, 2903]
[4509, 4508, 4753, 323, 8147, 7416, 10433, 182, 186, 187, 5818, 5823]


In [33]:
coll_name = f"{milvus_database}_nodes"
node_coll = Collection(name=coll_name)
node_coll.load()
nodes = node_coll.query(
    expr=f'node_index IN [{",".join(f"{n}" for n in [8, 14, 547, 743, 749, 816, 1254, 2484, 2531, 2756, 2903])}]',
    output_fields=['node_id', 'node_name', 'node_type', 'desc']
)
nodes = df.DataFrame(nodes)
nodes

Unnamed: 0,node_id,node_name,node_type,desc,node_index
0,RELA_(772),RELA,gene/protein,RELA belongs to gene/protein node. RELA is REL...,8
1,IL6_(1567),IL6,gene/protein,IL6 belongs to gene/protein node. IL6 is inter...,14
2,Binimetinib_(15391),Binimetinib,drug,"Binimetinib belongs to drug node. Binimetinib,...",547
3,Atiprimod_(17591),Atiprimod,drug,Atiprimod belongs to drug node. Investigat...,743
4,Dilmapimod_(17602),Dilmapimod,drug,Dilmapimod belongs to drug node. Dilmapimod ha...,749
5,SC-236_(20186),SC-236,drug,SC-236 belongs to drug node. SC-236 is a poten...,816
6,regulation of transcription initiation from RN...,regulation of transcription initiation from RN...,biological_process,regulation of transcription initiation from RN...,1254
7,interleukin-6-mediated signaling pathway_(108960),interleukin-6-mediated signaling pathway,biological_process,interleukin-6-mediated signaling pathway belon...,2484
8,cellular response to interleukin-6_(109554),cellular response to interleukin-6,biological_process,cellular response to interleukin-6 belongs to ...,2531
9,interleukin-6 receptor binding_(117413),interleukin-6 receptor binding,molecular_function,interleukin-6 receptor binding belongs to mole...,2756


In [34]:
coll_name = f"{milvus_database}_edges"
edge_coll = Collection(name=coll_name)
edge_coll.load()
graph_edges = edge_coll.query(
    expr=f'triplet_index IN [{",".join(f"{e}" for e in [182, 186, 187, 323, 4508, 4509, 4753, 5818, 5823, 7416, 8147, 10433])}]',
    output_fields=['head_id', 'tail_id', 'edge_type']
)
graph_edges = df.DataFrame(graph_edges)
graph_edges

Unnamed: 0,head_id,tail_id,edge_type,triplet_index
0,Atiprimod_(17591),IL6_(1567),drug|target|gene/protein,182
1,Binimetinib_(15391),IL6_(1567),drug|target|gene/protein,186
2,Dilmapimod_(17602),IL6_(1567),drug|target|gene/protein,187
3,SC-236_(20186),RELA_(772),drug|target|gene/protein,323
4,IL6_(1567),regulation of transcription initiation from RN...,gene/protein|interacts with|biological_process,4508
5,RELA_(772),regulation of transcription initiation from RN...,gene/protein|interacts with|biological_process,4509
6,RELA_(772),cellular response to interleukin-6_(109554),gene/protein|interacts with|biological_process,4753
7,IL6_(1567),Atiprimod_(17591),gene/protein|target|drug,5818
8,IL6_(1567),Dilmapimod_(17602),gene/protein|target|drug,5823
9,interleukin-6 receptor binding_(117413),IL6_(1567),molecular_function|interacts with|gene/protein,7416


### Optimization

In [150]:
local_dir = "../../../aiagents4pharma/talk2knowledgegraphs/tests/files"
# Input
with open(f"{local_dir}/input_edge_index.pkl", "rb") as f:
    edge_index = pickle.load(f)
with open(f"{local_dir}/input_prizes.pkl", "rb") as f:
    prizes = pickle.load(f)
num_nodes = 2991
# Output
with open(f"{local_dir}/output_edges_dict.pkl", "rb") as f:
    output_edges_dict = pickle.load(f)
with open(f"{local_dir}/output_prizes.pkl", "rb") as f:
    output_prizes = pickle.load(f)
with open(f"{local_dir}/output_costs.pkl", "rb") as f:
    output_costs = pickle.load(f)
with open(f"{local_dir}/output_mapping.pkl", "rb") as f:
    output_mapping = pickle.load(f)

In [151]:
edge_index

array([[ 123,   99,  320, ..., 2509, 2304, 2142],
       [  36,   47,   47, ...,   22,  838,  846]])

In [156]:
py.unique(prizes["nodes"])

array([0., 1., 2., 3., 4., 5.], dtype=float32)

In [157]:
py.unique(prizes["edges"])

array([0., 1., 2., 3., 4., 5.], dtype=float32)

In [158]:
updated_cost_e = min(
    cost_e,
    prizes["edges"].max().item() * (1 - c_const / 2),
)
updated_cost_e


0.5

In [165]:
# Initialize variables
edges = []
costs = []
virtual = {
    "n_prizes": [],
    "edges": [],
    "costs": [],
}
mapping = {"nodes": {}, "edges": {}}

In [None]:
# Logic to reduce the cost of the edges such that at least one edge is selected
updated_cost_e = min(
    cost_e,
    prizes["edges"].max().item() * (1 - self.c_const / 2),
)

# Initialize variables
edges = []
costs = []
virtual = {
    "n_prizes": [],
    "edges": [],
    "costs": [],
}
mapping = {"nodes": {}, "edges": {}}

# Compute the costs, edges, and virtual variables based on the prizes
for i, (src, dst) in enumerate(edge_index.T):
    prize_e = prizes["edges"][i].item()
    if prize_e <= updated_cost_e:
        mapping["edges"][len(edges)] = i
        edges.append((src.item(), dst.item()))
        costs.append(updated_cost_e - prize_e)
    else:
        virtual_node_id = num_nodes + len(virtual["n_prizes"])
        mapping["nodes"][virtual_node_id] = i
        virtual["edges"].append((src.item(), virtual_node_id))
        virtual["edges"].append((virtual_node_id, dst.item()))
        virtual["costs"].append(0)
        virtual["costs"].append(0)
        virtual["n_prizes"].append(prize_e - updated_cost_e)
prizes = py.concatenate([prizes["nodes"], py.array(virtual["n_prizes"])])
edges_dict = {}
edges_dict["edges"] = edges
edges_dict["num_prior_edges"] = len(edges)
# Final computation of the costs and edges based on the virtual costs and virtual edges
if len(virtual["costs"]) > 0:
    costs = py.array(costs + virtual["costs"])
    edges = py.array(edges + virtual["edges"])
    edges_dict["edges"] = edges

return edges_dict, prizes, costs, mapping

In [172]:
c_const, cost_e

(0.01, 0.5)

In [176]:
isinstance(prizes["edges"], py.ndarray)

True

In [228]:
# Initialize the collections dictionary
colls = {}

# Load the collection for nodes
colls["nodes"] = Collection(name=f"{milvus_database}_nodes")

if modality != "prompt":
    # Load the collection for the specific node type
    colls["nodes_type"] = Collection(
        f"{milvus_database}_nodes_{modality.replace('/', '_')}"
    )

# Load the collection for edges
colls["edges"] = Collection(name=f"{milvus_database}_edges")

# Load the collections
for coll in colls.values():
    coll.load()

In [227]:
colls["nodes"].load()

In [179]:
def compute_subgraph_costs(edge_index, num_nodes, prizes, c_const=0.01, cost_e=0.5):
    src, dst = edge_index  # shape: (2, num_edges)
    prizes_edges = prizes["edges"]
    
    # Update edge cost threshold
    updated_cost_e = min(
        cost_e,
        py.max(prizes_edges).item() * (1 - c_const / 2),
    )

    # Masks for real and virtual edges
    mask_real = prizes_edges <= updated_cost_e
    mask_virtual = ~mask_real

    # Real edge indices
    real_indices = py.nonzero(mask_real)[0]
    real_src = src[real_indices]
    real_dst = dst[real_indices]
    real_edges = py.stack([real_src, real_dst], axis=1)
    real_costs = updated_cost_e - prizes_edges[real_indices]

    # Edge index mapping: local real edge idx -> original global index
    mapping_edges = {int(i): int(j) for i, j in enumerate(real_indices)}

    # Virtual edge handling
    virtual_indices = py.nonzero(mask_virtual)[0]
    virtual_src = src[virtual_indices]
    virtual_dst = dst[virtual_indices]
    virtual_prizes = prizes_edges[virtual_indices] - updated_cost_e

    # Generate virtual node IDs
    num_virtual = virtual_indices.shape[0]
    virtual_node_ids = py.arange(num_nodes, num_nodes + num_virtual)

    # Virtual edges: (src → virtual), (virtual → dst)
    v_edges_1 = py.stack([virtual_src, virtual_node_ids], axis=1)
    v_edges_2 = py.stack([virtual_node_ids, virtual_dst], axis=1)
    virtual_edges = py.concatenate([v_edges_1, v_edges_2], axis=0)
    virtual_costs = py.zeros((virtual_edges.shape[0],), dtype=real_costs.dtype)

    # Combine real and virtual edges/costs
    all_edges = py.concatenate([real_edges, virtual_edges], axis=0)
    all_costs = py.concatenate([real_costs, virtual_costs], axis=0)

    # Final prizes
    final_prizes = py.concatenate([prizes["nodes"], virtual_prizes], axis=0)

    # Mapping virtual node ID -> edge index in original graph
    mapping_nodes = {int(nid): int(idx) for nid, idx in zip(virtual_node_ids, virtual_indices)}

    # Build return values
    edges_dict = {
        "edges": all_edges,
        "num_prior_edges": real_edges.shape[0],
    }
    mapping = {
        "edges": mapping_edges,
        "nodes": mapping_nodes,
    }

    return edges_dict, final_prizes, all_costs, mapping


In [219]:
src, dst = edge_index

In [220]:
src

array([ 123,   99,  320, ..., 2509, 2304, 2142])

In [222]:
dst

array([ 36,  47,  47, ...,  22, 838, 846])

In [221]:
edge_index[0]

array([ 123,   99,  320, ..., 2509, 2304, 2142])

In [223]:
edge_index[1]

array([ 36,  47,  47, ...,  22, 838, 846])

In [218]:
prizes["edges"]

array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

In [180]:
outputs = compute_subgraph_costs(
    edge_index,
    num_nodes,
    prizes,
)

In [197]:
py.unique(output_edges_dict["edges"].flatten()).shape

(2993,)

In [208]:
output_prizes.shape, output_costs.shape

((2996,), (11277,))

In [217]:
sum(py.unique(output_edges_dict["edges"].flatten()) == py.unique(outputs[0]['edges'].flatten()))
output_edges_dict['num_prior_edges'] == outputs[0]['num_prior_edges']
output_prizes == outputs[1]
output_costs == outputs[2]
output_mapping['nodes'], outputs[3]['nodes']
for k in output_mapping['edges'].keys():
    print(output_mapping['edges'][k] == outputs[3]['edges'][k])

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [None]:
output_edges_dict['edges'], final_prizes, all_costs, mapping

({'edges': array([[ 123,   36],
         [  99,   47],
         [ 320,   47],
         ...,
         [2993,  744],
         [2994,  748],
         [2995,  119]]),
  'num_prior_edges': 11267},
 array([0. , 0. , 0. , ..., 4.5, 0.5, 3.5], dtype=float32),
 array([0.5, 0.5, 0.5, ..., 0. , 0. , 0. ], dtype=float32),
 {'edges': {0: 0,
   1: 1,
   2: 2,
   3: 3,
   4: 4,
   5: 5,
   6: 6,
   7: 7,
   8: 8,
   9: 9,
   10: 10,
   11: 11,
   12: 12,
   13: 13,
   14: 14,
   15: 15,
   16: 16,
   17: 17,
   18: 18,
   19: 19,
   20: 20,
   21: 21,
   22: 22,
   23: 23,
   24: 24,
   25: 25,
   26: 26,
   27: 27,
   28: 28,
   29: 29,
   30: 30,
   31: 31,
   32: 32,
   33: 33,
   34: 34,
   35: 35,
   36: 36,
   37: 37,
   38: 38,
   39: 39,
   40: 40,
   41: 41,
   42: 42,
   43: 43,
   44: 44,
   45: 45,
   46: 46,
   47: 47,
   48: 48,
   49: 49,
   50: 50,
   51: 51,
   52: 52,
   53: 53,
   54: 54,
   55: 55,
   56: 56,
   57: 57,
   58: 58,
   59: 59,
   60: 60,
   61: 61,
   62: 62,
   63:

In [149]:
edges_dict, prizes, costs, mapping

({'edges': array([[ 123,   36],
         [  99,   47],
         [ 320,   47],
         ...,
         [2994,  365],
         [  24, 2995],
         [2995,  119]]),
  'num_prior_edges': 11267},
 array([0. , 0. , 0. , ..., 3.5, 1.5, 2.5]),
 array([0.5, 0.5, 0.5, ..., 0. , 0. , 0. ]),
 {'nodes': {2991: 167, 2992: 5801, 2993: 5803, 2994: 5810, 2995: 5814},
  'edges': {0: 0,
   1: 1,
   2: 2,
   3: 3,
   4: 4,
   5: 5,
   6: 6,
   7: 7,
   8: 8,
   9: 9,
   10: 10,
   11: 11,
   12: 12,
   13: 13,
   14: 14,
   15: 15,
   16: 16,
   17: 17,
   18: 18,
   19: 19,
   20: 20,
   21: 21,
   22: 22,
   23: 23,
   24: 24,
   25: 25,
   26: 26,
   27: 27,
   28: 28,
   29: 29,
   30: 30,
   31: 31,
   32: 32,
   33: 33,
   34: 34,
   35: 35,
   36: 36,
   37: 37,
   38: 38,
   39: 39,
   40: 40,
   41: 41,
   42: 42,
   43: 43,
   44: 44,
   45: 45,
   46: 46,
   47: 47,
   48: 48,
   49: 49,
   50: 50,
   51: 51,
   52: 52,
   53: 53,
   54: 54,
   55: 55,
   56: 56,
   57: 57,
   58: 58,
   59: 5

In [None]:
graph_edges['edge_type'].str.split('|')

In [None]:
sub.nodes, sub.edges

In [None]:
sub.edges

In [None]:
subgraphs

In [None]:
unified_subgraph

In [None]:
subgraphs

In [None]:
n_prizes.tolist()

In [None]:
# Intialize several variables
collection = Collection(name=f"{milvus_database}_edges")
collection.load()

results = collection.query(
    expr="triplet_index >= 0",
    output_fields=["head_index", "tail_index"],
)

In [None]:
np.array([[r['head_index'], r['tail_index']] for r in results])

In [None]:
topk_e

In [None]:
inverse_indices

In [None]:
indices = inverse_indices == (unique_prizes == topk_e_values[0]).nonzero()[0]

In [None]:
py.where(py.array(indices) == True)

In [None]:
indices.sum().item()

In [None]:
(topk_e - 0) / indices.sum().item()

In [None]:
min((topk_e - 0) / indices.sum().item(), last_topk_e_value)

In [None]:
[r['distance'] for r in res[0]]

In [None]:
e_prizes[[r['triplet_index'] for r in res[0]]]

In [None]:
[r['triplet_index'] for r in res[0]]

In [None]:
[r['head_id'] for r in res[0]]

In [None]:
[r['tail_id'] for r in res[0]]

In [None]:
# Vector similarity search in Milvus
vector_to_search = nodes_df["feat_emb"].iloc[0]
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"]
)
results

In [None]:
for node_type, group in multimodal_df.groupby("q_node_type"):
    print(f"Processing node type: {node_type}")
    
    # Load the collection
    collection = Collection(name=f"{milvus_database}_nodes_{node_type.replace('/', '_')}")
    collection.load()

    # Query the collection with node names from multimodal_df
    q_node_names =  getattr(multimodal_df['q_node_name'], 
                            "to_pandas", 
                            lambda: multimodal_df['q_node_name'])().tolist()
    q_columns = ["node_id", "node_type", "feat", "feat_emb", "desc", "desc_emb"]
    res = collection.query(
        expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
        output_fields=q_columns,
    )
    
    # Convert the result to a DataFrame
    res_df = df.DataFrame(res)[q_columns]
    


In [None]:
res_df

In [None]:
[r['node_id'] for r in res]

In [None]:
[r['node_type'] for r in res]

In [None]:
[r['desc'] for r in res]

In [None]:
multimodal_df['q_node_name']

In [None]:
multimodal_df[['q_node_name']].to_csv(header=False, index=False, sep=",")

In [None]:
getattr(multimodal_df['q_node_name'], "to_pandas", lambda: multimodal_df['q_node_name'])().tolist()

In [None]:
milvus