In [1]:
import torch
if torch.cuda.is_available():
    import cudf as df
    import cupy as py
else:
    import pandas as df
    import numpy as py
import pandas as pd
from pymilvus import (
    db,
    connections,
    utility,
    Collection,
)
import os
import glob
import cudf
import dask_cudf
import hydra
import logging
import sys
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.ollama import EmbeddingWithOllama

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define the data path
DATA_PATH = "../../../aiagents4pharma/talk2knowledgegraphs/tests/files"

# Define the agent state
state = {
    "selections": {
        "gene/protein": [],
        "molecular_function": [],
        "cellular_component": [],
        "biological_process": [],
        "drug": [],
        "disease": []
    },
    "uploaded_files": [],
    "topk_nodes": 3,
    "topk_edges": 3,
    "dic_source_graph": [
        {
            "name": "BioBridge",
            "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
            "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
        }
    ],
}

In [3]:
# Update state
state["uploaded_files"] = [
    {
        "file_name": "multimodal-analysis_single_gene.xlsx",
        "file_path": f"{DATA_PATH}/multimodal-analysis_single_gene.xlsx",
        "file_type": "multimodal",
        "uploaded_by": "VPEUser",
        "uploaded_timestamp": "2025-05-12 00:00:00",
    }
]

In [4]:
# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [10]:
# Initialize dataframes
logger.log(logging.INFO, "Initializing dataframes")
multimodal_df = df.DataFrame({"name": [], "node_type": []})
query_df = []

# Loop over the uploaded files and find multimodal files
logger.log(logging.INFO, "Looping over uploaded files")
for i in range(len(state["uploaded_files"])):
    # Check if multimodal file is uploaded
    if state["uploaded_files"][i]["file_type"] == "multimodal":
        # Read the Excel file
        multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
                                        sheet_name=None)
        
        
# Check if the multimodal_df is empty
logger.log(logging.INFO, "Checking if multimodal_df is empty")
if len(multimodal_df) > 0:
    # Prepare multimodal_df
    logger.log(logging.INFO, "Preparing multimodal_df")
    # Merge all obtained dataframes into a single dataframe
    multimodal_df = pd.concat(multimodal_df).reset_index()
    multimodal_df = df.DataFrame(multimodal_df)
    multimodal_df.drop(columns=["level_1"], inplace=True)
    multimodal_df.rename(columns={"level_0": "q_node_type",
                                  "name": "q_node_name"}, inplace=True)
    # Since an excel sheet name could not contain a `/`,
    # but the node type can be 'gene/protein' as exists in the PrimeKG
    multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace('-', '_')
    
    # Query the Milvus database for each node type in multimodal_df
    logger.log(logging.INFO, "Querying Milvus database for each node type in multimodal_df")
    for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
        print(f"Processing node type: {node_type}")
        
        # Load the collection
        collection = Collection(name=f"{milvus_database}_nodes_{node_type.replace('/', '_')}")
        collection.load()

        # Query the collection with node names from multimodal_df
        q_node_names =  getattr(node_type_df['q_node_name'], 
                                "to_pandas", 
                                lambda: node_type_df['q_node_name'])().tolist()
        q_columns = ["node_id", "node_type", "feat", "feat_emb", "desc", "desc_emb"]
        res = collection.query(
            expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
            output_fields=q_columns,
        )
        
        # Convert the result to a DataFrame
        res_df = df.DataFrame(res)[q_columns]
        res_df["use_description"] = False
        
        # Append the results to query_df
        query_df.append(res_df)
    
    # Concatenate all results into a single DataFrame
    logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
    query_df = df.concat(query_df, ignore_index=True)
    
    # Update the state by adding the the selected node IDs
    logger.log(logging.INFO, "Updating state with selected node IDs")
    state["selections"] = query_df.to_pandas().groupby(
        "node_type"
    )["node_id"].apply(list).to_dict()


INFO:__main__:Initializing dataframes
INFO:__main__:Looping over uploaded files
INFO:__main__:Checking if multimodal_df is empty
INFO:__main__:Preparing multimodal_df
INFO:__main__:Querying Milvus database for each node type in multimodal_df
INFO:__main__:Concatenating all results into a single DataFrame
INFO:__main__:Updating state with selected node IDs


Processing node type: gene_protein


In [11]:
query_df

Unnamed: 0,node_id,node_type,feat,feat_emb,desc,desc_emb,use_description
0,IL7R_(625),gene/protein,MTILGTTFGMVFSLLQVVSGESGYAQNGDLEDAELDDYSFSCYSQL...,"[0.005123700015246868, 0.005596505478024483, 0...",IL7R belongs to gene/protein node. IL7R is int...,"[0.04506901279091835, 0.008911124430596828, -0...",False
1,TCF7_(5195),gene/protein,MPQLDSGGGGAGGGDDLGAPDELLAFQDEGEEQDDKSRDSAAGPER...,"[0.0020665288902819157, 0.00025825444026850164...",TCF7 belongs to gene/protein node. TCF7 is tra...,"[0.03699759021401405, 0.038098547607660294, -0...",False


