In [19]:
# pip install llama-index graspologic numpy==1.24.4 scipy==1.12.0

In [64]:
from llama_index.core import Document

text= """
The History and Impact of Quantum Mechanics

Quantum mechanics is a fundamental theory in physics that describes nature at the smallest scales, such as atoms and subatomic particles. It was developed in the early 20th century, revolutionizing our understanding of the physical world.

The origins of quantum mechanics trace back to Max Planck’s work in 1900, where he introduced the idea of quantized energy levels to explain blackbody radiation. This concept was further developed by Albert Einstein in 1905 when he explained the photoelectric effect by proposing that light consists of discrete packets of energy called photons. Einstein’s work earned him the Nobel Prize in Physics in 1921.

Building on these ideas, Niels Bohr proposed the Bohr model of the atom in 1913, which described electrons orbiting the nucleus in fixed energy levels. This model successfully explained the spectral lines of hydrogen but had limitations with more complex atoms.

In the 1920s, Werner Heisenberg formulated matrix mechanics, and Erwin Schrödinger developed wave mechanics, two equivalent formulations of quantum mechanics. Schrödinger’s wave equation describes how the quantum state of a physical system changes over time.

Werner Heisenberg is also famous for the uncertainty principle, which states that certain pairs of physical properties, like position and momentum, cannot be simultaneously known to arbitrary precision.

Quantum mechanics has had profound implications beyond physics. It laid the foundation for quantum chemistry, explaining chemical bonding and reactions. It also enabled the development of semiconductors, which are the basis for modern electronics, including computers and smartphones.

Richard Feynman, a prominent physicist of the 20th century, contributed significantly to quantum electrodynamics (QED), a quantum theory of the interaction between light and matter. Feynman introduced Feynman diagrams, a pictorial representation of particle interactions.

The theory of quantum mechanics also paved the way for emerging fields such as quantum computing and quantum cryptography. Quantum computers use quantum bits, or qubits, which can represent both 0 and 1 simultaneously, promising exponential speedups for certain computational problems.

Despite its successes, quantum mechanics challenges our classical intuitions about reality. Concepts like superposition and entanglement defy everyday experience but have been experimentally confirmed.

The Copenhagen interpretation, primarily developed by Niels Bohr and Werner Heisenberg, is one of the earliest and most widely taught interpretations of quantum mechanics. It emphasizes the probabilistic nature of quantum measurements and the role of the observer.

Today, research in quantum mechanics continues to expand, with scientists exploring quantum gravity, quantum field theory, and applications in materials science and information technology.

"""

In [65]:
documents = [Document(text=text)]

In [54]:
from dotenv import load_dotenv
import os
# 1. Load environment variables for API keys
load_dotenv()
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

In [55]:
# pip install llama-index-llms-groq

In [66]:
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.openai import OpenAIEmbedding
llm = Gemini(api_key=GEMINI_API_KEY, model="gemini-1.5-flash")  # or "gemini-pro", etc.

# from llama_index.llms.groq import Groq
# llm = Groq(api_key=GROQ_API_KEY, model="llama3-70b-8192")  # or "llama-3-70b-8192", etc.

# ---- Use OPenai ----
# llm = OpenAI(model="gpt-4-turbo")  # You can use "gpt-3.5-turbo" for lower cost
embed_model = OpenAIEmbedding(model="text-embedding-3-small")


  llm = Gemini(api_key=GEMINI_API_KEY, model="gemini-1.5-flash")  # or "gemini-pro", etc.


In [67]:
import asyncio
import nest_asyncio

nest_asyncio.apply()

from typing import Any, List, Callable, Optional, Union, Dict
from IPython.display import Markdown, display

