In [1]:
from typing import Type, Annotated
import logging
import pickle
import numpy as np
import pandas as pd
import hydra
import networkx as nx
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command
from langgraph.prebuilt import InjectedState
import torch
from torch_geometric.data import Data

import sys
sys.path.append("../../../")
from aiagents4pharma.talk2knowledgegraphs.utils.extractions.multimodal_pcst import MultimodalPCSTPruning
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.ollama import EmbeddingWithOllama

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def _prepare_query_modalities(prompt_emb: list,
                              state: Annotated[dict, InjectedState],
                              pyg_graph: Data) -> pd.DataFrame:
    """
    Prepare the modality-specific query for subgraph extraction.

    Args:
        prompt_emb: The embedding of the user prompt in a list.
        state: The injected state for the tool.
        pyg_graph: The PyTorch Geometric graph Data.

    Returns:
        A DataFrame containing the query embeddings and modalities.
    """
    # Initialize dataframes
    multimodal_df = pd.DataFrame({"name": []})
    query_df = pd.DataFrame({"node_id": [],
                                "node_type": [],
                                "x": [],
                                "desc_x": [],
                                "use_description": []})

    # Loop over the uploaded files and find multimodal files
    for i in range(len(state["uploaded_files"])):
        # Check if multimodal file is uploaded
        if state["uploaded_files"][i]["file_type"] == "multimodal":
            # Read the Excel file
            multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
                                            sheet_name=None)

    # Check if the multimodal_df is empty
    if len(multimodal_df) > 0:
        # Merge all obtained dataframes into a single dataframe
        multimodal_df = pd.concat(multimodal_df).reset_index()
        multimodal_df.drop(columns=["level_1"], inplace=True)
        multimodal_df.rename(columns={"level_0": "q_node_type",
                                    "name": "q_node_name"}, inplace=True)
        # Since an excel sheet name could not contain a `/`,
        # but the node type can be 'gene/protein' as exists in the PrimeKG
        multimodal_df["q_node_type"] = multimodal_df.q_node_type.apply(
            lambda x: x.replace('-', '/')
        )

        # Convert PyG graph to a DataFrame for easier filtering
        graph_df = pd.DataFrame({
            "node_id": pyg_graph.node_id,
            "node_name": pyg_graph.node_name,
            "node_type": pyg_graph.node_type,
            "x": pyg_graph.x,
            "desc_x": pyg_graph.desc_x.tolist(),
        })

        # Make a query dataframe by merging the graph_df and multimodal_df
        query_df = graph_df.merge(multimodal_df, how='cross')
        query_df = query_df[
            query_df.apply(
                lambda x:
                (x['q_node_name'].lower() in x['node_name'].lower()) & # node name
                (x['node_type'] == x['q_node_type']), # node type
                axis=1
            )
        ]
        query_df = query_df[['node_id', 'node_type', 'x', 'desc_x']].reset_index(drop=True)
        query_df['use_description'] = False # set to False for modal-specific embeddings

        # Update the state by adding the the selected node IDs
        state["selections"] = query_df.groupby("node_type")["node_id"].apply(list).to_dict()

    # Append a user prompt to the query dataframe
    query_df = pd.concat([
        query_df,
        pd.DataFrame({
            'node_id': 'user_prompt',
            'node_type': 'prompt',
            'x': prompt_emb,
            'desc_x': prompt_emb,
            'use_description': True # set to True for user prompt embedding
        })
    ]).reset_index(drop=True)

    return query_df

