In [1]:
# Import necessary libraries
import os
import networkx as nx
import openai 
import matplotlib.pyplot as plt
import sys
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import hydra
import numpy as np
# import pandas as pd
import cudf
import pickle
import torch

sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.utils.extractions.multimodal_pcst import MultimodalPCSTPruning
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.ollama import EmbeddingWithOllama
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.sentence_transformer import EmbeddingWithSentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
os.environ["OPENAI_API_KEY"] = "XXX"
# Make sure to replace "your_api_key" with your actual API key.

In [3]:
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [4]:
# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs"):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction
cfg

{'_target_': 'talk2knowledgegraphs.tools.multimodal_subgraph_extraction', 'ollama_embeddings': ['nomic-embed-text'], 'temperature': 0.1, 'streaming': False, 'topk': 5, 'topk_e': 5, 'cost_e': 0.5, 'c_const': 0.01, 'root': -1, 'num_clusters': 1, 'pruning': 'gw', 'verbosity_level': 0, 'node_id_column': 'node_id', 'node_attr_column': 'node_attr', 'edge_src_column': 'edge_src', 'edge_attr_column': 'edge_attr', 'edge_dst_column': 'edge_dst', 'node_colors_dict': {'gene/protein': '#6a79f7', 'molecular_function': '#82cafc', 'cellular_component': '#3f9b0b', 'biological_process': '#c5c9c7', 'drug': '#c4a661', 'disease': '#80013f'}}

In [5]:
# Define state
state = {
    "llm_model": ChatOpenAI(model="gpt-4o-mini", temperature=0.0),
    "embedding_model": OpenAIEmbeddings(model="text-embedding-3-small"),
    "selected_genes": [], #["IL6_(1567)", "IL21_(34967)", "TNF_(2329)"],
    "selected_drugs": [], #["Remdesivir_(15267)", "Mesalazine_(15876)"],
    "uploaded_files": [
        {
            "file_name": "multimodal-analysis.csv",
            "file_path": '../../../aiagents4pharma/talk2knowledgegraphs/tests/files/multimodal-analysis.csv',
            "file_type": "multimodal",
            "uploaded_by": "VPEUser",
            "uploaded_timestamp": "2024-11-05 00:00:00",
        },
    ],
    "topk_nodes": 10,
    "topk_edges": 10,
    "dic_source_graph": [
        {
            "name": "PrimeKG",
            "kg_pyg_path": "../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_pyg_graph.pkl",
            "kg_text_path": "../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_text_graph.pkl",
        }
    ],
    "dic_extracted_graph": []
}

# Define prompt
prompt = """
Extract all relevant information related to nodes of genes related to inflammatory bowel disease (IBD) 
that existed in the knowledge graph.

Please set the extraction name for this process as `subkg_12345`.
"""

In [6]:
# Retrieve source graph from the state
initial_graph = {}
initial_graph["source"] = state["dic_source_graph"][-1]  # The last source graph as of now
# logger.log(logging.INFO, "Source graph: %s", source_graph)

# Load the knowledge graph
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
    initial_graph["pyg"] = pickle.load(f)
# with open(initial_graph["source"]["kg_text_path"], "rb") as f:
#     initial_graph["text"] = pickle.load(f)

pyg_graph = initial_graph["pyg"]

In [7]:
prompt_emb = [EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)]

INFO:httpx:HTTP Request: GET http://127.0.0.1:11434/api/tags "HTTP/1.1 200 OK"


INFO:httpx:HTTP Request: POST http://127.0.0.1:11434/api/embed "HTTP/1.1 200 OK"


In [8]:
# Load the data from the parquet files
local_dir = '../../../aiagents4pharma/talk2knowledgegraphs/tests/files'
nodes_df = cudf.read_parquet(os.path.join(local_dir, 'biobridge_nodes.parquet.gzip'))
edges_df = cudf.read_parquet(os.path.join(local_dir, 'biobridge_edges.parquet.gzip'))

In [9]:
# Initialize dataframes
multimodal_df = cudf.DataFrame({"name": [], "node_type": []})
query_df = cudf.DataFrame({"node_id": [],
                            "node_type": [],
                            "x": [],
                            "desc_x": [],
                            "use_description": []})

