In [None]:
# !pip install -r requirements.txt

In [1]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from torch_geometric.data import Data
import json
import gensim
from torch import nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader
from pcst_fast import pcst_fast
import warnings

# Load embedding modules

In [3]:
pretrained_repo = 'sentence-transformers/all-roberta-large-v1'
batch_size = 1024  # Adjust the batch size as needed

class Dataset(torch.utils.data.Dataset):
    def __init__(self, input_ids=None, attention_mask=None):
        super().__init__()
        self.data = {
            "input_ids": input_ids,
            "att_mask": attention_mask,
        }

    def __len__(self):
        return self.data["input_ids"].size(0)

    def __getitem__(self, index):
        if isinstance(index, torch.Tensor):
            index = index.item()
        batch_data = dict()
        for key in self.data.keys():
            if self.data[key] is not None:
                batch_data[key] = self.data[key][index]
        return batch_data
        
class Sentence_Transformer(nn.Module):

    def __init__(self, pretrained_repo):
        super(Sentence_Transformer, self).__init__()
        print(f"inherit model weights from {pretrained_repo}")
        self.bert_model = AutoModel.from_pretrained(pretrained_repo)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        data_type = token_embeddings.dtype
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(data_type)
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def forward(self, input_ids, att_mask):
        bert_out = self.bert_model(input_ids=input_ids, attention_mask=att_mask)
        sentence_embeddings = self.mean_pooling(bert_out, att_mask)

        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings

def load_sbert():

    model = Sentence_Transformer(pretrained_repo)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_repo)

    # data parallel
    if torch.cuda.device_count() > 1:
        print(f'Using {torch.cuda.device_count()} GPUs')
        model = nn.DataParallel(model)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    return model, tokenizer, device


def sbert_text2embedding(model, tokenizer, device, text):
    if len(text) == 0:
        return torch.zeros((0, 1024))

    encoding = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
    dataset = Dataset(input_ids=encoding.input_ids, attention_mask=encoding.attention_mask)

    # DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Placeholder for storing the embeddings
    all_embeddings = []

    # Iterate through batches
    with torch.no_grad():

        for batch in dataloader:
            # Move batch to the appropriate device
            batch = {key: value.to(device) for key, value in batch.items()}

            # Forward pass
            embeddings = model(input_ids=batch["input_ids"], att_mask=batch["att_mask"])

            # Append the embeddings to the list
            all_embeddings.append(embeddings)

    # Concatenate the embeddings from all batches
    all_embeddings = torch.cat(all_embeddings, dim=0).cpu()

    return all_embeddings

load_model = {
    'sbert': load_sbert,
}

load_text2embedding = {
    'sbert': sbert_text2embedding,
}

# Preprocess the data

In [5]:
model_name = 'sbert'
path = '.'
path_nodes = f'{path}/nodes'
path_edges = f'{path}/edges'
path_graphs = f'{path}/graphs'

In [7]:
from pymilvus import MilvusClient, DataType, FieldSchema, CollectionSchema

CLUSTER_ENDPOINT = (
    "https://in03-7a5f9d2a1aa84ef.serverless.gcp-us-west1.cloud.zilliz.com"
)
API_KEY = (
    "a73c79fb1924d05aeb410abc0d5669293cc33be37a123953be640725aa42198ef5c1e499cc07f231977c742ad6e6977c6eddec05"
)

milvus_client = MilvusClient(uri=CLUSTER_ENDPOINT, token=API_KEY)

COLLECTION_NAME = "graph_embeddings"
VECTOR_DIM      = 1024          # SBERT default; change if yours differs
METRIC_TYPE     = "COSINE"     # or "IP" for inner-product

In [9]:
# Drop the collection if it already exists (for a clean slate)
if COLLECTION_NAME in milvus_client.list_collections():
    milvus_client.drop_collection(COLLECTION_NAME)

