In [2]:
import torch
if torch.cuda.is_available():
    import cudf as df
    import cupy as py
else:
    import pandas as df
    import numpy as py
import pandas as pd
from pymilvus import (
    db,
    connections,
    utility,
    Collection,
)
import os
import glob
import cudf
import dask_cudf
import hydra
import logging
from langchain_openai import OpenAIEmbeddings
import sys
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.ollama import EmbeddingWithOllama
from aiagents4pharma.talk2knowledgegraphs.utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning as MultimodalPCSTPruningNew
from aiagents4pharma.talk2knowledgegraphs.utils.extractions.multimodal_pcst import MultimodalPCSTPruning as MultimodalPCSTPruningOld

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Configuration for Milvus
milvus_host = "localhost"
milvus_port = "19530"
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_user = "root"
milvus_password = "Milvus"
milvus_database = "t2kg_primekg"

In [4]:
# Connect to Milvus
connections.connect(
    alias="default",
    host=milvus_host,
    port=milvus_port,
    user=milvus_user,
    password=milvus_password
)

In [5]:
db.using_database(milvus_database)

utility.list_collections()

['t2kg_primekg_nodes',
 't2kg_primekg_nodes_biological_process',
 't2kg_primekg_nodes_cellular_component',
 't2kg_primekg_nodes_disease',
 't2kg_primekg_nodes_drug',
 't2kg_primekg_nodes_gene_protein',
 't2kg_primekg_nodes_molecular_function',
 't2kg_primekg_edges']

In [5]:
prompt = "A patient is diagnosed with inflammatory bowel disease. The patient's omics profile reveals that there is an overexpression of JAK1 levels. What are the drugs you'd recommend to inhibit JAK1 to treat the disease? Please extract a subgraph to explain your answer."

In [None]:
os.environ['OPENAI_API_KEY'] = "xxx"
prompt_emb = OpenAIEmbeddings(model='text-embedding-ada-002').embed_query(prompt)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


In [8]:
# # Query the collection with node names from multimodal_df
# q_node_names =  getattr(node_type_df['q_node_name'],
#                         "to_pandas",
#                         lambda: node_type_df['q_node_name'])().tolist()
collection = Collection('t2kg_primekg_nodes_gene_protein')
q_columns = ["node_id", "node_name", "node_type", "feat", "feat_emb", "desc", "desc_emb"]
res = collection.query(
    expr=f'node_name IN [{','.join(f'"{name}"' for name in ['JAK1', 'JAK2'])}]',
    output_fields=q_columns,
)
res

data: ["{'feat_emb': [0.0013380006, -0.0062988503, -0.0025861484, -0.0005466234, -0.0015513409, -0.0006619322, -0.0046984046, 0.012385055, 0.0045606005, 0.005080533, -0.006795768, -0.001498624, 0.003802797, 0.015259844, -0.003178962, -0.0052186693, -0.010775482, -0.0032112948, -0.0015902283, 0.0006502992, 0.0008241086, -0.0023181767, -0.0010040231, 0.00084730115, -0.0011392918, -0.006576022, -0.0032543375, -0.0033684445, -0.011699086, -0.010433914, 0.0031521916, 0.0066578193, -0.0043125222, -0.0028705427, 0.0025425453, -0.0018970314, 0.0005124194, 0.0015830761, -0.004088768, 0.0039491644, 0.0002761907, -0.00018609996, 0.007967015, -0.0013525843, -0.014276525, -0.0026016105, 0.002552265, 0.0042294576, -0.0017183361, -0.002268201, -0.0021365725, 0.0058297664, 0.00045224448, 0.005471259, -0.002217352, -0.0026397712, -0.004183228, 0.0009926874, -0.0015811849, 0.0010384733, -0.0006647584, -0.0040842462, -0.010013338, 0.0005960069, -0.002709248, 0.0016170277, 0.0034417668, -0.005432655, 0.01

In [11]:
[n['node_id'] for n in res]

['JAK1_(1410)', 'JAK2_(1618)']

In [10]:
[n['node_name'] for n in res]

['JAK1', 'JAK2']

In [19]:
def normalize_matrix(m, axis=1):
    """
    Normalize each row of a 2D matrix using CuPy.
    
    Parameters:
    m (cupy.ndarray): 2D matrix to normalize.
    
    Returns:
    cupy.ndarray: Normalized matrix.
    """
    norms = py.linalg.norm(m, axis=axis, keepdims=True)
    return m / norms

def normalize_vector(v):
    """
    Normalize a vector using CuPy.
    
    Parameters:
    v (cupy.ndarray): Vector to normalize.
    
    Returns:
    cupy.ndarray: Normalized vector.
    """
    v = py.asarray(v)
    norm = py.linalg.norm(v)
    return v / norm

In [46]:
# Assume node_coll_name is defined and collection exists
collection = Collection('t2kg_primekg_nodes')

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
# vector_to_search = normalize_vector(py.array(prompt_emb)).tolist()
vector_to_search = prompt_emb
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="desc_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name", "desc"]
)

