In [5]:
!pip install networkx graspologic matplotlib llama-index future
 

Collecting future
  Downloading future-1.0.0-py3-none-any.whl.metadata (4.0 kB)
Downloading future-1.0.0-py3-none-any.whl (491 kB)
Installing collected packages: future
Successfully installed future-1.0.0


In [1]:
"""
1. GraphRAGStore Class:
Initializes a graph using NetworkX and sets up attributes for communities, documents, and maximum cluster size.
Provides methods to add documents and triplets (subject-predicate-object relationships) to the graph.
Implements community detection using the Hierarchical Leiden algorithm and collects detailed community information.
Offers methods to query the graph for specific triplets and perform advanced queries using natural language processing.
Includes functionality to extract subgraphs from query results and visualize them.
Provides methods to save the graph store to a file and serialize it to JSON for easier inspection.

2. GraphRAGQueryEngine Class:
Initializes with a GraphRAGStore and a language model (LLM).
Processes queries by performing advanced graph queries, extracting subgraphs, and generating answers using the LLM.
Converts subgraphs to string representations for use in LLM prompts.

3. Main Execution:
Demonstrates example usage by creating a GraphRAGStore, adding sample documents and triplets, and building communities.
Performs basic and advanced queries on the graph and extracts subgraphs from the results.
Visualizes the subgraph and converts it to DOT format for Graphviz visualization.
Saves the graph store to a JSON file for inspection.

"""

import networkx as nx
from graspologic.partition import hierarchical_leiden
from typing import List, Dict, Any, Optional, Tuple
import pickle
import os
import matplotlib.pyplot as plt 
import re
import colorsys
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
import json
from triplets_extraction import extract_triplets_from_json
import tokenizers




In [6]:

class Document:
    def __init__(self, text: str, doc_id: str = None):
        self.text = text
        self.doc_id = doc_id