# Define the schema
schema = CollectionSchema(
    fields=[
        FieldSchema(
            name="graph_id",
            dtype=DataType.INT64,
            is_primary=True,
            auto_id=False
        ),
        FieldSchema(
            name="embedding",
            dtype=DataType.FLOAT_VECTOR,
            dim=VECTOR_DIM
        ),
        FieldSchema(                 # optional metadata field
            name="graph_idx",
            dtype=DataType.INT64
        )
    ],
    description="Mean-pooled SBERT embeddings of knowledge graphs"
)

index_params = milvus_client.prepare_index_params()
index_params.add_index("embedding", index_type="IVF_FLAT", metric_type="COSINE", index_params={"nlist": 64})
milvus_client.create_collection(
    collection_name=COLLECTION_NAME,
    schema=schema,
    consistency_level="Strong",
    index_params=index_params
)

In [17]:
def preprocessing_step_one(path):
    # Load the graphs from the JSON file
    with open(path+'graphs.json', 'r') as f:
        graphs = json.load(f)
    
    # Create directories if they don't exist
    os.makedirs(path_nodes, exist_ok=True)
    os.makedirs(path_edges, exist_ok=True)

    # Process each graph
    for i, triples in enumerate(tqdm(graphs)):
        node_map = {}   # Maps node label → node ID
        edges = []
    
        for h, r, t in triples:
            h = h.lower()
            t = t.lower()
            if h not in node_map:
                node_map[h] = len(node_map)
            if t not in node_map:
                node_map[t] = len(node_map)
            edges.append({'src': node_map[h], 'edge_attr': r, 'dst': node_map[t]})
    
        # Convert node map to DataFrame
        nodes_df = pd.DataFrame(
            [{'node_id': v, 'node_attr': k} for k, v in node_map.items()],
            columns=['node_id', 'node_attr']
        )
    
        # Convert edge list to DataFrame
        edges_df = pd.DataFrame(edges, columns=['src', 'edge_attr', 'dst'])
    
        # Save to CSV
        nodes_df.to_csv(f'{path_nodes}/{i}.csv', index=False)
        edges_df.to_csv(f'{path_edges}/{i}.csv', index=False)

preprocessing_step_one('./spo/')

100%|██████████| 5/5 [00:00<00:00, 161.29it/s]


In [21]:
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download")

def preprocessing_step_two(path):
    print("Loading local knowledge base...")
    with open(path+'graphs.json', 'r') as f:
        graphs = json.load(f)

    model, tokenizer, device = load_model[model_name]()
    text2embedding = load_text2embedding[model_name]

    print("Embedding and storing graphs in milvusDB...")
    os.makedirs(path_graphs, exist_ok=True)

    milvus_vectors = []  

    for index in tqdm(range(len(graphs))):
        # --- Load nodes & edges ---
        nodes_path = f'{path_nodes}/{index}.csv'
        edges_path = f'{path_edges}/{index}.csv'
        if not os.path.exists(nodes_path) or not os.path.exists(edges_path):
            print(f'Skipping graph {index} (missing files)')
            continue

        nodes = pd.read_csv(nodes_path)
        edges = pd.read_csv(edges_path)
        nodes.fillna({"node_attr": ""}, inplace=True)

        if len(nodes) == 0:
            print(f'Empty graph at index {index}')
            continue

        # --- Embed node and edge attributes ---
        x = text2embedding(model, tokenizer, device, nodes.node_attr.tolist())
        edge_attr = text2embedding(model, tokenizer, device, edges.edge_attr.tolist())
        edge_index = torch.LongTensor([edges.src.tolist(), edges.dst.tolist()])

        # --- Save graph as torch_geometric.Data ---
        pyg_graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=len(nodes))
        torch.save(pyg_graph, f'{path_graphs}/{index}.pt')

        # --- Compute graph-level embedding (mean of node embeddings) ---
        graph_embedding = torch.mean(x, dim=0).cpu().tolist()

        # --- Store in Milvus format: [graph_id, embedding, index] ---
        milvus_vectors.append({"graph_id": index, "embedding": graph_embedding, "graph_idx": index})

    # --- Final batch insert into Milvus ---
    if milvus_vectors:
        milvus_client.insert(
            collection_name=COLLECTION_NAME,
            data=milvus_vectors,
            auto_id=False
        )
        milvus_client.flush(COLLECTION_NAME)
        print(f"Inserted {len(milvus_vectors)} graph embeddings into Milvus.")
    else:
        print("No graphs were inserted into Milvus.")

