In [9]:
import re
import os
import asyncio 
from io import BytesIO
from typing import IO, List, Optional
from dotenv import load_dotenv
from langchain_community.graphs import Neo4jGraph
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain.chains import GraphCypherQAChain
from langchain_unstructured import UnstructuredLoader
from langchain_community.vectorstores import Neo4jVector
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.graph_transformers import LLMGraphTransformer
from pydantic import BaseModel, Field
from graphdatascience import GraphDataScience
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import logging

load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DuplicateEntities(BaseModel):
    entities: List[str] = Field(
        description="Entities that represent the same object or real-world entity and should be merged"
    )

class Disambiguate(BaseModel):
    merge_entities: Optional[List[DuplicateEntities]] = Field(
        description="Lists of entities that represent the same object or real-world entity and should be merged"
    )

class OptimizeGraph:
    def __init__(self):
        self.graph_chain = None
        self.store = None
        self.graph = None
        self.merged_entities = []

    def init_graph(self):
        self.graph = Neo4jGraph(
            url=os.environ["NEO4J_URI"],
            username=os.environ["NEO4J_USERNAME"],
            password=os.environ["NEO4J_PASSWORD"],
        )

    def init_vector(self):
        index_name = "vector"
        store = Neo4jVector.from_existing_index(
            OpenAIEmbeddings(),
            url=os.environ["NEO4J_URI"],
            username=os.environ["NEO4J_USERNAME"],
            password=os.environ["NEO4J_PASSWORD"],
            index_name=index_name,
        )
        self.store = store

    def add_embeddings_to_nodes(self):
        vector = Neo4jVector.from_existing_graph(
            OpenAIEmbeddings(),
            node_label='__Entity__',
            text_node_properties=['id', 'description'],
            embedding_node_property='embedding',
        )

    def add_relationships_similar(self):
        gds = GraphDataScience(
            os.environ["NEO4J_URI"],
            auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"]),
        )
        G, result = gds.graph.project(
            "entities",                   # Graph name
            "__Entity__",                 # Node projection
            "*",                          # Relationship projection
            nodeProperties=["embedding"]  # Configuration parameters
        )
        similarity_threshold = 0.95

        gds.knn.mutate(
            G,
            nodeProperties=['embedding'],
            mutateRelationshipType='SIMILAR',
            mutateProperty='score',
            similarityCutoff=similarity_threshold
        )

        gds.wcc.write(
            G,
            writeProperty="wcc",
            relationshipTypes=["SIMILAR"]
        )

    def get_potential_duplicate_candidates(self):
        self.init_graph()
        word_edit_distance = 3
        potential_duplicate_candidates = self.graph.query(
            """MATCH (e:`__Entity__`)
            WHERE size(e.id) > 3 // longer than 3 characters
            WITH e.wcc AS community, collect(e) AS nodes, count(*) AS count
            WHERE count > 1
            UNWIND nodes AS node
            // Add text distance
            WITH distinct
            [n IN nodes WHERE apoc.text.distance(toLower(node.id), toLower(n.id)) < $distance 
                        OR node.id CONTAINS n.id | n.id] AS intermediate_results
            WHERE size(intermediate_results) > 1
            WITH collect(intermediate_results) AS results
            // combine groups together if they share elements
            UNWIND range(0, size(results)-1, 1) as index
            WITH results, index, results[index] as result
            WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
                    CASE WHEN index <> index2 AND
                        size(apoc.coll.intersection(acc, results[index2])) > 0
                        THEN apoc.coll.union(acc, results[index2])
                        ELSE acc
                    END
            )) as combinedResult
            WITH distinct(combinedResult) as combinedResult
            // extra filtering
            WITH collect(combinedResult) as allCombinedResults
            UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
            WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
            WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1)
                WHERE x <> combinedResultIndex
                AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
            )
            RETURN combinedResult
            """, params={'distance': word_edit_distance}
        )
        return potential_duplicate_candidates or []

    def LLM_finnal_decision_to_merge(self):
        self.init_graph()
        system_prompt = """You are a data processing assistant. Your task is to identify duplicate entities in a list and decide which of them should be merged.
The entities might be slightly different in format or content, but essentially refer to the same thing. Use your analytical skills to determine duplicates.

Here are the rules for identifying duplicates:
1. Entities with minor typographical differences should be considered duplicates.
2. Entities with different formats but the same content should be considered duplicates.
3. Entities that refer to the same real-world object or concept, even if described differently, should be considered duplicates.
4. If it refers to different numbers, dates, or products, do not merge results
"""
        user_template = """
Here is the list of entities to process:
{entities}

Please identify duplicates, merge them, and provide the merged list.
"""
        extraction_llm = ChatOpenAI(model_name="gpt-4o-mini").with_structured_output(
            Disambiguate
        )
        extraction_prompt = ChatPromptTemplate([
            SystemMessage(content=system_prompt),
            ("human", user_template)
        ])
        extraction_chain = extraction_prompt | extraction_llm

        def entity_resolution(entities: List[str]) -> Optional[List[List[str]]]:
            result = extraction_chain.invoke({"entities": entities})
            if result and result.merge_entities:
                return [el.entities for el in result.merge_entities]
            else:
                logger.warning("Entity resolution returned None or unexpected structure.")
                return []

        MAX_WORKERS = 10
        self.merged_entities = []
        potential_duplicate_candidates = self.get_potential_duplicate_candidates()
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            futures = [
                executor.submit(entity_resolution, el.get('combinedResult', []))
                for el in potential_duplicate_candidates
            ]

            for future in tqdm(
                as_completed(futures), total=len(futures), desc="Processing documents"
            ):
                try:
                    to_merge = future.result() or []
                    if to_merge:
                        self.merged_entities.extend(to_merge)
                except Exception as e:
                    logger.error(f"Error processing future: {e}")

        if self.merged_entities:
            # Validate merged entities
            valid_entities = []
            for group in self.merged_entities:
                # Verify that all nodes exist
                result = self.graph.query("""
                    MATCH (e:__Entity__)
                    WHERE e.id IN $ids
                    RETURN count(e) as count
                """, params={'ids': group})
                if result[0]['count'] == len(group):
                    valid_entities.append(group)
                else:
                    logger.warning(f"Skipping merge for group {group} due to missing nodes.")
            # Proceed with valid groups
            for group in valid_entities:
                try:
                    self.graph.query("""
                        MATCH (e:__Entity__)
                        WHERE e.id IN $candidates
                        WITH collect(e) AS nodes
                        CALL apoc.refactor.mergeNodes(nodes, {properties: {
                            description:'combine',
                            `.*`: 'discard'
                        }})
                        YIELD node
                        RETURN node
                    """, params={"candidates": group})
                except Exception as e:
                    logger.error(f"Failed to merge nodes {group}: {e}")
        else:
            logger.info("No entities to merge.")

    def optimize_graph(self):
        self.init_graph()
        self.add_embeddings_to_nodes()
        self.add_relationships_similar()
        self.LLM_finnal_decision_to_merge()


In [10]:
print(os.environ["NEO4J_URI"])
print(os.environ["NEO4J_USERNAME"])
print(os.environ["NEO4J_PASSWORD"])

neo4j+s://4ee31ee1.databases.neo4j.io
neo4j
WDPZl8DBi-2uDnlUAyLaWFluM-ViC3FdY5LgxUcH1u0


In [11]:
optimize = OptimizeGraph()

In [None]:
# optimize.optimize_graph()

In [12]:
optimize.LLM_finnal_decision_to_merge()

Processing documents:   0%|          | 0/217 [00:00<?, ?it/s]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing documents:   0%|          | 1/217 [00:00<02:45,  1.31it/s]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing documents:   1%|          | 2/217 [00:00<01:22,  2.60it/s]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing documents:   2%|▏         | 5/217 [00:01<00:29,  7.09it/s]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Proc