class GraphRAGStore:
    def __init__(self):
        self.graph = nx.Graph()
        self.communities = {}
        self.max_cluster_size = 5
        self.documents = {}  # New attribute to store documents

    def add_document(self, document: Document):
        """Add a document to the store."""
        if document.doc_id is None:
            document.doc_id = str(len(self.documents))
        self.documents[document.doc_id] = document

    def add_triplet(self, subject: str, predicate: str, object: str, description: str, doc_id: str = None):
        """Add a single triplet to the graph."""
        self.graph.add_edge(subject, object, relationship=predicate, description=description, doc_id=doc_id)

    def add_triplets(self, triplets: List[tuple]):
        """Add multiple triplets to the graph."""
        for triplet in triplets:
            if len(triplet) == 5:
                subject, predicate, object, description, doc_id = triplet
            else:
                subject, predicate, object, description = triplet
                doc_id = None
            self.add_triplet(subject, predicate, object, description, doc_id)

    def build_communities(self):
        """Builds communities from the graph."""
        community_hierarchical_clusters = hierarchical_leiden(
            self.graph, max_cluster_size=self.max_cluster_size
        )
        self.communities = self._collect_community_info(
            self.graph, community_hierarchical_clusters
        )

    def _collect_community_info(self, nx_graph, clusters):
        """Collect detailed information for each node based on their community."""
        community_mapping = {item.node: item.cluster for item in clusters}
        community_info = {}
        for item in clusters:
            cluster_id = item.cluster
            node = item.node
            if cluster_id not in community_info:
                community_info[cluster_id] = []

            for neighbor in nx_graph.neighbors(node):
                if community_mapping[neighbor] == cluster_id:
                    edge_data = nx_graph.get_edge_data(node, neighbor)
                    if edge_data:
                        detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
                        if 'doc_id' in edge_data:
                            detail += f" (Document ID: {edge_data['doc_id']})"
                        community_info[cluster_id].append(detail)
        return community_info

    def get_communities(self) -> Dict[Any, List[str]]:
        """Returns the communities, building them if not already done."""
        if not self.communities:
            self.build_communities()
        return self.communities

    def print_graph_info(self):
        """Print basic information about the graph."""
        print(f"Number of nodes: {self.graph.number_of_nodes()}")
        print(f"Number of edges: {self.graph.number_of_edges()}")
        print(f"Number of communities: {len(self.communities)}")
        print(f"Number of documents: {len(self.documents)}")

    def print_communities(self):
        """Print the contents of each community."""
        for community_id, details in self.communities.items():
            print(f"\nCommunity {community_id}:")
            for detail in details:
                print(f"  {detail}")

    def save(self, filename='graph_store.pkl'):
        """Save the GraphRAGStore object to a file."""
        with open(filename, 'wb') as f:
            pickle.dump(self, f)

    @classmethod
    def load(cls, filename='graph_store.pkl'):
        """Load a GraphRAGStore object from a file."""
        with open(filename, 'rb') as f:
            return pickle.load(f)

    def query(self, subject: Optional[str] = None, predicate: Optional[str] = None, object: Optional[str] = None, 
              fuzzy_match: bool = True, case_sensitive: bool = False) -> List[Tuple]:
        """
        Query the graph for matching triplets with enhanced flexibility.
        """
        results = []

        def match_string(query: Optional[str], target: str) -> bool:
            if query is None:
                return True
            if not case_sensitive:
                query = query.lower()
                target = target.lower()
            if fuzzy_match:
                return query in target
            else:
                return query == target

        for s, o, data in self.graph.edges(data=True):
            p = data['relationship']
            d = data['description']
            doc_id = data.get('doc_id')

            if match_string(subject, s) and match_string(predicate, p) and match_string(object, o):
                result = (s, p, o, d)
                if doc_id:
                    result += (doc_id,)
                results.append(result)

        return results

    def advanced_query(self, query: str) -> List[Tuple]:
        """
        Perform an advanced query using natural language processing techniques.
        """
        # Tokenize the query
        tokens = re.findall(r'\w+', query.lower())
        
        # Remove stop words (you can expand this list)
        stop_words = set(['the', 'a', 'an', 'in', 'on', 'at', 'for', 'to', 'of', 'with', 'by', 'is', 'are', 'was', 'were'])
        tokens = [token for token in tokens if token not in stop_words]
        
        # Perform the query for each token
        all_results = []
        for token in tokens:
            results = self.query(subject=token, fuzzy_match=True, case_sensitive=False)
            results += self.query(predicate=token, fuzzy_match=True, case_sensitive=False)
            results += self.query(object=token, fuzzy_match=True, case_sensitive=False)
            all_results.extend(results)
        
        # Remove duplicates while preserving order
        seen = set()
        unique_results = []
        for result in all_results:
            result_key = (result[0], result[1], result[2])  # Use (subject, predicate, object) as the key
            if result_key not in seen:
                seen.add(result_key)
                unique_results.append(result)
        
        # Sort results by relevance (number of matching tokens)
        sorted_results = sorted(unique_results, key=lambda x: sum(token in ' '.join(x).lower() for token in tokens), reverse=True)
        
        # Return top 10 most relevant results
        return sorted_results[:50]

    def extract_subgraph(self, query_results):
        """Extract a subgraph based on query results."""
        subgraph = nx.Graph()
        for result in query_results:
            s, p, o, d = result[:4]
            edge_data = {'relationship': p, 'description': d}
            if len(result) == 5:
                edge_data['doc_id'] = result[4]
            subgraph.add_edge(s, o, **edge_data)
        return subgraph
    
    def visualize_graph(self, subgraph=None, figsize=(12, 8)):
        """
        Visualize the graph or a subgraph using networkx and matplotlib.
        
        Args:
            subgraph (nx.Graph, optional): A subgraph to visualize. If None, visualize the entire graph.
            figsize (tuple, optional): Figure size. Defaults to (12, 8).
        """
        plt.figure(figsize=figsize)
        graph_to_draw = subgraph if subgraph is not None else self.graph
        pos = nx.spring_layout(graph_to_draw)
        nx.draw(graph_to_draw, pos, with_labels=True, node_color='lightblue', 
                node_size=500, font_size=8, font_weight='bold')
        
        edge_labels = nx.get_edge_attributes(graph_to_draw, 'relationship')
        nx.draw_networkx_edge_labels(graph_to_draw, pos, edge_labels=edge_labels, font_size=6)
        
        plt.title("Graph Visualization" if subgraph is None else "Subgraph Visualization")
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        
    def subgraph_to_dot(self, subgraph):
        """
        Convert a subgraph to DOT format for visualization with Graphviz.

        Args:
            subgraph (nx.Graph): The subgraph to convert.

        Returns:
            str: The DOT representation of the subgraph.
        """
        dot_string = "digraph G {\n"
        dot_string += "  rankdir=LR;\n"  # Left to right layout
        dot_string += "  node [style=filled];\n"  # Use filled style for nodes

        # Generate unique colors for relationships
        relationships = set(nx.get_edge_attributes(subgraph, 'relationship').values())
        colors = self._generate_colors(len(relationships))
        color_map = dict(zip(relationships, colors))

        # Add nodes
        subjects = set()
        objects = set()
        for s, o, _ in subgraph.edges(data=True):
            subjects.add(s)
            objects.add(o)

        for node in subgraph.nodes():
            if node in subjects and node in objects:
                color = "#FFA500"  # Orange for nodes that are both subject and object
            elif node in subjects:
                color = "#ADD8E6"  # Light blue for subjects
            else:
                color = "#90EE90"  # Light green for objects
            dot_string += f'  "{node}" [label="{node}", fillcolor="{color}"];\n'

        # Add edges
        for edge in subgraph.edges(data=True):
            source, target, data = edge
            relationship = data.get('relationship', '')
            description = data.get('description', '')
            doc_id = data.get('doc_id', 'N/A')
            edge_label = f"{relationship}\\n{description}\\n(Doc: {doc_id})"
            edge_color = color_map[relationship]
            dot_string += f'  "{source}" -> "{target}" [label="{edge_label}", color="{edge_color}"];\n'

        # Add legend
        dot_string += "  subgraph cluster_legend {\n"
        dot_string += "    label = \"Legend\";\n"
        dot_string += "    node [shape=box];\n"
        dot_string += '    "Subject" [fillcolor="#ADD8E6"];\n'
        dot_string += '    "Object" [fillcolor="#90EE90"];\n'
        dot_string += '    "Both" [fillcolor="#FFA500"];\n'
        for relationship, color in color_map.items():
            dot_string += f'    "{relationship}" [shape=plaintext, fillcolor="white"];\n'
            dot_string += f'    "dummy_{relationship}" [shape=point, style=invis];\n'
            dot_string += f'    "dummy_{relationship}" -> "{relationship}" [color="{color}"];\n'
        dot_string += "  }\n"

        dot_string += "}"
        return dot_string

    def _generate_colors(self, n):
        """Generate n distinct colors."""
        HSV_tuples = [(x * 1.0 / n, 0.5, 0.5) for x in range(n)]
        return ['#%02x%02x%02x' % tuple(int(x * 255) for x in colorsys.hsv_to_rgb(*hsv)) for hsv in HSV_tuples]
    
    def to_json(self, filename='graph_store_v8_1.json'):
        """Serialize the GraphRAGStore object to a JSON file."""
        data = {
            'documents': {doc_id: {'text': doc.text, 'doc_id': doc.doc_id} for doc_id, doc in self.documents.items()},
            'nodes': list(self.graph.nodes),
            'edges': [
                {
                    'subject': s,
                    'object': o,
                    'relationship': data['relationship'],
                    'description': data['description'],
                    'doc_id': data.get('doc_id')
                }
                for s, o, data in self.graph.edges(data=True)
            ],
            'communities': self.communities
        }
        with open(filename, 'w') as f:
            json.dump(data, f, indent=4)