def _perform_subgraph_extraction(state: Annotated[dict, InjectedState],
                                 cfg: dict,
                                 pyg_graph: Data,
                                 query_df: pd.DataFrame) -> dict:
    """
    Perform multimodal subgraph extraction based on modal-specific embeddings.

    Args:
        state: The injected state for the tool.
        cfg: The configuration dictionary.
        pyg_graph: The PyTorch Geometric graph Data.
        query_df: The DataFrame containing the query embeddings and modalities.

    Returns:
        A dictionary containing the extracted subgraph with nodes and edges.
    """
    # Initialize the subgraph dictionary
    subgraphs = {}
    subgraphs["nodes"] = []
    subgraphs["edges"] = []

    # Loop over query embeddings and modalities
    for q in query_df.iterrows():
        # Prepare the PCSTPruning object and extract the subgraph
        # Parameters were set in the configuration file obtained from Hydra
        subgraph = MultimodalPCSTPruning(
            topk=state["topk_nodes"],
            topk_e=state["topk_edges"],
            cost_e=cfg.cost_e,
            c_const=cfg.c_const,
            root=cfg.root,
            num_clusters=cfg.num_clusters,
            pruning=cfg.pruning,
            verbosity_level=cfg.verbosity_level,
            use_description=q[1]['use_description'],
        ).extract_subgraph(pyg_graph,
                            torch.tensor(q[1]['desc_x']), # description embedding
                            torch.tensor(q[1]['x']), # modal-specific embedding
                            q[1]['node_type'])

        # Append the extracted subgraph to the dictionary
        subgraphs["nodes"].append(subgraph["nodes"].tolist())
        subgraphs["edges"].append(subgraph["edges"].tolist())

    # Concatenate and get unique node and edge indices
    subgraphs["nodes"] = np.unique(
        np.concatenate([np.array(list_) for list_ in subgraphs["nodes"]])
    )
    subgraphs["edges"] = np.unique(
        np.concatenate([np.array(list_) for list_ in subgraphs["edges"]])
    )

    return subgraphs

def _prepare_final_subgraph(state:Annotated[dict, InjectedState],
                            subgraph: dict,
                            graph: dict,
                            cfg) -> dict:
    """
    Prepare the subgraph based on the extracted subgraph.

    Args:
        state: The injected state for the tool.
        subgraph: The extracted subgraph.
        graph: The initial graph containing PyG and textualized graph.
        cfg: The configuration dictionary.

    Returns:
        A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
    """
    # print(subgraph)
    # Prepare the PyTorch Geometric graph
    mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
    pyg_graph = Data(
        # Node features
        # x=pyg_graph.x[subgraph["nodes"]],
        x=[graph["pyg"].x[i] for i in subgraph["nodes"]],
        node_id=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
        node_name=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
        enriched_node=np.array(graph["pyg"].enriched_node)[subgraph["nodes"]].tolist(),
        num_nodes=len(subgraph["nodes"]),
        # Edge features
        edge_index=torch.LongTensor(
            [
                [
                    mapping[i]
                    for i in graph["pyg"].edge_index[:, subgraph["edges"]][0].tolist()
                ],
                [
                    mapping[i]
                    for i in graph["pyg"].edge_index[:, subgraph["edges"]][1].tolist()
                ],
            ]
        ),
        edge_attr=graph["pyg"].edge_attr[subgraph["edges"]],
        edge_type=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
        relation=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
        label=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
        enriched_edge=np.array(graph["pyg"].enriched_edge)[subgraph["edges"]].tolist(),
    )

    # Networkx DiGraph construction to be visualized in the frontend
    nx_graph = nx.DiGraph()
    # Add nodes with attributes
    node_colors = {n: cfg.node_colors_dict[k]
                    for k, v in state["selections"].items() for n in v}
    for n in pyg_graph.node_name:
        nx_graph.add_node(n, color=node_colors.get(n, None))

    # Add edges with attributes
    edges = zip(
        pyg_graph.edge_index[0].tolist(),
        pyg_graph.edge_index[1].tolist(),
        pyg_graph.edge_type
    )
    for src, dst, edge_type in edges:
        nx_graph.add_edge(
            pyg_graph.node_name[src],
            pyg_graph.node_name[dst],
            relation=edge_type,
            label=edge_type,
        )

    # Prepare the textualized subgraph
    # textualized_graph = (
    #     graph["text"]["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
    #     + "\n"
    #     + graph["text"]["edges"].iloc[subgraph["edges"]].to_csv(index=False)
    # )

    return {
        "graph_pyg": pyg_graph,
        "graph_nx": nx_graph,
        # "graph_text": textualized_graph,
    }