# Loop over the uploaded files and find multimodal files
for i in range(len(state["uploaded_files"])):
    # Check if multimodal file is uploaded
    if state["uploaded_files"][i]["file_type"] == "multimodal":
        # Read the Excel file
        multimodal_df = cudf.read_csv(state["uploaded_files"][i]["file_path"])

# Check if the multimodal_df is empty
if len(multimodal_df) > 0:
    # Merge all obtained dataframes into a single dataframe
    multimodal_df.rename(columns={"name": "q_node_name", "node_type": "q_node_type"}, inplace=True)

    # Make and process a query dataframe by merging the graph_df and multimodal_df
    query_df = nodes_df[['node_id', 'node_name', 'node_type', 'enriched_node', 'x', 'desc', 'desc_x']].merge(multimodal_df, how='cross')
    query_df['q_node_name'] = query_df['q_node_name'].str.lower()
    query_df['node_name'] = query_df['node_name'].str.lower()
    # Get the mask for filtering based on the query
    mask = (
        query_df['node_name'].str.contains(query_df['q_node_name']) &
        (query_df['node_type'] == query_df['q_node_type'])
    )
    query_df = query_df[mask]
    query_df = query_df[['node_id', 'node_type', 'enriched_node', 'x', 'desc', 'desc_x']].reset_index(drop=True)
    query_df['use_description'] = False # set to False for modal-specific embeddings

    # Update the state by adding the the selected node IDs
    state["selections"] = query_df.to_pandas().groupby("node_type")["node_id"].apply(list).to_dict()

# Append a user prompt to the query dataframe
query_df = cudf.concat([
    query_df,
    cudf.DataFrame({
        'node_id': 'user_prompt',
        'node_type': 'prompt',
        # 'enriched_node': prompt,
        'x': prompt_emb,
        # 'desc': prompt,
        'desc_x': prompt_emb,
        'use_description': True # set to True for user prompt embedding
    })
]).reset_index(drop=True)


### Before

In [10]:
from torch_geometric.data import Data

topk = state["topk_nodes"]  
topk_e = state["topk_edges"]
c_const = 0.01

def _compute_node_prizes(graph: Data,
                         query_emb: torch.Tensor,
                         modality: str,
                         use_description: bool=False) :
    """
    Compute the node prizes based on the cosine similarity between the query and nodes.

    Args:
        graph: The knowledge graph in PyTorch Geometric Data format.
        query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
            a prompt, sequence, or any other feature to be used for the subgraph extraction.
        modality: The modality to use for the subgraph extraction based on the node type.

    Returns:
        The prizes of the nodes.
    """
    # Convert PyG graph to a DataFrame
    graph_df = cudf.DataFrame({
        "node_type": graph.node_type,
        "desc_x": [x.tolist() for x in graph.desc_x],
        "x": [list(x) for x in graph.x],
        "score": [0.0 for _ in range(len(graph.node_id))],
    })

    # Calculate cosine similarity for text features and update the score
    if use_description:
        graph_df.loc[:, "score"] = torch.nn.CosineSimilarity(dim=-1)(
                query_emb,
                torch.tensor(list(graph_df.desc_x.values)) # Using textual description features
            ).tolist()
    else:
        graph_df.loc[graph_df["node_type"] == modality,
                        "score"] = torch.nn.CosineSimilarity(dim=-1)(
                query_emb,
                torch.tensor(list(graph_df[graph_df["node_type"]== modality].x.values))
            ).tolist()

    # Set the prizes for nodes based on the similarity scores
    n_prizes = torch.tensor(graph_df.score.values, dtype=torch.float32)
    # n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
    topk = min(topk, graph.num_nodes)
    _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
    n_prizes = torch.zeros_like(n_prizes)
    n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()

    return n_prizes