def get_or_create_graph_store(filename='graph_store.pkl'):
    """Get an existing GraphRAGStore or create a new one if it doesn't exist."""
    return GraphRAGStore()
    if os.path.exists(filename):
        return GraphRAGStore.load(filename)
    else:
        return GraphRAGStore()
    
class GraphRAGQueryEngine:
    def __init__(self, graph_store: GraphRAGStore,  llm: OpenAI):
        self.graph_store = graph_store
        self.llm = llm
        

    def query(self, query_str: str) -> str:
        """Process the extracted subgraph to generate answers to a specific query."""
        # Perform an advanced query to get relevant results
        advanced_results = self.graph_store.advanced_query(query_str)
        print(f"Advanced results: {advanced_results}")
        # Extract a subgraph from the advanced query results
        subgraph = self.graph_store.extract_subgraph(advanced_results)
        print(f"Subgraph: {subgraph}")
        # Generate an answer from the subgraph
        answer = self.generate_answer_from_subgraph(subgraph, query_str)
        
        return answer

    def generate_answer_from_subgraph(self, subgraph: nx.Graph, query: str) -> str:
        """Generate an answer from the extracted subgraph based on a given query using LLM."""
        # Convert the subgraph to a string representation
        subgraph_info = self.subgraph_to_string(subgraph)
        
        prompt = (
            f"Given the following subgraph information: {subgraph_info}, "
            f"how would you answer the following query? Query: {query}"
        )
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content="I need an answer based on the above information.",
            ),
        ]
        
        
        # Calculate input tokens
        # input_text = " ".join([msg.content for msg in messages])
        # input_tokens = len(self.tokenizer.encode(input_text))

        # Get response and calculate output tokens
        response = self.llm.chat(messages)
        # print("Response is : ", response)
        # output_tokens = len(self.tokenizer.encode(str(response)))

        # Calculate cost (assuming a hypothetical cost per token)
        # cost_per_input_token = 0.0000025  
        # cost_per_output_token = 0.00001
        # total_cost = (input_tokens * cost_per_input_token) + (output_tokens * cost_per_output_token)

        # Print or log the token counts and cost
        # print(f"Input Tokens: {input_tokens}, Output Tokens: {output_tokens}, Total Cost: ${total_cost:.4f}")
        
        
        # response = self.llm.chat(messages)
        return re.sub(r"^assistant:\s*", "", str(response)).strip()

    def subgraph_to_string(self, subgraph: nx.Graph) -> str:
        """Convert a subgraph to a string representation for use in LLM prompts."""
        subgraph_details = []
        for s, o, data in subgraph.edges(data=True):
            relationship = data.get('relationship', '')
            description = data.get('description', '')
            doc_id = data.get('doc_id', 'N/A')
            subgraph_details.append(f"{s} -> {o} -> {relationship} -> {description} (Document ID: {doc_id})")
        return "\n".join(subgraph_details)