In [121]:
state["selections"]

{'gene/protein': ['IL7R_(625)', 'TCF7_(5195)']}

In [6]:
# Configuration for Milvus
milvus_host = "localhost"
milvus_port = "19530"
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_user = "root"
milvus_password = "Milvus"
milvus_database = "t2kg_primekg"

In [7]:
# Connect to Milvus
connections.connect(
    alias="default",
    host=milvus_host,
    port=milvus_port,
    user=milvus_user,
    password=milvus_password
)

In [8]:
# Query the milvus database
connections.has_connection("default")

True

In [9]:
db.using_database(milvus_database)

utility.list_collections()

['t2kg_primekg_nodes_molecular_function',
 't2kg_primekg_edges',
 't2kg_primekg_nodes',
 't2kg_primekg_nodes_biological_process',
 't2kg_primekg_nodes_gene_protein',
 't2kg_primekg_nodes_cellular_component',
 't2kg_primekg_nodes_disease',
 't2kg_primekg_nodes_drug']

In [None]:
for node_type, group in multimodal_df.groupby("q_node_type"):
    print(f"Processing node type: {node_type}")
    
    # Load the collection
    collection = Collection(name=f"{milvus_database}_nodes_{node_type.replace('/', '_')}")
    collection.load()

    # Query the collection with node names from multimodal_df
    q_node_names =  getattr(multimodal_df['q_node_name'], 
                            "to_pandas", 
                            lambda: multimodal_df['q_node_name'])().tolist()
    q_columns = ["node_id", "node_type", "feat", "feat_emb", "desc", "desc_emb"]
    res = collection.query(
        expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
        output_fields=q_columns,
    )
    
    # Convert the result to a DataFrame
    res_df = df.DataFrame(res)[q_columns]
    


Processing node type: gene_protein


In [109]:
res_df

Unnamed: 0,node_id,node_type,feat,feat_emb,desc,desc_emb
0,IL7R_(625),gene/protein,MTILGTTFGMVFSLLQVVSGESGYAQNGDLEDAELDDYSFSCYSQL...,"[0.005123700015246868, 0.005596505478024483, 0...",IL7R belongs to gene/protein node. IL7R is int...,"[0.04506901279091835, 0.008911124430596828, -0..."
1,TCF7_(5195),gene/protein,MPQLDSGGGGAGGGDDLGAPDELLAFQDEGEEQDDKSRDSAAGPER...,"[0.0020665288902819157, 0.00025825444026850164...",TCF7 belongs to gene/protein node. TCF7 is tra...,"[0.03699759021401405, 0.038098547607660294, -0..."


In [102]:
[r['node_id'] for r in res]

['IL7R_(625)', 'TCF7_(5195)']

In [98]:
[r['node_type'] for r in res]

['gene/protein', 'gene/protein']

In [99]:
[r['desc'] for r in res]