def _compute_edge_prizes(graph: Data,
                         text_emb: torch.Tensor) :
    """
    Compute the node prizes based on the cosine similarity between the query and nodes.

    Args:
        graph: The knowledge graph in PyTorch Geometric Data format.
        text_emb: The textual description embedding in PyTorch Tensor format.

    Returns:
        The prizes of the nodes.
    """
    # Note that as of now, the edge features are based on textual features
    # Compute prizes for edges
    e_prizes = torch.nn.CosineSimilarity(dim=-1)(text_emb, graph.edge_attr)
    unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
    topk_e = min(topk_e, unique_prizes.size(0))
    topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
    e_prizes[e_prizes < topk_e_values[-1]] = 0.0
    last_topk_e_value = topk_e
    for k in range(topk_e):
        indices = inverse_indices == (
            unique_prizes == topk_e_values[k]
        ).nonzero(as_tuple=True)[0]
        value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
        e_prizes[indices] = value
        last_topk_e_value = value * (1 - c_const)

    return e_prizes

In [156]:
import pandas as pd

graph = initial_graph["pyg"]
text_emb = torch.tensor(query_df.iloc[0]['desc_x'][0])
query_emb = torch.tensor(query_df.iloc[0]['x'][0])
modality = query_df.iloc[0]['node_type'][0]

# Convert PyG graph to a DataFrame
graph_df = pd.DataFrame({
    "node_type": graph.node_type,
    "desc_x": [x.tolist() for x in graph.desc_x],
    "x": [list(x) for x in graph.x],
    "score": [0.0 for _ in range(len(graph.node_id))],
})

graph_df.loc[graph_df["node_type"] == modality, "score"]  = torch.nn.CosineSimilarity(dim=-1)(
        query_emb,
        torch.tensor(list(graph_df[graph_df["node_type"]== modality].x.values))
    ).tolist()

In [164]:
# Set the prizes for nodes based on the similarity scores
n_prizes = torch.tensor(graph_df.score.values, dtype=torch.float32)
topk = min(topk, graph.num_nodes)
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
n_prizes = torch.zeros_like(n_prizes)
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
n_prizes

tensor([0., 0., 0.,  ..., 0., 0., 0.])

### After

In [128]:
import cudf
import cupy as cp
import cuvs
from cuvs.distance import pairwise_distance

# Initialize several variables
sim_scores = cudf.Series(cp.zeros(len(nodes_df), dtype=cp.float32))
text_emb = torch.tensor(query_df.iloc[0]['desc_x'][0]) # torch.Size([768])
query_emb = torch.tensor(query_df.iloc[0]['x'][0]) # torch.Size([2560])
modality = query_df.iloc[0]['node_type'][0] # `gene/protein`

# Compute cosine distance and similarity
mask = (nodes_df.node_type == modality)
cosine_distance = pairwise_distance(cp.array(nodes_df[mask]["x"].to_arrow().to_pylist()).astype(cp.float32), 
                                    cp.array(query_emb.cpu().numpy()).reshape(1, -1).astype(cp.float32), 
                                    metric="cosine")  # shape [N, 1]
cosine_similarity = 1 - cp.asarray(cosine_distance).ravel()

# Store scores in the graph_df
sim_scores[mask] = cosine_similarity

In [133]:
sim_scores.sort_values(ascending=False).index[:10]

Index([5, 845, 12, 82, 22, 849, 850, 59, 9, 15], dtype='int64')

In [165]:
topk_n_indices

tensor([  5, 845,  12,  82,  22, 849, 850,  59,   9,  15])

In [None]:
sim_scores.sort_values()

89      0.0
90      0.0
91      0.0
92      0.0
93      0.0
       ... 
848    <NA>
849    <NA>
850    <NA>
851    <NA>
852    <NA>
Length: 2991, dtype: float32

In [92]:
cp.arange(topk, 0, -1)

array([10,  9,  8,  7,  6,  5,  4,  3,  2,  1])

In [82]:
topk = min(topk, sim_scores.size)
n_prizes_ = cudf.Series(0.0, index=cp.arange(sim_scores.size))
n_prizes_[(-sim_scores).sort_values()[:topk].index] = cp.arange(topk, 0, -1).astype(cp.float32)

In [104]:
sim_scores[sim_scores.isna()]

837    <NA>
838    <NA>
842    <NA>
843    <NA>
844    <NA>
845    <NA>
846    <NA>
847    <NA>
848    <NA>
849    <NA>
850    <NA>
851    <NA>
852    <NA>
dtype: float64