In [3]:
# Define the data path
DATA_PATH = "../../../aiagents4pharma/talk2knowledgegraphs/tests/files"

# Define the agent state
state = {
    "selections": {
        "gene/protein": [],
        "molecular_function": [],
        "cellular_component": [],
        "biological_process": [],
        "drug": [],
        "disease": []
    },
    "uploaded_files": [],
    "topk_nodes": 3,
    "topk_edges": 3,
    "dic_source_graph": [
        {
            "name": "BioBridge",
            "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
            "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
        }
    ],
}

In [4]:
# Parameters
prompt = """
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
(IBD) that existed in the knowledge graph.
Please set the extraction name for this process as `subkg_12345`.
"""
tool_call_id = "subgraph_extraction_tool"
extraction_name = "subgraph_12345"

In [5]:
%%time

# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.log(logging.INFO, "Invoking subgraph_extraction tool")

# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction

# Retrieve source graph from the state
initial_graph = {}
initial_graph["source"] = state["dic_source_graph"][-1]  # The last source graph as of now
# logger.log(logging.INFO, "Source graph: %s", source_graph)

# Load the knowledge graph
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
    initial_graph["pyg"] = pickle.load(f)
# with open(initial_graph["source"]["kg_text_path"], "rb") as f:
#     initial_graph["text"] = pickle.load(f)

# Prepare the query embeddings and modalities
query_df = _prepare_query_modalities(
    [EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)],
    state,
    initial_graph["pyg"]
)

# Perform subgraph extraction
subgraphs = _perform_subgraph_extraction(state,
                                                cfg,
                                                initial_graph["pyg"],
                                                query_df)

# Prepare subgraph as a NetworkX graph and textualized graph
final_subgraph = _prepare_final_subgraph(state,
                                                subgraphs,
                                                initial_graph,
                                                cfg)

# Prepare the dictionary of extracted graph
dic_extracted_graph = {
    "name": extraction_name,
    "tool_call_id": tool_call_id,
    "graph_source": initial_graph["source"]["name"],
    "topk_nodes": state["topk_nodes"],
    "topk_edges": state["topk_edges"],
    "graph_dict": {
        "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
        "edges": list(final_subgraph["graph_nx"].edges(data=True)),
    },
    # "graph_text": final_subgraph["graph_text"],
    "graph_summary": None,
}

# Prepare the dictionary of updated state
dic_updated_state_for_model = {}
for key, value in {
    "dic_extracted_graph": [dic_extracted_graph],
}.items():
    if value:
        dic_updated_state_for_model[key] = value

INFO:__main__:Invoking subgraph_extraction tool
INFO:httpx:HTTP Request: GET http://127.0.0.1:11434/api/tags "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"


CPU times: user 982 ms, sys: 261 ms, total: 1.24 s
Wall time: 1.03 s


## CProfile

In [6]:
import cProfile

with cProfile.Profile() as pr:

    # Initialize logger
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    logger.log(logging.INFO, "Invoking subgraph_extraction tool")

    # Load hydra configuration
    with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
        cfg = hydra.compose(
            config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
        )
        cfg = cfg.tools.multimodal_subgraph_extraction

    # Retrieve source graph from the state
    initial_graph = {}
    initial_graph["source"] = state["dic_source_graph"][-1]  # The last source graph as of now
    # logger.log(logging.INFO, "Source graph: %s", source_graph)

    # Load the knowledge graph
    with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
        initial_graph["pyg"] = pickle.load(f)
    # with open(initial_graph["source"]["kg_text_path"], "rb") as f:
    #     initial_graph["text"] = pickle.load(f)

    # Prepare the query embeddings and modalities
    query_df = _prepare_query_modalities(
        [EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)],
        state,
        initial_graph["pyg"]
    )

    # Perform subgraph extraction
    subgraphs = _perform_subgraph_extraction(state,
                                                    cfg,
                                                    initial_graph["pyg"],
                                                    query_df)

    # Prepare subgraph as a NetworkX graph and textualized graph
    final_subgraph = _prepare_final_subgraph(state,
                                                    subgraphs,
                                                    initial_graph,
                                                    cfg)

    # Prepare the dictionary of extracted graph
    dic_extracted_graph = {
        "name": extraction_name,
        "tool_call_id": tool_call_id,
        "graph_source": initial_graph["source"]["name"],
        "topk_nodes": state["topk_nodes"],
        "topk_edges": state["topk_edges"],
        "graph_dict": {
            "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
            "edges": list(final_subgraph["graph_nx"].edges(data=True)),
        },
        # "graph_text": final_subgraph["graph_text"],
        "graph_summary": None,
    }

    # Prepare the dictionary of updated state
    dic_updated_state_for_model = {}
    for key, value in {
        "dic_extracted_graph": [dic_extracted_graph],
    }.items():
        if value:
            dic_updated_state_for_model[key] = value

