In [1]:
import os
import openai
# from openai import OpenAI
# from langchain.chains import RetrievalQA
# from langchain.chat_models import ChatOpenAI

from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.embeddings.openai import OpenAIEmbeddings

from graphdatascience import GraphDataScience

from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Optional
from langchain_core.prompts import ChatPromptTemplate
from retry import retry
from pydantic import BaseModel
from langchain_openai import ChatOpenAI

from graphdatascience import GraphDataScience
from langchain_community.graphs import Neo4jGraph
from neo4j import GraphDatabase

from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

import matplotlib.pyplot as plt

import pandas as pd
import tiktoken
import seaborn as sns
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm

For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


# Entity resolution

### NEO4j setup

In [2]:
os.environ["NEO4J_URI"] = os.getenv('NEO4J_URI')
os.environ["NEO4J_USERNAME"] = os.getenv('NEO4J_USERNAME')
os.environ["NEO4J_PASSWORD"] = os.getenv('NEO4J_PASSWORD')
os.environ["OPENAI_API_KEY"] =  os.getenv('OPENAI_API_KEY')
NEO4j_URI = os.getenv('NEO4J_URI')
NEO4j_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD')

graph = Neo4jGraph(url = NEO4j_URI, username=NEO4j_USERNAME, password = NEO4j_PASSWORD, refresh_schema=False)

graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Chunk) REQUIRE c.id IS UNIQUE")
graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (c:AtomicFact) REQUIRE c.id IS UNIQUE")
graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (c:KeyElement) REQUIRE c.id IS UNIQUE")
graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Document) REQUIRE d.id IS UNIQUE")

[]

<langchain_community.graphs.neo4j_graph.Neo4jGraph at 0x177c950d0>

In [191]:
host = os.getenv('NEO4J_URI')
user = os.getenv('NEO4J_USERNAME')
password = os.getenv('NEO4J_PASSWORD')
driver = GraphDatabase.driver(host, auth=(user, password))

NEO4J_URI = os.getenv('NEO4J_URI')
username = os.getenv('NEO4J_USERNAME')
password = os.getenv('NEO4J_PASSWORD')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

### generate embeddings for key elements

In [192]:
vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY),
    url=NEO4J_URI,
    username=username,
    password= password,
    node_label="KeyElement", #["Document","Chunk", "AtomicFact", "KeyElement"],
    text_node_properties=["id"],#["id", "text", "index"], #['name', 'description', 'status'],
    embedding_node_property="embedding", #'embedding'
)

### Graph data science set up

In [193]:
# project graph
gds = GraphDataScience(
    NEO4J_URI,
    auth=(username,password)
)

G, result = gds.graph.project(
    "entities",                   #  Graph name
    "KeyElement",                 #  Node projection
    "*",                          #  Relationship projection
    nodeProperties=["embedding"]  #  Configuration parameters
)

Loading: 100%|██████████| 100.0/100 [00:01<00:00, 85.29%/s] 


### k-nearest neighbors similarity search using the embedding property

In [195]:
similarity_threshold = 0.95

gds.knn.mutate(
  G,
  nodeProperties=['embedding'],
  mutateRelationshipType= 'SIMILAR',
  mutateProperty= 'score',
  similarityCutoff=similarity_threshold
)

K-Nearest Neighbours: 100%|██████████| 100.0/100 [00:04<00:00, 20.75%/s]


