In [2]:
!pip install llama-index llama-index-graph-stores-neo4j graspologic numpy==1.24.4 scipy==1.12.0 future

Collecting llama-index
  Using cached llama_index-0.12.34-py3-none-any.whl.metadata (12 kB)
Collecting llama-index-graph-stores-neo4j
  Using cached llama_index_graph_stores_neo4j-0.4.6-py3-none-any.whl.metadata (694 bytes)
Collecting graspologic
  Using cached graspologic-3.4.1-py3-none-any.whl.metadata (5.8 kB)
Collecting numpy==1.24.4
  Downloading numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl.metadata (5.6 kB)
Collecting scipy==1.12.0
  Downloading scipy-1.12.0-cp310-cp310-macosx_12_0_arm64.whl.metadata (112 kB)
Collecting future
  Using cached future-1.0.0-py3-none-any.whl.metadata (4.0 kB)
Collecting llama-index-agent-openai<0.5,>=0.4.0 (from llama-index)
  Using cached llama_index_agent_openai-0.4.6-py3-none-any.whl.metadata (727 bytes)
Collecting llama-index-cli<0.5,>=0.4.1 (from llama-index)
  Using cached llama_index_cli-0.4.1-py3-none-any.whl.metadata (1.5 kB)
Collecting llama-index-core<0.13,>=0.12.34 (from llama-index)
  Downloading llama_index_core-0.12.34.post1-py3-none

In [1]:
import pandas as pd
from llama_index.core import Document 

news = pd.read_csv(
    "https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv"
)[:50]
news.head()

Unnamed: 0,title,date,text
0,Chevron: Best Of Breed,2031-04-06T01:36:32.000000000+00:00,JHVEPhoto Like many companies in the O&G secto...
1,FirstEnergy (NYSE:FE) Posts Earnings Results,2030-04-29T06:55:28.000000000+00:00,FirstEnergy (NYSE:FE – Get Rating) posted its ...
2,Dáil almost suspended after Sinn Féin TD put p...,2023-06-15T14:32:11.000000000+00:00,The Dáil was almost suspended on Thursday afte...
3,Epic’s latest tool can animate hyperrealistic ...,2023-06-15T14:00:00.000000000+00:00,"Today, Epic is releasing a new tool designed t..."
4,"EU to Ban Huawei, ZTE from Internal Commission...",2023-06-15T13:50:00.000000000+00:00,The European Commission is planning to ban equ...


In [2]:
documents = [
    Document(text=f"{row['title']}: {row['text']}")
    for i, row in news.iterrows()
]
print(documents[0])

Doc ID: 9e80a75b-0a14-4385-887c-bfcb3b2f6166
Text: Chevron: Best Of Breed: JHVEPhoto Like many companies in the O&G
sector, the stock of Chevron (NYSE:CVX) has declined about 10% over
the past 90-days despite the fact that Q2 consensus earnings estimates
have risen sharply (~25%) during that same time frame. Over the years,
Chevron has kept a very strong balance sheet. That allowed the...


In [6]:
from llama_index.llms.ollama import Ollama
from llama_index.core.settings import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

llm = Ollama(
    model = "llama3",
    base_url = "http://localhost:11434",
    request_timeout = 120.0
)
Settings.llm = llm
Settings.embed_model = HuggingFaceEmbedding(model_name = "BAAI/bge-base-en-v1.5")


In [7]:
import asyncio 
import nest_asyncio

nest_asyncio.apply()

from typing import Any, List, Callable, Optional, Union, Dict 
from IPython.display import Markdown, display

from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import (
    default_parse_triplets_fn,
)
from llama_index.core.graph_stores.types import (
    EntityNode, 
    KG_NODES_KEY,
    KG_RELATIONS_KEY, 
    Relation
)
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import (
    DEFAULT_KG_TRIPLET_EXTRACT_PROMPT, 
)
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.bridge.pydantic import BaseModel, Field 

