In [1]:
# Import necessary libraries
import pickle
import pandas as pd

import sys
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.tools.multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool

  from .autonotebook import tqdm as notebook_tqdm
  register_pytree_node(
  register_pytree_node(


In [2]:
# Define the data path
DATA_PATH = "../../../aiagents4pharma/talk2knowledgegraphs/tests/files"

# Define the agent state
agent_state = {
    "selections": {
        "gene/protein": [],
        "molecular_function": [],
        "cellular_component": [],
        "biological_process": [],
        "drug": [],
        "disease": []
    },
    "uploaded_files": [],
    "topk_nodes": 3,
    "topk_edges": 3,
    "dic_source_graph": [
        {
            "name": "BioBridge",
            "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
            "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
        }
    ],
}

In [3]:
# Retrieve source graph from the state
initial_graph = {}
initial_graph["source"] = agent_state["dic_source_graph"][-1]  # The last source graph as of now

# Load the knowledge graph
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
    initial_graph["pyg"] = pickle.load(f)


In [4]:
pyg_graph = initial_graph["pyg"]

# Convert PyG graph to a DataFrame for easier filtering
graph_df = pd.DataFrame({
    "node_id": pyg_graph.node_id,
    "node_name": pyg_graph.node_name,
    "node_type": pyg_graph.node_type,
    "x": pyg_graph.x,
    "desc_x": pyg_graph.desc_x.tolist(),
})

In [5]:
# Get statistics of the graph by node type
graph_df.node_type.value_counts()

node_type
biological_process    1615
drug                   748
molecular_function     317
cellular_component     202
gene/protein           102
disease                  7
Name: count, dtype: int64

In [6]:
# Update state
agent_state["uploaded_files"] = [
    {
        "file_name": "multimodal-analysis.xlsx",
        "file_path": f"{DATA_PATH}/multimodal-analysis_single_gene.xlsx",
        "file_type": "multimodal",
        "uploaded_by": "VPEUser",
        "uploaded_timestamp": "2025-05-12 00:00:00",
    }
]

In [7]:
prompt = """
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
(IBD) that existed in the knowledge graph.
Please set the extraction name for this process as `subkg_12345`.
"""

In [8]:
# Instantiate the SubgraphExtractionTool
subgraph_extraction_tool = MultimodalSubgraphExtractionTool()

In [9]:
%%time

# Invoking the subgraph_extraction_tool
response = subgraph_extraction_tool.invoke(
    input={"prompt": prompt,
            "tool_call_id": "subgraph_extraction_tool",
            "state": agent_state,
            "arg_data": {"extraction_name": "subkg_12345"}})

INFO:aiagents4pharma.talk2knowledgegraphs.tools.multimodal_subgraph_extraction:Invoking subgraph_extraction tool
INFO:httpx:HTTP Request: GET http://127.0.0.1:11434/api/tags "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"


CPU times: user 4.62 s, sys: 548 ms, total: 5.17 s
Wall time: 5.28 s


In [10]:
from pyinstrument import Profiler

profiler = Profiler()
profiler.start()
subgraph_extraction_tool.invoke(input={"prompt": prompt,
                                       "tool_call_id": "subgraph_extraction_tool",
                                       "state": agent_state,
                                       "arg_data": {"extraction_name": "subkg_12345"}})
profiler.stop()

profiler.open_in_browser()


INFO:aiagents4pharma.talk2knowledgegraphs.tools.multimodal_subgraph_extraction:Invoking subgraph_extraction tool
INFO:httpx:HTTP Request: GET http://127.0.0.1:11434/api/tags "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"


'/tmp/tmp4pwstklq.html'