from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import (
    default_parse_triplets_fn,
)
from llama_index.core.graph_stores.types import (
    EntityNode,
    KG_NODES_KEY,
    KG_RELATIONS_KEY,
    Relation,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import (
    DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
)
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.bridge.pydantic import BaseModel, Field


class GraphRAGExtractor(TransformComponent):
    """Extract triples from a graph.

    Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.

    Args:
        llm (LLM):
            The language model to use.
        extract_prompt (Union[str, PromptTemplate]):
            The prompt to use for extracting triples.
        parse_fn (callable):
            A function to parse the output of the language model.
        num_workers (int):
            The number of workers to use for parallel processing.
        max_paths_per_chunk (int):
            The maximum number of paths to extract per chunk.
    """

    llm: LLM
    extract_prompt: PromptTemplate
    parse_fn: Callable
    num_workers: int
    max_paths_per_chunk: int

    def __init__(
        self,
        llm: Optional[LLM] = None,
        extract_prompt: Optional[Union[str, PromptTemplate]] = None,
        parse_fn: Callable = default_parse_triplets_fn,
        max_paths_per_chunk: int = 10,
        num_workers: int = 4,
    ) -> None:
        """Init params."""
        from llama_index.core import Settings

        if isinstance(extract_prompt, str):
            extract_prompt = PromptTemplate(extract_prompt)

        super().__init__(
            llm=llm or Settings.llm,
            extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
            parse_fn=parse_fn,
            num_workers=num_workers,
            max_paths_per_chunk=max_paths_per_chunk,
        )

    @classmethod
    def class_name(cls) -> str:
        return "GraphExtractor"

    def __call__(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes."""
        return asyncio.run(
            self.acall(nodes, show_progress=show_progress, **kwargs)
        )

    async def _aextract(self, node: BaseNode) -> BaseNode:
        """Extract triples from a node."""
        assert hasattr(node, "text")

        text = node.get_content(metadata_mode="llm")
        try:
            llm_response = await self.llm.apredict(
                self.extract_prompt,
                text=text,
                max_knowledge_triplets=self.max_paths_per_chunk,
            )
            entities, entities_relationship = self.parse_fn(llm_response)
        except ValueError:
            entities = []
            entities_relationship = []

        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
        entity_metadata = node.metadata.copy()
        for entity, entity_type, description in entities:
            entity_metadata["entity_description"] = description
            entity_node = EntityNode(
                name=entity, label=entity_type, properties=entity_metadata
            )
            existing_nodes.append(entity_node)

        relation_metadata = node.metadata.copy()
        for triple in entities_relationship:
            subj, obj, rel, description = triple
            relation_metadata["relationship_description"] = description
            rel_node = Relation(
                label=rel,
                source_id=subj,
                target_id=obj,
                properties=relation_metadata,
            )

            existing_relations.append(rel_node)

        node.metadata[KG_NODES_KEY] = existing_nodes
        node.metadata[KG_RELATIONS_KEY] = existing_relations
        return node

    async def acall(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triples from nodes async."""
        jobs = []
        for node in nodes:
            jobs.append(self._aextract(node))

        return await run_jobs(
            jobs,
            workers=self.num_workers,
            show_progress=show_progress,
            desc="Extracting paths from text",
        )

In [68]:
# pip install future

In [69]:
import re
from llama_index.core.graph_stores import SimplePropertyGraphStore
import networkx as nx
from graspologic.partition import hierarchical_leiden

from llama_index.core.llms import ChatMessage


class GraphRAGStore(SimplePropertyGraphStore):
    community_summary = {}
    max_cluster_size = 5

    def generate_community_summary(self, text):
        """Generate summary for a given text using an LLM."""
        messages = [
            ChatMessage(
                role="system",
                content=(
                    "You are provided with a set of relationships from a knowledge graph, each represented as "
                    "entity1->entity2->relation->relationship_description. Your task is to create a summary of these "
                    "relationships. The summary should include the names of the entities involved and a concise synthesis "
                    "of the relationship descriptions. The goal is to capture the most critical and relevant details that "
                    "highlight the nature and significance of each relationship. Ensure that the summary is coherent and "
                    "integrates the information in a way that emphasizes the key aspects of the relationships."
                ),
            ),
            ChatMessage(role="user", content=text),
        ]
        response = Gemini().chat(messages)
        clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return clean_response

    def build_communities(self):
        """Builds communities from the graph and summarizes them."""
        nx_graph = self._create_nx_graph()
        community_hierarchical_clusters = hierarchical_leiden(
            nx_graph, max_cluster_size=self.max_cluster_size
        )
        community_info = self._collect_community_info(
            nx_graph, community_hierarchical_clusters
        )
        self._summarize_communities(community_info)

    def _create_nx_graph(self):
        """Converts internal graph representation to NetworkX graph."""
        nx_graph = nx.Graph()
        for node in self.graph.nodes.values():
            nx_graph.add_node(str(node))
        for relation in self.graph.relations.values():
            nx_graph.add_edge(
                relation.source_id,
                relation.target_id,
                relationship=relation.label,
                description=relation.properties["relationship_description"],
            )
        return nx_graph

    def _collect_community_info(self, nx_graph, clusters):
        """Collect detailed information for each node based on their community."""
        community_mapping = {item.node: item.cluster for item in clusters}
        community_info = {}
        for item in clusters:
            cluster_id = item.cluster
            node = item.node
            if cluster_id not in community_info:
                community_info[cluster_id] = []

            for neighbor in nx_graph.neighbors(node):
                if community_mapping[neighbor] == cluster_id:
                    edge_data = nx_graph.get_edge_data(node, neighbor)
                    if edge_data:
                        detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
                        community_info[cluster_id].append(detail)
        return community_info

    def _summarize_communities(self, community_info):
        """Generate and store summaries for each community."""
        for community_id, details in community_info.items():
            details_text = (
                "\n".join(details) + "."
            )  # Ensure it ends with a period
            self.community_summary[
                community_id
            ] = self.generate_community_summary(details_text)

    def get_community_summaries(self):
        """Returns the community summaries, building them if not already done."""
        if not self.community_summary:
            self.build_communities()
        return self.community_summary

In [70]:
# pip install graspologic

In [71]:
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.llms import LLM


class GraphRAGQueryEngine(CustomQueryEngine):
    graph_store: GraphRAGStore
    llm: LLM

    def custom_query(self, query_str: str) -> str:
        """Process all community summaries to generate answers to a specific query."""
        community_summaries = self.graph_store.get_community_summaries()
        community_answers = [
            self.generate_answer_from_summary(community_summary, query_str)
            for _, community_summary in community_summaries.items()
        ]

        final_answer = self.aggregate_answers(community_answers)
        return final_answer

    def generate_answer_from_summary(self, community_summary, query):
        """Generate an answer from a community summary based on a given query using LLM."""
        prompt = (
            f"Given the community summary: {community_summary}, "
            f"how would you answer the following query? Query: {query}"
        )
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content="I need an answer based on the above information.",
            ),
        ]
        response = self.llm.chat(messages)
        cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
        return cleaned_response

    def aggregate_answers(self, community_answers):
        """Aggregate individual community answers into a final, coherent response."""
        # intermediate_text = " ".join(community_answers)
        prompt = "Combine the following intermediate answers into a final, concise response."
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content=f"Intermediate answers: {community_answers}",
            ),
        ]
        final_response = self.llm.chat(messages)
        cleaned_final_response = re.sub(
            r"^assistant:\s*", "", str(final_response)
        ).strip()
        return cleaned_final_response