results

data: [[{'node_index': 15355, 'distance': 0.8468228578567505, 'entity': {'node_id': 'Baricitinib_(15666)', 'node_name': 'Baricitinib', 'desc': 'Baricitinib belongs to drug node. Baricitinib is a selective and reversible Janus kinase 1 (JAK1) and 2 (JAK2) inhibitor. Janus kinases belong to the tyrosine protein kinase family and play an important role in the proinflammatory pathway signalling that is frequently over-activated in autoimmune disorders such as rheumatoid arthritis. By blocking the actions of JAK1/2, baricitinib disrupts the activation of downstream signalling molecules and proinflammatory mediators.  JAK enzymes are part of the family of tyrosine kinases that constitutively bind to the intracellular domains of cytokine receptors and promote the signalling cascades of cytokines and growth factors involved in haematopoiesis, inflammation and immune function that are also implicated in the pathogenesis of rheumatoid arthritis. Circulating proinflammatory cytokines bind to thes

In [45]:
# Get node indices from the results
for i, n in enumerate([(n['node_id'], n['desc'], n['distance']) for n in results[0]]):
    print(f"Node {i+1}:")
    print(f"Sim Score: {n[2]}")
    print(f"Node ID: {n[0]}")
    print(f"Description: {n[1].replace('.', '.\n')}")
    print('---')

Node 1:
Sim Score: 0.8468227982521057
Node ID: Baricitinib_(15666)
Description: Baricitinib belongs to drug node.
 Baricitinib is a selective and reversible Janus kinase 1 (JAK1) and 2 (JAK2) inhibitor.
 Janus kinases belong to the tyrosine protein kinase family and play an important role in the proinflammatory pathway signalling that is frequently over-activated in autoimmune disorders such as rheumatoid arthritis.
 By blocking the actions of JAK1/2, baricitinib disrupts the activation of downstream signalling molecules and proinflammatory mediators.
  JAK enzymes are part of the family of tyrosine kinases that constitutively bind to the intracellular domains of cytokine receptors and promote the signalling cascades of cytokines and growth factors involved in haematopoiesis, inflammation and immune function that are also implicated in the pathogenesis of rheumatoid arthritis.
 Circulating proinflammatory cytokines bind to these cell surface receptors.
 Upon binding of extracellular cy

In [33]:
# Assume node_coll_name is defined and collection exists
collection = Collection('t2kg_primekg_edges')

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
# vector_to_search = normalize_vector(py.array(prompt_emb)).tolist()
vector_to_search = prompt_emb
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["head_id", "tail_id", "edge_type"]
)

results

data: [[{'triplet_index': 3486988, 'distance': 0.881709098815918, 'entity': {'head_id': 'inflammatory bowel disease_(28158)', 'tail_id': 'JAK2_(1618)', 'edge_type': 'disease|associated with|gene/protein'}}, {'triplet_index': 2651638, 'distance': 0.8756547570228577, 'entity': {'head_id': 'JAK2_(1618)', 'tail_id': 'inflammatory bowel disease_(28158)', 'edge_type': 'gene/protein|associated with|disease'}}, {'triplet_index': 3515751, 'distance': 0.8587834239006042, 'entity': {'head_id': "Crohn's colitis_(83770)", 'tail_id': 'JAK2_(1618)', 'edge_type': 'disease|associated with|gene/protein'}}, {'triplet_index': 2989081, 'distance': 0.85621178150177, 'entity': {'head_id': 'JAK2_(1618)', 'tail_id': 'positive regulation of inflammatory response_(40214)', 'edge_type': 'gene/protein|interacts with|biological_process'}}, {'triplet_index': 2987134, 'distance': 0.8557138442993164, 'entity': {'head_id': 'JAK2_(1618)', 'tail_id': 'regulation of inflammatory response_(46101)', 'edge_type': 'gene/prote

In [34]:
# Get node indices from the results
[(n['head_id'], n['tail_id']) for n in results[0]]

[('inflammatory bowel disease_(28158)', 'JAK2_(1618)'),
 ('JAK2_(1618)', 'inflammatory bowel disease_(28158)'),
 ("Crohn's colitis_(83770)", 'JAK2_(1618)'),
 ('JAK2_(1618)', 'positive regulation of inflammatory response_(40214)'),
 ('JAK2_(1618)', 'regulation of inflammatory response_(46101)'),
 ('ulcerative colitis (disease)_(37785)', 'JAK2_(1618)'),
 ('Crohn disease_(37784)', 'JAK2_(1618)'),
 ('regulation of inflammatory response_(46101)', 'JAK2_(1618)'),
 ('positive regulation of inflammatory response_(40214)', 'JAK2_(1618)'),
 ('JAK2_(1618)', "Crohn's colitis_(83770)")]