In [94]:
a, topk_n_indices = torch.topk(n_prizes, topk, largest=True)

In [95]:
a

tensor([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.])

In [96]:
topk_n_indices

tensor([ 5, 94, 12, 82, 22, 98, 99, 59,  9, 15])

In [93]:
topk_n_indices

tensor([ 5, 94, 12, 82, 22, 98, 99, 59,  9, 15])

In [84]:
(-sim_scores).sort_values()[:topk].index

Index([5, 12, 82, 22, 59, 9, 15, 44, 11, 62], dtype='int64')

In [85]:
sim_scores[5, 94, 12, 82, 22, 98, 99, 59,  9, 15]

5     1.000000
94    0.000000
12    0.989550
82    0.981857
22    0.980781
98    0.000000
99    0.000000
59    0.979544
9     0.978507
15    0.978297
dtype: float64

In [83]:
a = torch.tensor(n_prizes_.to_arrow().to_pylist()).float()
a[:100]

tensor([ 0.,  0.,  0.,  0.,  0., 10.,  0.,  0.,  0.,  5.,  0.,  2.,  9.,  0.,
         0.,  4.,  0.,  0.,  0.,  0.,  0.,  0.,  7.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  3.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  6.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  8.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.])

In [19]:
sim_scores.to_arrow().to_pylist()