In [72]:
from llama_index.core.node_parser import SentenceSplitter

splitter = SentenceSplitter(
    chunk_size=1024,
    chunk_overlap=20,
)
nodes = splitter.get_nodes_from_documents(documents)

In [73]:
len(nodes)

1

In [74]:
KG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.

-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities

2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other

3. Output Formatting:
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.

-An Output Example-
{
  "entities": [
    {
      "entity_name": "Albert Einstein",
      "entity_type": "Person",
      "entity_description": "Albert Einstein was a theoretical physicist who developed the theory of relativity and made significant contributions to physics."
    },
    {
      "entity_name": "Theory of Relativity",
      "entity_type": "Scientific Theory",
      "entity_description": "A scientific theory developed by Albert Einstein, describing the laws of physics in relation to observers in different frames of reference."
    },
    {
      "entity_name": "Nobel Prize in Physics",
      "entity_type": "Award",
      "entity_description": "A prestigious international award in the field of physics, awarded annually by the Royal Swedish Academy of Sciences."
    }
  ],
  "relationships": [
    {
      "source_entity": "Albert Einstein",
      "target_entity": "Theory of Relativity",
      "relation": "developed",
      "relationship_description": "Albert Einstein is the developer of the theory of relativity."
    },
    {
      "source_entity": "Albert Einstein",
      "target_entity": "Nobel Prize in Physics",
      "relation": "won",
      "relationship_description": "Albert Einstein won the Nobel Prize in Physics in 1921."
    }
  ]
}