preprocessing_step_two('./spo/')

Loading local knowledge base...
inherit model weights from sentence-transformers/all-roberta-large-v1
Embedding and storing graphs in milvusDB...


100%|██████████| 5/5 [02:16<00:00, 27.30s/it]


Inserted 5 graph embeddings into Milvus.


In [23]:
def retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk=3, topk_e=3, cost_e=0.5):
    c = 0.01
    if len(textual_nodes) == 0 or len(textual_edges) == 0:
        desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
        graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
        return graph, desc

    root = -1  # unrooted
    num_clusters = 1
    pruning = 'gw'
    verbosity_level = 0
    if topk > 0:
        n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.x)
        topk = min(topk, graph.num_nodes)
        _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)

        n_prizes = torch.zeros_like(n_prizes)
        n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
    else:
        n_prizes = torch.zeros(graph.num_nodes)

    if topk_e > 0:
        e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr)
        topk_e = min(topk_e, e_prizes.unique().size(0))

        topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
        e_prizes[e_prizes < topk_e_values[-1]] = 0.0
        last_topk_e_value = topk_e
        for k in range(topk_e):
            indices = e_prizes == topk_e_values[k]
            value = min((topk_e-k)/sum(indices), last_topk_e_value)
            e_prizes[indices] = value
            last_topk_e_value = value*(1-c)
        # reduce the cost of the edges such that at least one edge is selected
        cost_e = min(cost_e, e_prizes.max().item()*(1-c/2))
    else:
        e_prizes = torch.zeros(graph.num_edges)

    costs = []
    edges = []
    vritual_n_prizes = []
    virtual_edges = []
    virtual_costs = []
    mapping_n = {}
    mapping_e = {}
    for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
        prize_e = e_prizes[i]
        if prize_e <= cost_e:
            mapping_e[len(edges)] = i
            edges.append((src, dst))
            costs.append(cost_e - prize_e)
        else:
            virtual_node_id = graph.num_nodes + len(vritual_n_prizes)
            mapping_n[virtual_node_id] = i
            virtual_edges.append((src, virtual_node_id))
            virtual_edges.append((virtual_node_id, dst))
            virtual_costs.append(0)
            virtual_costs.append(0)
            vritual_n_prizes.append(prize_e - cost_e)

    prizes = np.concatenate([n_prizes, np.array(vritual_n_prizes)])
    num_edges = len(edges)
    if len(virtual_costs) > 0:
        costs = np.array(costs+virtual_costs)
        edges = np.array(edges+virtual_edges)

    vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, pruning, verbosity_level)

    selected_nodes = vertices[vertices < graph.num_nodes]
    selected_edges = [mapping_e[e] for e in edges if e < num_edges]
    virtual_vertices = vertices[vertices >= graph.num_nodes]
    if len(virtual_vertices) > 0:
        virtual_vertices = vertices[vertices >= graph.num_nodes]
        virtual_edges = [mapping_n[i] for i in virtual_vertices]
        selected_edges = np.array(selected_edges+virtual_edges)

    edge_index = graph.edge_index[:, selected_edges]
    selected_nodes = np.unique(np.concatenate([selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))

    n = textual_nodes.iloc[selected_nodes]
    e = textual_edges.iloc[selected_edges]
    desc = n.to_csv(index=False)+'\n'+e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])

    mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}

    x = graph.x[selected_nodes]
    edge_attr = graph.edge_attr[selected_edges]
    src = [mapping[i] for i in edge_index[0].tolist()]
    dst = [mapping[i] for i in edge_index[1].tolist()]
    edge_index = torch.LongTensor([src, dst])
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=len(selected_nodes))

    return data, desc