[0.9762842655181885,
 0.8985072374343872,
 0.9174196124076843,
 0.974480390548706,
 0.9627967476844788,
 0.9999995231628418,
 0.968204915523529,
 0.8944925665855408,
 0.9456265568733215,
 0.9785073399543762,
 0.9755682945251465,
 0.977044403553009,
 0.9895495176315308,
 0.9467787742614746,
 0.9746270179748535,
 0.9782971143722534,
 0.9550138711929321,
 0.9491298198699951,
 0.9711894989013672,
 0.9494197368621826,
 0.9682755470275879,
 0.9343744516372681,
 0.9807805418968201,
 0.9611757397651672,
 0.9460437893867493,
 0.9518918395042419,
 0.9323171973228455,
 0.9732439517974854,
 0.8083146214485168,
 0.906587541103363,
 0.9759802222251892,
 0.9287912845611572,
 0.924881100654602,
 0.9658037424087524,
 0.9527099132537842,
 0.7955286502838135,
 0.9495731592178345,
 0.9741109609603882,
 0.9529463648796082,
 0.968661367893219,
 0.9751885533332825,
 0.968982994556427,
 0.97170490026474,
 0.9425373673439026,
 0.9780073165893555,
 0.9757586717605591,
 0.9519517421722412,
 0.9541404247283936,
 

In [None]:
# Set the prizes for nodes based on the similarity scores
n_prizes = torch.tensor(scores, dtype=torch.float32)
topk = min(topk, graph.num_nodes)
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
# n_prizes = torch.zeros_like(n_prizes)
# n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()

In [155]:
min(topk, graph.num_nodes)

10

In [167]:
scores_cp = cp.asarray(scores, dtype=cp.float32)
topk = min(topk, nodes_df.shape[0])
cuvs.selection.select_k(scores_cp, k=topk, select_type='kth_largest')

AttributeError: module 'cuvs' has no attribute 'selection'

In [154]:
n_prizes

tensor([0.9763, 0.8985, 0.9174, 0.9745, 0.9628, 1.0000, 0.9682, 0.8945, 0.9456,
        0.9785, 0.9756, 0.9770, 0.9896, 0.9468, 0.9746, 0.9783, 0.9550, 0.9491,
        0.9712, 0.9494, 0.9683, 0.9344, 0.9808, 0.9612, 0.9460, 0.9519, 0.9323,
        0.9732, 0.8083, 0.9066, 0.9760, 0.9288, 0.9249, 0.9658, 0.9527, 0.7955,
        0.9496, 0.9741, 0.9529, 0.9687, 0.9752, 0.9690, 0.9717, 0.9425, 0.9780,
        0.9758, 0.9520, 0.9541, 0.9324, 0.9445, 0.9207, 0.9587, 0.9603, 0.9509,
        0.9756, 0.9468, 0.9251, 0.9670, 0.9164, 0.9795, 0.9578, 0.9455, 0.9765,
        0.9129, 0.9676, 0.9683, 0.9736, 0.9461, 0.9411, 0.9675, 0.9416, 0.9740,
        0.9711, 0.9286, 0.9672, 0.9546, 0.9340, 0.9451, 0.9537, 0.9714, 0.9391,
        0.9758, 0.9819, 0.9691, 0.9474, 0.9592, 0.9749, 0.9646, 0.9730, 0.9561,
        0.9326, 0.9668, 0.9626, 0.9579, 0.9899, 0.9696, 0.9358, 0.9430, 0.9804,
        0.9797, 0.9782, 0.9713])

In [151]:
n_prizes

tensor([ 0.,  0.,  0.,  0.,  0., 10.,  0.,  0.,  0.,  2.,  0.,  0.,  8.,  0.,
         0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  6.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  3.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  7.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  9.,  0.,  0.,  0.,
         5.,  4.,  0.,  0.])

In [146]:
scores

[0.976284384727478,
 0.8985079526901245,
 0.9174203872680664,
 0.9744806885719299,
 0.9627973437309265,
 1.000000238418579,
 0.9682051539421082,
 0.8944931626319885,
 0.9456267952919006,
 0.978507399559021,
 0.9755687117576599,
 0.9770446419715881,
 0.9895502924919128,
 0.9467793107032776,
 0.9746269583702087,
 0.978297770023346,
 0.9550144076347351,
 0.9491303563117981,
 0.9711899161338806,
 0.949420154094696,
 0.9682762622833252,
 0.9343748092651367,
 0.980780839920044,
 0.961176335811615,
 0.9460442662239075,
 0.9518922567367554,
 0.932317852973938,
 0.9732446074485779,
 0.8083150386810303,
 0.9065879583358765,
 0.9759810566902161,
 0.9287917613983154,
 0.924881637096405,
 0.9658040404319763,
 0.9527105689048767,
 0.7955290079116821,
 0.9495737552642822,
 0.9741115570068359,
 0.9529469609260559,
 0.9686618447303772,
 0.9751890301704407,
 0.9689838886260986,
 0.9717056751251221,
 0.9425379633903503,
 0.9780075550079346,
 0.9757587909698486,
 0.9519528746604919,
 0.954140841960907,
 0

In [143]:
sim_scores[:10]

0    0.976284
1    0.898507
2    0.917420
3    0.974480
4    0.962797
5    1.000000
6    0.968205
7    0.894493
8    0.945627
9    0.978507
dtype: float64

In [96]:

node_mask = (graph.node_type == modality)
len(graph.x[node_mask])

2560

In [111]:
cudf.Series(initial_graph["pyg"].node_type) == 'gene/protein'

0        True
1        True
2        True
3        True
4        True
        ...  
2986    False
2987    False
2988    False
2989    False
2990    False
Length: 2991, dtype: bool

In [94]:
scores[:10], scores_new[:10]

([0.976284384727478,
  0.8985079526901245,
  0.9174203872680664,
  0.9744806885719299,
  0.9627973437309265,
  1.000000238418579,
  0.9682051539421082,
  0.8944931626319885,
  0.9456267952919006,
  0.978507399559021],
 0    0.976284
 1    0.898507
 2    0.917420
 3    0.974480
 4    0.962797
 5    1.000000
 6    0.968205
 7    0.894493
 8    0.945627
 9    0.978507
 dtype: float32)

array([2.3715734e-02, 1.0149276e-01, 8.2580388e-02, 2.5519609e-02,
       3.7203252e-02, 4.7683716e-07, 3.1795084e-02, 1.0550743e-01,
       5.4373443e-02, 2.1492660e-02, 2.4431705e-02, 2.2955596e-02,
       1.0450482e-02, 5.3221226e-02, 2.5372982e-02, 2.1702886e-02,
       4.4986129e-02, 5.0870180e-02, 2.8810501e-02, 5.0580263e-02,
       3.1724453e-02, 6.5625548e-02, 1.9219458e-02, 3.8824260e-02,
       5.3956211e-02, 4.8108160e-02, 6.7682803e-02, 2.6756048e-02,
       1.9168538e-01, 9.3412459e-02, 2.4019778e-02, 7.1208715e-02,
       7.5118899e-02, 3.4196258e-02, 4.7290087e-02, 2.0447135e-01,
       5.0426841e-02, 2.5889039e-02, 4.7053635e-02, 3.1338632e-02,
       2.4811447e-02, 3.1017005e-02, 2.8295100e-02, 5.7462633e-02,
       2.1992683e-02, 2.4241328e-02, 4.8048258e-02, 4.5859575e-02,
       6.7596853e-02, 5.5534005e-02, 7.9296470e-02, 4.1286826e-02,
       3.9689958e-02, 4.9071908e-02, 2.4416745e-02, 5.3175926e-02,
       7.4861526e-02, 3.3016801e-02, 8.3598554e-02, 2.0456493e

In [40]:
graph_df.loc[graph_df["node_type"] == modality,
                "score"] = torch.nn.CosineSimilarity(dim=-1)(
        query_emb,
        torch.tensor(list(graph_df[graph_df["node_type"]== modality].x.values))
    ).tolist()

TypeError: *obj* doesn't implement the cuda array interface.

In [33]:
graph = initial_graph["pyg"]
text_emb = torch.tensor(query_df.iloc[0]['desc_x'][0])
query_emb = torch.tensor(query_df.iloc[0]['x'][0])
modality = query_df.iloc[0]['node_type'][0]

# Compute prizes for nodes
n_prizes = _compute_node_prizes(graph, query_emb, modality)

# Compute prizes for edges
e_prizes = _compute_edge_prizes(graph, text_emb)

TypeError: *obj* doesn't implement the cuda array interface.

In [10]:
# Initialize the subgraph dictionary
subgraphs = {}
subgraphs["nodes"] = []
subgraphs["edges"] = []

# Loop over query embeddings and modalities
for q in query_df.to_pandas().iterrows():
    # Prepare the PCSTPruning object and extract the subgraph
    # Parameters were set in the configuration file obtained from Hydra
    subgraph = MultimodalPCSTPruning(
        topk=state["topk_nodes"],
        topk_e=state["topk_edges"],
        cost_e=cfg.cost_e,
        c_const=cfg.c_const,
        root=cfg.root,
        num_clusters=cfg.num_clusters,
        pruning=cfg.pruning,
        verbosity_level=cfg.verbosity_level,
        use_description=q[1]['use_description'],
    ).extract_subgraph(pyg_graph,
                        torch.tensor(q[1]['desc_x']), # description embedding
                        torch.tensor(q[1]['x']), # modal-specific embedding
                        q[1]['node_type'])

    # Append the extracted subgraph to the dictionary
    subgraphs["nodes"].append(subgraph["nodes"].tolist())
    subgraphs["edges"].append(subgraph["edges"].tolist())

# Concatenate and get unique node and edge indices
subgraphs["nodes"] = np.unique(
    np.concatenate([np.array(list_) for list_ in subgraphs["nodes"]])
)
subgraphs["edges"] = np.unique(
    np.concatenate([np.array(list_) for list_ in subgraphs["edges"]])
)

In [13]:
# Before optimiziation
before_ = subgraphs
before_

{'nodes': array([   3,    5,    9,   10,   12,   15,   22,   27,   30,   31,   40,
          44,   54,   59,   62,   63,   71,   72,   73,   82,   87,  839,
         840,  841,  845,  846,  849,  850,  851,  853,  854,  855,  869,
        1176, 1320, 1372, 1415, 1734, 1747, 1918, 1933, 1952, 2102, 2113,
        2263, 2435, 2447, 2479, 2779, 2827, 2886]),
 'edges': array([  166,   167,   272,   273,   276,   277,   291,   292,   296,
          299,   301,   302,   553,  1922,  2055,  2097,  2703,  2930,
         2931,  2932,  3930,  3936,  3938,  3940,  3944,  3950,  3951,
         3952,  3953,  3954,  4410,  4671,  4683,  4703,  5270,  6303,
         6320,  6357,  6363,  6372,  6381,  6398,  6490,  6564,  6575,
         6712,  6743,  6819,  9015,  9140,  9142,  9143,  9597,  9694,
        11056, 11070, 11078])}