-Real Data-
######################
text: {text}
######################
output:"""

In [75]:
import json


def parse_fn(response_str: str) -> Any:
    json_pattern = r"\{.*\}"
    match = re.search(json_pattern, response_str, re.DOTALL)
    entities = []
    relationships = []
    if not match:
        return entities, relationships
    json_str = match.group(0)
    try:
        data = json.loads(json_str)
        entities = [
            (
                entity["entity_name"],
                entity["entity_type"],
                entity["entity_description"],
            )
            for entity in data.get("entities", [])
        ]
        relationships = [
            (
                relation["source_entity"],
                relation["target_entity"],
                relation["relation"],
                relation["relationship_description"],
            )
            for relation in data.get("relationships", [])
        ]
        return entities, relationships
    except json.JSONDecodeError as e:
        print("Error parsing JSON:", e)
        return entities, relationships


kg_extractor = GraphRAGExtractor(
    llm=llm,
    extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
    max_paths_per_chunk=2,
    parse_fn=parse_fn,
)

In [76]:
from llama_index.core import PropertyGraphIndex

index = PropertyGraphIndex(
    nodes=nodes,
    property_graph_store=GraphRAGStore(),
    kg_extractors=[kg_extractor],
    show_progress=True,
)

Extracting paths from text: 100%|██████████| 1/1 [00:02<00:00,  2.99s/it]
Generating embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
Generating embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.78it/s]


In [86]:
index.property_graph_store.graph.get_triplets()

[(EntityNode(label='Person', embedding=None, properties={'entity_description': 'A physicist who introduced the idea of quantized energy levels in 1900 to explain blackbody radiation, contributing to the origins of quantum mechanics.', 'triplet_source_id': '63381b8a-5226-4886-a823-f5d192731524'}, name='Max Planck'),
  Relation(label='contributed to', source_id='Max Planck', target_id='Quantum Mechanics', properties={'relationship_description': "Max Planck's work on quantized energy levels in 1900 was a foundational contribution to the development of quantum mechanics.", 'triplet_source_id': '63381b8a-5226-4886-a823-f5d192731524'}),
  EntityNode(label='Scientific Theory', embedding=None, properties={'entity_description': 'A fundamental theory in physics describing nature at the smallest scales, developed in the early 20th century, revolutionizing our understanding of the physical world. It has profound implications beyond physics, laying the foundation for quantum chemistry, semiconducto

In [None]:
list(index.property_graph_store.graph.nodes.values())[-1]

EntityNode(label='Person', embedding=None, properties={'entity_description': 'A physicist who explained the photoelectric effect in 1905 by proposing that light consists of discrete packets of energy called photons.  This work earned him the Nobel Prize in Physics in 1921 and further developed the concepts of quantum mechanics.', 'triplet_source_id': '63381b8a-5226-4886-a823-f5d192731524'}, name='Albert Einstein')

In [79]:
list(index.property_graph_store.graph.relations.values())[0]

Relation(label='contributed to', source_id='Max Planck', target_id='Quantum Mechanics', properties={'relationship_description': "Max Planck's work on quantized energy levels in 1900 was a foundational contribution to the development of quantum mechanics.", 'triplet_source_id': '63381b8a-5226-4886-a823-f5d192731524'})

In [80]:
list(index.property_graph_store.graph.relations.values())[0].properties[
    "relationship_description"
]

"Max Planck's work on quantized energy levels in 1900 was a foundational contribution to the development of quantum mechanics."

In [81]:
index.property_graph_store.build_communities()

  response = Gemini().chat(messages)


In [82]:
query_engine = GraphRAGQueryEngine(
    graph_store=index.property_graph_store, llm=llm
)

In [83]:
response = query_engine.query(
    "What are the main news discussed in the document?"
)
display(Markdown(f"{response.response}"))

Planck's quantized energy levels and Einstein's explanation of the photoelectric effect were foundational to the development of quantum mechanics.

In [84]:
response = query_engine.query("What waws the contribution of Albert Einstein?")
display(Markdown(f"{response.response}"))

Einstein explained the photoelectric effect, advancing quantum mechanics.

In [None]:
# pip install pyvis gradio

In [115]:
triplets = index.property_graph_store.graph.get_triplets()
trip = triplets[0]
trip[0].properties
    # source = triplets[i][0].properties.get("name") or triplets[i][0].properties.get("id") or "Unknown"
    # relation = triplets[i][1].label
    # target = triplets[i][2].properties.get("name") or triplets[i+2].properties.get("id") or "Unknown"
        

{'entity_description': 'A physicist who introduced the idea of quantized energy levels in 1900 to explain blackbody radiation, contributing to the origins of quantum mechanics.',
 'triplet_source_id': '63381b8a-5226-4886-a823-f5d192731524'}

In [116]:
from pyvis.network import Network
import webbrowser

# --- Get triplets from LlamaIndex GraphRAG
triplets = index.property_graph_store.graph.get_triplets()

# --- Colors for different entity types
type_colors = {
    "Person": "#FF6961",              # red
    "Scientific Theory": "#779ECB",   # blue
    "Concept": "#77DD77",             # green
    "Location": "#FFD700",            # yellow
    "Organization": "#FFB347",        # orange
    "Unknown": "#D3D3D3"              # grey fallback
}

def build_and_open_graph(triplets):
    net = Network(height="750px", width="100%", directed=True, notebook=False)
    
    for i in range(0, len(triplets), 3):
        print(triplets[i])
        try:
            node1 = triplets[i][0]
            edge = triplets[i][1]
            node2 = triplets[i][2]

            # Get node info
            source = node1.name
            source_type = node1.label
            source_desc = node1.properties.get("entity_description", "")

            target = node2.name or node2.properties.get("id") or "Unknown"
            target_type = node2.label or "Unknown"
            target_desc = node2.properties.get("entity_description", "")

            relation = edge.label or "related_to"
            relation_desc = edge.properties.get("relation_description", "")

            # Add source node
            net.add_node(
                source,
                label=source,
                title=f"{source_type}: {source_desc}",
                color=type_colors.get(source_type, type_colors["Unknown"])
            )

            # Add target node
            net.add_node(
                target,
                label=target,
                title=f"{target_type}: {target_desc}",
                color=type_colors.get(target_type, type_colors["Unknown"])
            )

            # Add edge with tooltip
            net.add_edge(
                source,
                target,
                label=relation,
                title=relation_desc
            )
        except Exception as e:
            print(f"Skipping triplet group {i}-{i+2}: {e}")
            continue

    net.repulsion()
    output_path = "graph.html"
    net.save_graph(output_path)
    webbrowser.open(f"file://{os.path.abspath(output_path)}")

# --- Run it
build_and_open_graph(triplets)


(EntityNode(label='Person', embedding=None, properties={'entity_description': 'A physicist who introduced the idea of quantized energy levels in 1900 to explain blackbody radiation, contributing to the origins of quantum mechanics.', 'triplet_source_id': '63381b8a-5226-4886-a823-f5d192731524'}, name='Max Planck'), Relation(label='contributed to', source_id='Max Planck', target_id='Quantum Mechanics', properties={'relationship_description': "Max Planck's work on quantized energy levels in 1900 was a foundational contribution to the development of quantum mechanics.", 'triplet_source_id': '63381b8a-5226-4886-a823-f5d192731524'}), EntityNode(label='Scientific Theory', embedding=None, properties={'entity_description': 'A fundamental theory in physics describing nature at the smallest scales, developed in the early 20th century, revolutionizing our understanding of the physical world. It has profound implications beyond physics, laying the foundation for quantum chemistry, semiconductors, a

In [None]:
from pyvis.network import Network
import gradio as gr
from llama_index.graph_stores import SimpleGraphStore  # or your actual GraphStore class

# --- Assuming you already have the index with property_graph_store
triplets = index.property_graph_store.graph.get_triplets()

def create_graph_html(triplets):
    g = Network(height="600px", width="100%", directed=True)

    for i in range(0, len(triplets), 3):
        source = triplets[i].properties.get("name") or triplets[i].properties.get("id") or "Unknown"
        relation = triplets[i+1].label
        target = triplets[i+2].properties.get("name") or triplets[i+2].properties.get("id") or "Unknown"
        
        g.add_node(source, label=source)
        g.add_node(target, label=target)
        g.add_edge(source, target, label=relation)

    g.repulsion()  # Makes layout prettier
    g_html_path = "graph.html"
    g.show(g_html_path)
    
    with open(g_html_path, "r", encoding="utf-8") as f:
        return f.read()

# --- Gradio Interface
def show_graph():
    return create_graph_html(triplets)

gr.Interface(fn=show_graph, inputs=[], outputs=gr.HTML()).launch()