['IL7R belongs to gene/protein node. IL7R is interleukin 7 receptor. The protein encoded by this gene is a receptor for interleukin 7 (IL7). The function of this receptor requires the interleukin 2 receptor, gamma chain (IL2RG), which is a common gamma chain shared by the receptors of various cytokines, including interleukins 2, 4, 7, 9, and 15. This protein has been shown to play a critical role in V(D)J recombination during lymphocyte development. Defects in this gene may be associated with severe combined immunodeficiency (SCID). Alternatively spliced transcript variants have been found. [provided by RefSeq, Dec 2015].',
 'TCF7 belongs to gene/protein node. TCF7 is transcription factor 7. This gene encodes a member of the T-cell factor/lymphoid enhancer-binding factor family of high mobility group (HMG) box transcriptional activators. This gene is expressed predominantly in T-cells and plays a critical role in natural killer cell and innate lymphoid cell development. The encoded pro

data: ['{\'node_id\': \'SMAD3_(144)\', \'node_type\': \'gene/protein\', \'desc\': "SMAD3 belongs to gene/protein node. SMAD3 is SMAD family member 3. The SMAD family of proteins are a group of intracellular signal transducer proteins similar to the gene products of the Drosophila gene \'mothers against decapentaplegic\' (Mad) and the C. elegans gene Sma. The SMAD3 protein functions in the transforming growth factor-beta signaling pathway, and transmits signals from the cell surface to the nucleus, regulating gene activity and cell proliferation. This protein forms a complex with other SMAD proteins and binds DNA, functioning both as a transcription factor and tumor suppressor. Mutations in this gene are associated with aneurysms-osteoarthritis syndrome and Loeys-Dietz Syndrome 3. [provided by RefSeq, May 2022].", \'node_index\': 0}']

In [64]:
multimodal_df['q_node_name']

0        ACTG1
1         AIF1
2      ALOX5AP
3      ANKRD12
4        ANXA1
        ...   
184      TXNIP
185     TYROBP
186       UCP2
187        VIM
188       ZEB2
Name: q_node_name, Length: 189, dtype: object

In [81]:
multimodal_df[['q_node_name']].to_csv(header=False, index=False, sep=",")

'ACTG1\nAIF1\nALOX5AP\nANKRD12\nANXA1\nANXA5\nAPBB1IP\nARF6\nARHGDIB\nARID4B\nATP5D\nB2M\nBTG1\nC12ORF57\nC12ORF75\nC19ORF60\nC9ORF142\nCALM1\nCAPZB\nCCL5\nCCR7\nCD2\nCD3E\nCD52\nCD74\nCD8A\nCD8B\nCD96\nCEBPD\nCHCHD10\nCKLF\nCMC1\nCOA4\nCORO1B\nCOX4I1\nCRIP1\nCST7\nCTSW\nCX3CR1\nDDX5\nDSTN\nDUSP1\nEEF1A1\nEEF1G\nEEF2\nEFHD2\nEIF1\nEIF1AX\nEIF3E\nFAM173A\nFGD3\nFGFBP2\nFOS\nFTH1\nFUBP1\nFYB\nGLRX\nGLTSCR2\nGMFG\nGNAS\nGNLY\nGPR183\nGYPC\nGZMA\nGZMB\nGZMH\nGZMK\nGZMM\nH3F3A\nHCST\nHLA-A\nHLA-DPB1\nHMGB2\nHNRNPU\nHOPX\nHSP90B1\nIFITM2\nIL32\nIL7R\nILF2\nITGA4\nITGB1\nITM2B\nKLF6\nKLRB1\nKLRD1\nKLRF1\nKLRG1\nKLRK1\nKRTCAP2\nLEPROTL1\nLGALS1\nLGALS3\nLIME1\nLIMS1\nLTB\nLYAR\nMALT1\nMBP\nMT-ATP6\nMT-CO1\nMT-CO2\nMT-CO3\nMT-CYB\nMT-ND1\nMT-ND2\nMT-ND3\nMT-ND4\nMT-ND5\nMT-ND6\nMT2A\nMTIF3\nMTRNR2L12\nNCR3\nNDUFA13\nNDUFV2\nNKG7\nNKTR\nNOSIP\nNSA2\nNSG1\nPASK\nPFN1\nPHACTR2\nPHF11\nPHF3\nPLCG2\nPLEK\nPLIN2\nPNISR\nPPIB\nPPP2R5C\nPRF1\nPTPRC\nRNASET2\nRPL13A\nRPL14\nRPL17\nRPL21\nRPL23\nRPL30\nR

In [82]:
getattr(multimodal_df['q_node_name'], "to_pandas", lambda: multimodal_df['q_node_name'])().tolist()

['ACTG1',
 'AIF1',
 'ALOX5AP',
 'ANKRD12',
 'ANXA1',
 'ANXA5',
 'APBB1IP',
 'ARF6',
 'ARHGDIB',
 'ARID4B',
 'ATP5D',
 'B2M',
 'BTG1',
 'C12ORF57',
 'C12ORF75',
 'C19ORF60',
 'C9ORF142',
 'CALM1',
 'CAPZB',
 'CCL5',
 'CCR7',
 'CD2',
 'CD3E',
 'CD52',
 'CD74',
 'CD8A',
 'CD8B',
 'CD96',
 'CEBPD',
 'CHCHD10',
 'CKLF',
 'CMC1',
 'COA4',
 'CORO1B',
 'COX4I1',
 'CRIP1',
 'CST7',
 'CTSW',
 'CX3CR1',
 'DDX5',
 'DSTN',
 'DUSP1',
 'EEF1A1',
 'EEF1G',
 'EEF2',
 'EFHD2',
 'EIF1',
 'EIF1AX',
 'EIF3E',
 'FAM173A',
 'FGD3',
 'FGFBP2',
 'FOS',
 'FTH1',
 'FUBP1',
 'FYB',
 'GLRX',
 'GLTSCR2',
 'GMFG',
 'GNAS',
 'GNLY',
 'GPR183',
 'GYPC',
 'GZMA',
 'GZMB',
 'GZMH',
 'GZMK',
 'GZMM',
 'H3F3A',
 'HCST',
 'HLA-A',
 'HLA-DPB1',
 'HMGB2',
 'HNRNPU',
 'HOPX',
 'HSP90B1',
 'IFITM2',
 'IL32',
 'IL7R',
 'ILF2',
 'ITGA4',
 'ITGB1',
 'ITM2B',
 'KLF6',
 'KLRB1',
 'KLRD1',
 'KLRF1',
 'KLRG1',
 'KLRK1',
 'KRTCAP2',
 'LEPROTL1',
 'LGALS1',
 'LGALS3',
 'LIME1',
 'LIMS1',
 'LTB',
 'LYAR',
 'MALT1',
 'MBP',
 'MT-ATP6',
 

In [None]:
milvus