ranIterations                                                            35
nodePairsConsidered                                                11996717
didConverge                                                            True
preProcessingMillis                                                       2
computeMillis                                                          5137
mutateMillis                                                            100
postProcessingMillis                                                      0
nodesCompared                                                         13710
relationshipsWritten                                                 102038
similarityDistribution    {'min': 0.9499969482421875, 'p5': 0.9520568847...
configuration             {'mutateProperty': 'score', 'jobId': 'f24cbc9d...
Name: 0, dtype: object

In [196]:
gds.wcc.write(
    G,
    writeProperty="wcc",
    relationshipTypes=["SIMILAR"]
)
     

writeMillis                                                            125
nodePropertiesWritten                                                13710
componentCount                                                        1354
componentDistribution    {'min': 1, 'p5': 1, 'max': 12067, 'p999': 19, ...
postProcessingMillis                                                     7
preProcessingMillis                                                      0
computeMillis                                                           25
configuration            {'writeProperty': 'wcc', 'jobId': 'bced7c1c-09...
Name: 0, dtype: object

### Identifying Duplicate Entities

In [197]:
word_edit_distance = 3
potential_duplicate_candidates = graph.query(
    """MATCH (e:`KeyElement`)
    WHERE size(e.id) > 4 // longer than 4 characters
    WITH e.wcc AS community, collect(e) AS nodes, count(*) AS count
    WHERE count > 1
    UNWIND nodes AS node
    // Add text distance
    WITH distinct
      [n IN nodes WHERE apoc.text.distance(toLower(node.id), toLower(n.id)) < $distance | n.id] AS intermediate_results
    WHERE size(intermediate_results) > 1
    WITH collect(intermediate_results) AS results
    // combine groups together if they share elements
    UNWIND range(0, size(results)-1, 1) as index
    WITH results, index, results[index] as result
    WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
            CASE WHEN index <> index2 AND
                size(apoc.coll.intersection(acc, results[index2])) > 0
                THEN apoc.coll.union(acc, results[index2])
                ELSE acc
            END
    )) as combinedResult
    WITH distinct(combinedResult) as combinedResult
    // extra filtering
    WITH collect(combinedResult) as allCombinedResults
    UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
    WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
    WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1)
        WHERE x <> combinedResultIndex
        AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
    )
    RETURN combinedResult
    """, params={'distance': word_edit_distance})
potential_duplicate_candidates[:5]

[{'combinedResult': ['Dose effect', 'Noise Effect']},
 {'combinedResult': ['Larger σ', 'large K', 'large k0', 'large τ']},
 {'combinedResult': ['Threshold', 'Threshold c', 'Threshold t', 'Thresholds']},
 {'combinedResult': ['Label', 'PATEL', 'Patel']},
 {'combinedResult': ['Fan et al. (1993)', 'Fang et al. (1990)']}]

### Processing Duplicate Entities with LLM

In [199]:
system_prompt = """You are a data processing assistant. Your task is to identify duplicate entities in a list and decide which of them should be merged.
The entities might be slightly different in format or content, but essentially refer to the same thing. Use your analytical skills to determine duplicates.

Here are the rules for identifying duplicates:
1. Entities with minor typographical differences should be considered duplicates.
2. Refrain from merging entities based on equations, as they may have different meanings.
3. Entities that refer to the same real-world object or concept, even if described differently, should be considered duplicates.
4. If it refers to different numbers, dates, or products, do not merge results
5. Mathematical equations must be identical to be considered duplicates.
"""
user_template = """
Here is the list of entities to process:
{entities}

Please identify duplicates, merge them, and provide the merged list.
"""

In [201]:
class DuplicateEntities(BaseModel):
    entities: List[str] = Field(
        description="Entities that represent the same object or real-world entity and should be merged"
    )


class Disambiguate(BaseModel):
    merge_entities: Optional[List[DuplicateEntities]] = Field(
        description="Lists of entities that represent the same object or real-world entity and should be merged"
    )

extraction_llm = ChatOpenAI(model_name="gpt-4o-mini").with_structured_output(
    Disambiguate
)

extraction_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            system_prompt,
        ),
        (
            "human",
            user_template,
        ),
    ]
)



In [202]:

extraction_chain = extraction_prompt | extraction_llm

@retry(tries=3, delay=2)
def entity_resolution(entities: List[str]) -> Optional[List[str]]:
    return [
        el.entities
        for el in extraction_chain.invoke({"entities": entities}).merge_entities
    ]


In [203]:

MAX_WORKERS = 4

merged_entities = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submitting all tasks and creating a list of future objects
    futures = [
        executor.submit(entity_resolution, el['combinedResult'])
        for el in potential_duplicate_candidates
    ]

    for future in tqdm(
        as_completed(futures), total=len(futures), desc="Processing documents"
    ):
        try:
            to_merge = future.result()
            # Ensure to_merge is not None and is iterable
            if to_merge:  # this checks if it's not None or an empty value
                merged_entities.extend(to_merge)
        except Exception as e:
            # Handle any exception raised during the processing of a future
            print(f"Error processing future: {e}")

Processing documents:   4%|▎         | 58/1656 [00:24<13:41,  1.95it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:   4%|▍         | 63/1656 [00:26<11:53,  2.23it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  16%|█▌        | 263/1656 [01:17<06:39,  3.49it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  22%|██▏       | 371/1656 [01:43<05:22,  3.98it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  29%|██▉       | 488/1656 [02:10<04:34,  4.26it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  40%|████      | 670/1656 [02:59<11:24,  1.44it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  44%|████▍     | 735/1656 [03:19<04:55,  3.12it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  52%|█████▏    | 854/1656 [03:53<10:11,  1.31it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  52%|█████▏    | 859/1656 [03:55<04:31,  2.94it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  63%|██████▎   | 1049/1656 [04:40<03:38,  2.78it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  88%|████████▊ | 1458/1656 [06:17<00:41,  4.79it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents:  93%|█████████▎| 1540/1656 [06:37<00:36,  3.18it/s]

Error processing future: 'NoneType' object is not iterable


Processing documents: 100%|██████████| 1656/1656 [07:15<00:00,  3.80it/s]


### Merge entities in the graph using query

In [224]:
# break down merged entities into chunks of 100
merged_entities_chunks = [merged_entities[i:i + 100] for i in range(0, len(merged_entities), 100)]


for i, chunk in enumerate(merged_entities_chunks):
  try:
    graph.query("""
    UNWIND $data AS candidates
    CALL {
      WITH candidates
      MATCH (e:KeyElement) WHERE e.id IN candidates
      RETURN collect(e) AS nodes
    }
    CALL apoc.refactor.mergeNodes(nodes, {properties: {
        description:'combine',
        `.*`: 'discard'
    }})
    YIELD node
    RETURN count(*)
    """, params={"data": chunk})
  except Exception as e:
    print(e)
    print("failed to merge chunk", i)


### Inspect graph

In [225]:
graph.query("""MATCH (e) RETURN count(e)""")

[{'count(e)': 18401}]

In [226]:
graph.query("""MATCH (e:KeyElement) RETURN count(e)""")

[{'count(e)': 10795}]

In [82]:
G

Graph({'graphName': 'entities', 'nodeCount': 11484, 'relationshipCount': 85313, 'database': 'neo4j', 'configuration': {'relationshipProjection': {'__ALL__': {'aggregation': 'DEFAULT', 'orientation': 'NATURAL', 'indexInverse': False, 'properties': {}, 'type': '*'}}, 'readConcurrency': 4, 'relationshipProperties': {}, 'nodeProperties': {}, 'jobId': '359d3bf1-53f2-4d92-982a-6697ff240bcd', 'nodeProjection': {'KeyElement': {'label': 'KeyElement', 'properties': {'embedding': {'property': 'embedding', 'defaultValue': None}}}}, 'logProgress': True, 'validateRelationships': False, 'sudo': False}, 'schema': {'graphProperties': {}, 'nodes': {'KeyElement': {'embedding': 'List of Float (DefaultValue(null), PERSISTENT)'}}, 'relationships': {'__ALL__': {}, 'SIMILAR': {'score': 'Float (DefaultValue(NaN), PERSISTENT, Aggregation.NONE)'}}}, 'memoryUsage': '95 MiB'})

In [None]:
gds.graph.export(G, dbName = "thisisauniquename")

In [None]:
gds.graph.get('entities')