class GraphRAGExtractor(TransformComponent):
    llm: LLM
    extract_prompt: PromptTemplate
    parse_fn: Callable
    num_workers: int
    max_path_per_chunk: int

    def __init__(
        self, 
        llm: Optional[LLM] = None, 
        extract_prompt: Optional[Union[str, PromptTemplate]] = None,
        parse_fn: Callable = default_parse_triplets_fn,
        max_path_per_chunk: int = 10,
        num_workers: int = 4,
    ) -> None:
        from llama_index.core import Settings

        if isinstance(extract_prompt, str):
            extract_prompt = PromptTemplate(extract_prompt)

        super().__init__(
            llm = llm or Settings.llm,
            extract_prompt = extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
            parse_fn = parse_fn,
            num_workers = num_workers,
            max_path_per_chunk = max_path_per_chunk
        )

    @classmethod 
    def class_name(cls) -> str:
        return "GraphExtractor"

    def __call__(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        return asyncio.run(
            self.acall(nodes, show_progress=show_progress, **kwargs)
        )

    async def acall(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        jobs = []
        for node in nodes:
            jobs.append(self._aextract(node))
        return await run_jobs(
            jobs, 
            workers = self.num_workers,
            show_progress = show_progress,
            desc = "Extracting paths from text"
        )

    async def _aextract(self, node: BaseNode) -> BaseNode:
        assert hasattr(node, "text")

        text = node.get_content(metadata_mode = "llm")
        try:
            llm_response = await self.llm.apredict(
                self.extract_prompt, 
                text=text, 
                max_knowledge_triplets = self.max_path_per_chunk
            )
            entities, entities_relationship = self.parse_fn(llm_response)
        except ValueError: 
            entities = []
            entities_relationship = []

        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
        entity_metadata = node.metadata.copy()

        for entity, entity_type, description in entities:
            entity_metadata["entity_description"] = description
            entity_node = EntityNode(
                name=entity, label=entity_type, properties = entity_metadata
            )
            existing_nodes.append(entity_node)

        relation_metadata = node.metadata.copy()
        for triple in entities_relationship:
            subj, obj, rel, description = triple
            relation_metadata["relationship_description"] = description 
            rel_node = Relation(
                label = rel,
                source_id = subj,
                target_id = obj,
                properties = relation_metadata
            )
            existing_relations.append(rel_node)


        node.metadata[KG_NODES_KEY] = existing_nodes
        node.metadata[KG_RELATIONS_KEY] = existing_relations
        return node




    
    

In [8]:
import re 
import networkx as nx 
from graspologic.partition import hierarchical_leiden
from collections import defaultdict 

from llama_index.core.llms import ChatMessage
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.llms.ollama import Ollama


class GraphRAGStore(Neo4jPropertyGraphStore):
    community_summary = {}
    entity_info = None
    max_cluster_size = 5
    
    llm = Ollama(
    model = "llama3",
    base_url = "http://localhost:11434",
    request_timeout = 120.0
    )

    def generate_community_summary(self, text):
        messages = [
            ChatMessage(
                role="system",
                content=(
                    "You are provided with a set of relationships from a knowledge graph, each represented as "
                    "entity1->entity2->relation->relationship_description. Your task is to create a summary of these "
                    "relationships. The summary should include the names of the entities involved and a concise synthesis "
                    "of the relationship descriptions. The goal is to capture the most critical and relevant details that "
                    "highlight the nature and significance of each relationship. Ensure that the summary is coherent and "
                    "integrates the information in a way that emphasizes the key aspects of the relationships."
                ),
                
            ),
            ChatMessage(role="user", content=text),
        ]
        response = self.llm.chat(messages)
        return str(response).strip()

    def build_communities(self):
        nx_graph = self._create_nx_graph()
        community_hierarchical_clusters = hierarchical_leiden(
            nx_graph, max_cluster_size=self.max_cluster_size
        )
        self.entity_info, community_info = self._collect_community_info(
            nx_graph, community_hierarchical_clusters
        )
        self._summarize_communities(community_info)

    def _create_nx_graph(self):
        nx_graph = nx.Graph
        triplets = self.get_triplets()
        for entity1, relation, entity2 in triplets:
            nx_graph.add_node(entity1.name)
            nx_graph.add_node(entity2.name)
            nx_graph.add_edge(
                relation.source_id, 
                relation.target_id,
                relationship = relation.label,
                description = relation.properties["relationship_description"]
            )
        return nx_graph

    def _collect_community_info(self, nx_graph, clusters):
        entity_info = defaultdict(set)
        community_info = defaultdict(list)

        for item in clusters:
            node = item.node
            cluster_id = item.cluster

            entity_info[node].add(cluster_id)

            for neighbor in nx_graph.neighbors(node):
                edge_data = nx.graph.get_edge_data(node, neighbor)
                if edge_data:
                    detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
                    community_info[cluster_id].append(detail)

        entity_info = {k: list(v) for k, v in entity_info.items()}

        return dict(entity_info), dict(community_info)

    def _summarize_communities(self, community_info):
        for community_id, details in community_info.items():
            defaults_text = (
                "\n".join(details) + "."
            )
            self.community_summary[
                community_id
            ] = self.generate_community_summary(details_text)

    def get_community_summaries(self):
        if not self.community_summary:
            self.build_communities()
        return self.community_summary 
        
        

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [22]:
from llama_index.core.query_engine import CustomQueryEngine 
from llama_index.core.llms import LLM 
from llama_index.core import PropertyGraphIndex 

import re

class GraphRAGQueryEngine(CustomQueryEngine):
    graph_store: GraphRAGStore 
    index: PropertyGraphIndex
    llm:LLM 
    similarity_top_k: int = 20 

    def custom_query(self, query_str: str) -> str:
        entities = self.get_entities(query_str, self.similarity_top_k)

        community_ids = self.retrieve_entity_communities(
            self.graph_store.entity_info, entities
        )
        community_summaries = self.graph_store.get_community_summaries()
        community_answers = [
            self.generate_answer_from_summary(community_summary, query_str)
            for id, community_summary, in community_summaries.items()
            if id in community_ids
        ]

        final_answer = self.aggregate_answers(community_answers)
        return final_answer


    def get_entities(self, query_str, similarity_top_k):
        nodes_retrieved = self.index.as_retriever(
            similarity_top_k=similarity_top_k
        ).retrieve(query_str)

        entities = set()
        pattern = (
            r"^(\w+(?:\s+\w+)*)\s*->\s*([a-zA-Z\s]+?)\s*->\s*(\w+(?:\s+\w+)*)$"
        )

        for node in nodes_retrieved:
            matches = re.findall(
                pattern, node.text, re.MULTILINE | re.IGNORECASE
            )
            for match in matches:
                subject = match[0]
                obj = match[2]
                entities.add(subject)
                entities.add(obj)

        return list(entities) 

    def retrieve_entity_communities(self, entity_info, entities):
        community_ids = []

        for entity in entities:
            if entity in entity_info:
                community_ids.extend(entity_info[entity])

        return list(set(community_ids))

    def generate_answer_from_summary(self, community_summary, query):
        prompt = (
            f"Given the community summary: {community_summary}, "
            f"how would you answer the following query? Query: {query}"
        )
        messages = [
            ChatMessage(role = "system", content=prompt),
            ChatMessage(
                role="user",
                content="I need an answer based on the above information.",
            )
        ]
        response = self.llm.chat(messages)
        cleaned_response = re.sub(
            r"^assistant:\s*", "", str(final_response)
        ).strip()
        return cleaned_response


    def aggregate_answers(self, community_answers):
        prompt = "Combine the following intermediate answers into a final, concise response."
        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content=f"Intermediate answers: {community_answers}",
            ),
        ]
        final_response = self.llm.chat(messages)
        cleaned_final_response = re.sub(
            r"^assistant:\s*", "", str(final_response)
        ).strip()
        return cleaned_final_response