In [7]:

# Example usage
# if __name__ == "__main__":
# Get or create GraphRAGStore
graph_store = get_or_create_graph_store()

# If the graph is empty, add sample documents and triplets
if graph_store.graph.number_of_edges() == 0:
    # Add sample documents
    with open('/Users/akshit/Documents/Projects/Python-all/information-extraction/Guidelines/Data/NCCN_prostate_4.2024_Graph_12_33.json', 'r') as f:
            data = json.load(f)
        
    # Create documents from graph nodes
    documents = []
    for node in data['@graph']:
        # Convert the entire node to a string representation
        node_text = json.dumps(node, indent=2)
        doc_id = node.get('@id', 'unknown')
        documents.append(Document(node_text, doc_id))

    # Add documents to graph store
    for doc in documents:
        graph_store.add_document(doc)


    json_file = '/Users/akshit/Documents/Projects/Python-all/information-extraction/Guidelines/Data/NCCN_prostate_4.2024_Graph_12_33.json'

    triplets = extract_triplets_from_json(json_file)
    print(triplets)
    graph_store.add_triplets(triplets)
    graph_store.build_communities()
    graph_store.save()  # Save the graph store after adding triplets and documents

# Print graph information
graph_store.print_graph_info()


graph_store.to_json()
print("GraphRAGStore has been saved to 'graph_store.json'")


