# Notebook 1: Primary Graph Building

This notebook processes the cleaned text decomposition responses from the LLM, including semantic units, entities, and relationships, to build the corresponding **S**, **N**, and **R** nodes and link them. The default edge weight is 1. 

The resulting graph is called **G1 – the primary graph**. A mapping of entity names to node IDs is also stored for quick retrieval by entity matching. 

Finally, a NetworkX GML file is stored for visualization using Gephi or any other compatible tool.


In [1]:
import os
dir_path = os.getcwd()
print("The directory of this script is:", dir_path)
root_path = os.path.dirname(dir_path)
print("The root directory is:", root_path)

The directory of this script is: c:\Users\HP\Desktop\Projects\NodeRAG\graphs
The root directory is: c:\Users\HP\Desktop\Projects\NodeRAG


In [None]:
#load llm decomposition response
import pandas as pd
medical_responses = pd.read_parquet(f"{root_path}\\text_decomposition/medical_responses_cleaned.parquet")
medical_responses

In [None]:
#load Node class definition
import sys
sys.path.append(root_path)
from graphs.Node import Node

In [None]:
#regex for matching entities in relationship triples (only match whole word)
import re
def entities_in_relationship(rel, entities):
    rel_text = rel.content.lower()
    found = []
    for e in entities:
        pattern = r'\b' + re.escape(e.content.lower()) + r'\b'
        if re.search(pattern, rel_text):
            found.append(e)
    return found

In [None]:
#function for creating and linking S, N, R nodes
import json
from tqdm import tqdm
def build_nodes(df, source_name):
    nodes = dict()
    entity_nodes = dict()
    for idx in tqdm(range(len(df))):
        row = df.iloc[idx]
        response = row["cleaned_response"]
        response = json.loads(response)
        for unit_idx,unit in enumerate(response):
            #data
            semantic_unit = unit["semantic_unit"]
            entities = unit["entities"]
            relationships = unit["relationships"]

            #ids
            source_id = f"{source_name}-{idx}"
            semantic_id = f"{source_id}-S-{unit_idx}"
            #entities_ids = [f"{semantic_id}-N-{e_idx}" for e_idx in range(len(entities))]
            relationships_ids = [f"{semantic_id}-R-{r_idx}" for r_idx in range(len(relationships))]

            #create semantic node
            semantic_node = Node(
                id=semantic_id,
                node_type = "S",
                source = source_id,
                content = semantic_unit
            )

            #create entity nodes
            current_entity_nodes = set()
            for e_idx, entity in enumerate(entities):
                key = entity.strip().upper()
                if key not in entity_nodes:
                    entity_node = Node(
                        id = f"{source_name}-N-{len(entity_nodes)}",
                        node_type = "N",
                        source = "",
                        content = key
                    )
                    entity_nodes[key] = entity_node
                else:
                    entity_node = entity_nodes[key]
                current_entity_nodes.add(entity_node)

            #create relationship nodes
            relationship_nodes = []
            for r_idx, relationship in enumerate(relationships):
                relationship_node = Node(
                    id = relationships_ids[r_idx],
                    node_type = "R",
                    source = source_id,
                    content = relationship
                )
                relationship_nodes.append(relationship_node)

            #link nodes
            for entity_node in current_entity_nodes:
                semantic_node.link(entity_node)
                entity_node.link(semantic_node)
            
            for relationship_node in relationship_nodes:
                ents = entities_in_relationship(relationship_node, current_entity_nodes)
                for ent in ents:
                    relationship_node.link(ent)
                    ent.link(relationship_node)
            
            current_nodes = [semantic_node] + relationship_nodes
            for n in current_nodes:
                nodes[n.id] = n
    nodes.update({v.id: v for v in entity_nodes.values()})
    #store entities
    entities_dict = {k: v.id for k, v in entity_nodes.items()}
    return nodes, entities_dict

In [7]:
medical_nodes, medical_entities = build_nodes(medical_responses, "medical")

100%|██████████| 554/554 [00:01<00:00, 439.18it/s]


In [8]:
import pickle
with open(f"{root_path}/graphs/data/graphs/G1_medical_primary_graph.pkl", "wb") as f:
    pickle.dump(medical_nodes, f)
with open(f"{root_path}/graphs/data/nodes/entity/medical_entities.pkl", "wb") as f:
    pickle.dump(medical_entities, f)

In [9]:
with open(f"{root_path}/graphs/data/graphs/G1_medical_primary_graph.pkl", "rb") as f:
    medical_nodes = pickle.load(f)
with open(f"{root_path}/graphs/data/nodes/entity/medical_entities.pkl", "rb") as f:
    medical_entities = pickle.load(f)

In [None]:
#sanity checks
for node_id in medical_nodes:
    node = medical_nodes[node_id]
    for edge in node.edges:
        if node.edges[edge] > 1:
            print(node_id,"-",edge,"-",node.edges[edge])

In [None]:
def nodes_to_dataframe(nodes):
    data = []
    for node in nodes.values():
        data.append({
            "id": node.id,
            "node_type": node.node_type,
            "source": node.source,
            "content": node.content,
            "edges": node.edges
        })
    return pd.DataFrame(data)

medical_nodes_df = nodes_to_dataframe(medical_nodes)
medical_nodes_df = medical_nodes_df.sort_values(by=["source", "node_type", "id"]).reset_index(drop=True)
medical_nodes_df


In [13]:
medical_entities

{'BASAL CELL SKIN CANCER': 'medical-N-0',
 'BASAL CELL CARCINOMA (BCC)': 'medical-N-1',
 '3 MILLION CASES': 'medical-N-2',
 'UNITED STATES': 'medical-N-3',
 'SURGERY': 'medical-N-4',
 'BASAL CELLS': 'medical-N-5',
 'EPIDERMIS': 'medical-N-6',
 'FACE': 'medical-N-7',
 'HEAD': 'medical-N-8',
 'NECK': 'medical-N-9',
 'ARMS': 'medical-N-10',
 'LEGS': 'medical-N-11',
 'TRUNK': 'medical-N-12',
 'SKIN': 'medical-N-13',
 'DERMIS': 'medical-N-14',
 'HYPODERMIS': 'medical-N-15',
 'SQUAMOUS CELLS': 'medical-N-16',
 'MELANOCYTES': 'medical-N-17',
 'FAT': 'medical-N-18',
 'CONNECTIVE TISSUE': 'medical-N-19',
 'ULTRAVIOLET (UV) RAYS': 'medical-N-20',
 'SQUAMOUS CELL SKIN CANCER': 'medical-N-21',
 'MELANOMA': 'medical-N-22',
 'NCCN GUIDELINES FOR PATIENTS': 'medical-N-23',
 'NCCN.ORG/PATIENTGUIDELINES': 'medical-N-24',
 'NCCN PATIENT GUIDES FOR CANCER APP': 'medical-N-25',
 'LIGHTER SKIN': 'medical-N-26',
 'LIGHTER HAIR': 'medical-N-27',
 'LIGHTER EYES': 'medical-N-28',
 'FLAT, PALE OR YELLOW AREAS':

In [None]:
total_edges = sum(len(node.edges) for node in medical_nodes.values())//2
print("Total edges:", total_edges)

Total edges: 67566


In [None]:
#data for visualization using Gephi
import networkx as nx
G = nx.Graph()
for node_id, node in medical_nodes.items():
    G.add_node(node_id, node_type=node.node_type, content=node.content, source=node.source)
    for target_id, weight in node.edges.items():
        if target_id in medical_nodes:
            G.add_edge(node_id, target_id, weight=weight)

nx.write_gml(G, "viz/medical_graph.gml")