In [23]:
from llama_index.core.node_parser import SentenceSplitter

splitter = SentenceSplitter(
    chunk_size = 1024,
    chunk_overlap = 20,
)

nodes = splitter.get_nodes_from_documents(documents)

In [24]:
len(nodes)

50

In [25]:
KG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.

-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities

2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other

3. Output Formatting:
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.

-An Output Example-
{
  "entities": [
    {
      "entity_name": "Albert Einstein",
      "entity_type": "Person",
      "entity_description": "Albert Einstein was a theoretical physicist who developed the theory of relativity and made significant contributions to physics."
    },
    {
      "entity_name": "Theory of Relativity",
      "entity_type": "Scientific Theory",
      "entity_description": "A scientific theory developed by Albert Einstein, describing the laws of physics in relation to observers in different frames of reference."
    },
    {
      "entity_name": "Nobel Prize in Physics",
      "entity_type": "Award",
      "entity_description": "A prestigious international award in the field of physics, awarded annually by the Royal Swedish Academy of Sciences."
    }
  ],
  "relationships": [
    {
      "source_entity": "Albert Einstein",
      "target_entity": "Theory of Relativity",
      "relation": "developed",
      "relationship_description": "Albert Einstein is the developer of the theory of relativity."
    },
    {
      "source_entity": "Albert Einstein",
      "target_entity": "Nobel Prize in Physics",
      "relation": "won",
      "relationship_description": "Albert Einstein won the Nobel Prize in Physics in 1921."
    }
  ]
}

-Real Data-
######################
text: {text}
######################
output:"""

In [29]:
import json 

def parse_fn(response_str: str) -> Any:
    json_pattern = r"\{.*\}"
    match = re.search(json_pattern, response_str, re.DOTALL)
    entities = []
    relationships = []
    if not match:
        return entities, relationships 
    json_str = match.group(0)

    try:
        data = json.loads(json_str)
        entities = [
            (
                entity["entity_name"],
                entity["entity_type"],
                entity["entity_description"],
            )
            for entity in data.get("entities", [])
        ]
        relationships = [
            (
                relation["source_entity"],
                relation["target_entity"],
                relation["relation"],
                relation["relationship_description"]

            )
            for relation in data.get("relationships", [])
        ]
    except json.JSONDecodeError as e:
        print("Error parsing JSON:", e)
        return entities, relationships


kg_extractor = GraphRAGExtractor(
    llm=llm, 
    extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
    max_path_per_chunk=2,
    parse_fn = parse_fn
)
    

In [30]:
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore 

graph_store = GraphRAGStore(
    username="neo4j", password="S00stest!", url="bolt://localhost:7687"
)

In [32]:
from llama_index.core import PropertyGraphIndex

index = PropertyGraphIndex(
    nodes=nodes, 
    kg_extractor=[kg_extractor],
    property_graph_store = graph_store,
    show_progress = True
)

ValueError: 
******
Could not load OpenAI embedding model. If you intended to use OpenAI, please check your OPENAI_API_KEY.
Original error:
No API key found for OpenAI.
Please set either the OPENAI_API_KEY environment variable or openai.api_key prior to initialization.
API keys can be found or created at https://platform.openai.com/account/api-keys

Consider using embed_model='local'.
Visit our documentation for more embedding options: https://docs.llamaindex.ai/en/stable/module_guides/models/embeddings.html#modules
******