[('Clinically localized prostate cancer (Any T, N0, M0 or Any T, NX, MX)', 'page_key', 'PROS-1', '', 'http://nccn-guideline.org/nsclc/0'), ('Clinically localized prostate cancer (Any T, N0, M0 or Any T, NX, MX)', 'page_no', '12', '', 'http://nccn-guideline.org/nsclc/0'), ('Clinically localized prostate cancer (Any T, N0, M0 or Any T, NX, MX)', 'label', 'INITIAL PROSTATE CANCER DIAGNOSIS{a,b,c}', '', 'http://nccn-guideline.org/nsclc/0'), ('Clinically localized prostate cancer (Any T, N0, M0 or Any T, NX, MX)', 'next_contained_nodes', 'Perform physical exam', '', 'http://nccn-guideline.org/nsclc/0'), ('Clinically localized prostate cancer (Any T, N0, M0 or Any T, NX, MX)', 'next_contained_nodes', 'Perform digital rectal exam (DRE) to confirm clinical stage', '', 'http://nccn-guideline.org/nsclc/0'), ('Clinically localized prostate cancer (Any T, N0, M0 or Any T, NX, MX)', 'next_contained_nodes', 'Perform and/or collect prostate-specific antigen (PSA) and calculate PSA density', '', 'http

In [8]:

# Perform queries with the new methods
# print("\nBasic query results for 'Bone imaging':")
# basic_results = graph_store.query(subject="Bone imaging", fuzzy_match=True, case_sensitive=False)
# for result in basic_results:
#     print(f"  {result}")


# print("\nAdvanced query results for : ", query)
# advanced_results = graph_store.advanced_query(query)
# for result in advanced_results:
#     print(f"  {result}")

# # Extract subgraph from advanced query results
# subgraph = graph_store.extract_subgraph(advanced_results)
# print(f"\nSubgraph information:")
# print(f"  Number of nodes: {subgraph.number_of_nodes()}")
# print(f"  Number of edges: {subgraph.number_of_edges()}")

# # Print subgraph edges
# print("\nSubgraph edges:")
# for s, o, data in subgraph.edges(data=True):
#     print(f"  {s} -> {o} -> {data['relationship']} -> {data['description']} (Document ID: {data.get('doc_id', 'N/A')})")

# # Visualize the subgraph
# # graph_store.visualize_graph(subgraph)

# # Convert subgraph to DOT format
# dot_representation = graph_store.subgraph_to_dot(subgraph)
# print("\nDOT representation of the subgraph:")
# print(dot_representation)

# # Save the DOT representation to a file
# dot_filename = "subgraph4.dot"
# with open(dot_filename, "w") as f:
#     f.write(dot_representation)

# print("\nDOT representation has been saved to : ", dot_filename)



query = "What are the clinical features for patients with intermediate clinically localised prostate cancer?"
llm = OpenAI(api_key=<api_key>, model="o1-mini-2024-09-12")

    # Create GraphRAGQueryEngine
query_engine = GraphRAGQueryEngine(graph_store=graph_store, llm=llm)

# Perform a query
# query = "How is bone imaging performed?"
response = query_engine.query(query)
print(f"\nQuery: {query}")
print(f"Response: {response}")


Advanced results: [('Because of the increased sensitivity and specificity of PSMA-PET tracers for detecting micrometastatic disease compared to conventional imaging (eg, CT, bone scan) at both initial staging and biochemical recurrence (BCR), the panel does not feel that conventional imaging is a necessary prerequisite to PSMA-PET and that PSMA-PET/CT or PSMA-PET/MRI can serve as an equally effective, if not more effective frontline imaging tool for these patients.', 'refers_to', 'INITIAL RISK STRATIFICATION AND STAGING WORKUP FOR CLINICALLY LOCALIZED DISEASE{i}Risk GroupClinical/Pathologic Features (Staging, ST-1)Additional Evaluation{f,m}Initial TherapyVery low{j}Has all of the following:• cT1c • Grade Group 1 • PSA <10 ng/mL • <3 prostate biopsy fragments/cores positive, ≤50% cancer in each fragment/core{k} • PSA density <0.15 ng/mL/g• Confirmatory testing can be used to assess the appropriateness of active surveillance (PROS-F 2 of 5) PROS-3Low{j}Has all of the following but does n