INFO:__main__:Invoking subgraph_extraction tool
INFO:httpx:HTTP Request: GET http://127.0.0.1:11434/api/tags "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"


In [8]:
pr.print_stats()

         847091 function calls (832110 primitive calls) in 0.787 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.024    0.024 2555975276.py:1(_prepare_query_modalities)
        1    0.003    0.003    0.030    0.030 2555975276.py:137(_prepare_final_subgraph)
        1    0.000    0.000    0.452    0.452 2555975276.py:83(_perform_subgraph_extraction)
       15    0.000    0.000    0.000    0.000 <frozen _collections_abc>:1022(__iter__)
       20    0.000    0.000    0.000    0.000 <frozen _collections_abc>:306(__iter__)
      111    0.000    0.000    0.000    0.000 <frozen _collections_abc>:435(__subclasshook__)
       13    0.000    0.000    0.000    0.000 <frozen _collections_abc>:804(get)
       26    0.000    0.000    0.000    0.000 <frozen _collections_abc>:811(__contains__)
        2    0.000    0.000    0.000    0.000 <frozen _collections_abc>:819(keys)
        2    0.000    0.000    0.

## PyInstrument

In [9]:
from pyinstrument import Profiler

profiler = Profiler()
profiler.start()

# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.log(logging.INFO, "Invoking subgraph_extraction tool")

# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction

# Retrieve source graph from the state
initial_graph = {}
initial_graph["source"] = state["dic_source_graph"][-1]  # The last source graph as of now
# logger.log(logging.INFO, "Source graph: %s", source_graph)

# Load the knowledge graph
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
    initial_graph["pyg"] = pickle.load(f)
# with open(initial_graph["source"]["kg_text_path"], "rb") as f:
#     initial_graph["text"] = pickle.load(f)

# Prepare the query embeddings and modalities
query_df = _prepare_query_modalities(
    [EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)],
    state,
    initial_graph["pyg"]
)

# Perform subgraph extraction
subgraphs = _perform_subgraph_extraction(state,
                                                cfg,
                                                initial_graph["pyg"],
                                                query_df)

# Prepare subgraph as a NetworkX graph and textualized graph
final_subgraph = _prepare_final_subgraph(state,
                                                subgraphs,
                                                initial_graph,
                                                cfg)

# Prepare the dictionary of extracted graph
dic_extracted_graph = {
    "name": extraction_name,
    "tool_call_id": tool_call_id,
    "graph_source": initial_graph["source"]["name"],
    "topk_nodes": state["topk_nodes"],
    "topk_edges": state["topk_edges"],
    "graph_dict": {
        "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
        "edges": list(final_subgraph["graph_nx"].edges(data=True)),
    },
    # "graph_text": final_subgraph["graph_text"],
    "graph_summary": None,
}

# Prepare the dictionary of updated state
dic_updated_state_for_model = {}
for key, value in {
    "dic_extracted_graph": [dic_extracted_graph],
}.items():
    if value:
        dic_updated_state_for_model[key] = value

profiler.stop()

profiler.open_in_browser()

INFO:__main__:Invoking subgraph_extraction tool
INFO:httpx:HTTP Request: GET http://127.0.0.1:11434/api/tags "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"


'/tmp/tmp9mfb829b.html'