In [25]:
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download")

def retreival(question, k=3):
    model, tokenizer, device = load_model[model_name]()
    text2embedding = load_text2embedding[model_name]
    # Encode question
    q_emb = text2embedding(model, tokenizer, device, [question])[0]  

    # Ensure collection is loaded before search
    try:
        milvus_client.load_collection(COLLECTION_NAME)
    except Exception as e:
        print(f"Error loading collection: {e}")
        return [], []

    # Perform similarity search in Milvus
    search_results = milvus_client.search(
        collection_name=COLLECTION_NAME,
        data=[q_emb.tolist()],
        limit=k,
        search_params={"metric_type": METRIC_TYPE, "params": {}},
        output_fields=["graph_idx"]
    )

    # Extract graph indices from results
    hits = search_results[0]
    graph_indices = [hit["entity"]["graph_idx"] for hit in hits]

    # Collect subgraphs and descriptions
    sub_graphs = []
    descriptions = []

    for index in tqdm(graph_indices, desc="Retrieving subgraphs"):
        nodes_path = f'{path_nodes}/{index}.csv'
        edges_path = f'{path_edges}/{index}.csv'
        graph_path = f'{path_graphs}/{index}.pt'

        if not (os.path.exists(nodes_path) and os.path.exists(edges_path) and os.path.exists(graph_path)):
            print(f"Missing data for graph {index}")
            continue

        nodes = pd.read_csv(nodes_path)
        edges = pd.read_csv(edges_path)
        if len(nodes) == 0:
            print(f"Empty graph at index {index}")
            continue

        graph = torch.load(graph_path)

        # Apply your custom retrieval logic (must be defined elsewhere)
        subg, desc = retrieval_via_pcst(
            graph=graph,
            q_emb=q_emb,
            textual_nodes=nodes,
            textual_edges=edges,
            topk=3,
            topk_e=5,
            cost_e=0.5
        )

        sub_graphs.append(subg)
        descriptions.append(desc)

    return sub_graphs, descriptions

question = "How does air pollution impact the treatment or worsening of asthma and COPD symptoms?"
question2 = "what are asthma symptoms?"
subgraphs, descriptions = retreival(question2, k=1)

inherit model weights from sentence-transformers/all-roberta-large-v1


Retrieving subgraphs: 100%|██████████| 1/1 [00:00<00:00, 10.99it/s]


In [27]:
for i in descriptions:
    print(i)
    print('--------------')

node_id,node_attr
0,asthma
16,narrowed airways
24,harder breathing out
30,asthma_symptoms
129,respiratory_distress
134,early morning symptoms
143,breathlessness
146,flare-ups
165,airway irritation
174,wheezing
205,asthma_attack
245,asthma attack
257,nocturnal symptoms
266,shortness of breath
320,asthma exacerbation
329,chest_tightness
340,excessive mucus production
345,coughing
363,nocturnal_coughing
369,harder_to_breathe
370,swelling_of_airways
371,lightheadedness
372,worsened_asthma_symptoms
373,dyspnea
376,tightening_of_airways
381,asthma symptoms
382,worsened asthma symptoms

src,edge_attr,dst
245,occurs_with,382
381,leads_to_complication,245
0,has_symptom,16
0,has_symptom,24
0,has_symptom,30
0,has_symptom,129
0,has_symptom,134
0,has_symptom,143
0,has_symptom,146
0,has_symptom,165
0,has_symptom,174
0,has_symptom,245
0,has_symptom,257
0,has_symptom,266
0,has_symptom,320
0,has_symptom,329
0,has_symptom,340
0,has_symptom,345
0,has_symptom,363
205,has_symptom,369
205,has_symptom,370
20