# Building a Knowledge Graph with Small Language Models: A Comparative Approach

This notebook demonstrates the process of building a Knowledge Graph from unstructured text using small, locally-run Language Models (LLMs). We will explore two primary approaches:

1.  **The Standard LangChain Approach**: Using the built-in `LLMGraphTransformer` from LangChain. We will see the challenges and limitations, particularly the high failure rate in extracting structured graph data with smaller models like Mistral.
2.  **The BAML-Enhanced Approach**: Leveraging Boundary's AI-programming language (BAML) to significantly improve the reliability and accuracy of graph extraction from a model like Llama 3.

Finally, after successfully generating high-quality graph documents with BAML, we will ingest them into a Neo4j graph database and perform various analyses, including community detection and entity resolution, to uncover insights from the data.

In [1]:
# Magic command to load the autoreload extension.
%load_ext autoreload
# Magic command to automatically reload all modules before executing a cell.
# This is useful for development when you are changing the source code of imported modules.
%autoreload 2

## Part 1: The Standard LangChain Approach with `LLMGraphTransformer`

In this section, we'll start with the standard method for graph extraction provided by LangChain. We will use `LLMGraphTransformer`, which is designed to take a sequence of documents and convert them into graph documents. We'll use the `mistral` model, a popular small LLM, to see how well this approach works out-of-the-box.

In [2]:
# Import necessary libraries from LangChain and other standard packages.
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_ollama import ChatOllama
import pandas as pd
from typing import List

# Import specific data structures for handling graph data.
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Field

In [3]:
# Define the model name to be used.
model = "mistral"

# Initialize the ChatOllama instance.
# This connects to a locally running Ollama service to use the specified model.
# We set a very low temperature to make the model's output more deterministic and less random.
llm = ChatOllama(model=model, temperature=0.001)

### Testing Structured Output with Pydantic

Before diving into the complex task of graph extraction, let's test the LLM's ability to produce structured output using a simpler Pydantic model. LangChain's `.with_structured_output` method allows us to specify a schema (like a Pydantic class) that the LLM should conform to. This is a good way to gauge the model's instruction-following capabilities.

In [4]:
# Define a Pydantic model named 'Joke'.
# This class specifies the desired structure for a joke, with a 'setup' and a 'punchline'.
# The Field descriptions guide the LLM on what content to generate for each attribute.
class Joke(BaseModel):
    setup: str = Field(description="The setup of the joke")
    punchline: str = Field(description="The punchline to the joke")

In [5]:
# Bind the LLM to the structured output format defined by the Joke class.
# This creates a new LangChain Runnable that will automatically prompt the LLM to return a JSON object
# matching the Joke schema.
llm.with_structured_output(Joke)


RunnableBinding(bound=ChatOllama(model='mistral', temperature=0.001), kwargs={'tools': [{'type': 'function', 'function': {'name': 'Joke', 'description': '', 'parameters': {'type': 'object', 'properties': {'setup': {'description': 'The setup of the joke', 'type': 'string'}, 'punchline': {'description': 'The punchline to the joke', 'type': 'string'}}, 'required': ['setup', 'punchline']}}}], 'tool_choice': 'any'}, config={}, config_factories=[])
| PydanticToolsParser(first_tool_only=True, tools=[<class '__main__.Joke'>])

### Loading and Preparing the Dataset

We will use a dataset of news articles. To manage the LLM's context window and processing time, it's crucial to understand the size of our text chunks. We'll use `tiktoken` to count the number of tokens in each article, giving us a good proxy for the input size.

In [6]:
# Import pandas for data manipulation and tiktoken for token counting.
import pandas as pd
import tiktoken

# Function to calculate the number of tokens in a string for a given model.
def num_tokens_from_string(string: str, model: str = "gpt-4o") -> int:
    """Returns the number of tokens in a text string."""
    # Get the encoding for the specified model.
    encoding = tiktoken.encoding_for_model(model)
    # Encode the string and return the length of the resulting token list.
    num_tokens = len(encoding.encode(string))
    return num_tokens

# Load the news articles dataset from a remote CSV file.
news = pd.read_csv(
    "https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv"
)
# Calculate the token count for each article by combining its title and text.
news["tokens"] = [
    num_tokens_from_string(f"{row['title']} {row['text']}")
    for i, row in news.iterrows()
]
# Display the first few rows of the DataFrame to inspect the data and token counts.
news.head()

Unnamed: 0,title,date,text,tokens
0,Chevron: Best Of Breed,2031-04-06T01:36:32.000000000+00:00,JHVEPhoto Like many companies in the O&G secto...,78
1,FirstEnergy (NYSE:FE) Posts Earnings Results,2030-04-29T06:55:28.000000000+00:00,FirstEnergy (NYSE:FE – Get Rating) posted its ...,130
2,Dáil almost suspended after Sinn Féin TD put p...,2023-06-15T14:32:11.000000000+00:00,The Dáil was almost suspended on Thursday afte...,631
3,Epic’s latest tool can animate hyperrealistic ...,2023-06-15T14:00:00.000000000+00:00,"Today, Epic is releasing a new tool designed t...",528
4,"EU to Ban Huawei, ZTE from Internal Commission...",2023-06-15T13:50:00.000000000+00:00,The European Commission is planning to ban equ...,281


In [7]:
# Import plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns

# Create a histogram of the token counts to visualize the distribution of article lengths.
sns.histplot(news["tokens"], kde=False)
plt.title("Distribution of chunk sizes")
plt.xlabel("Token count")
plt.ylabel("Frequency")
plt.show()

### Extracting Graph Documents

Now we'll configure the `LLMGraphTransformer`. This component takes an LLM and prompts it to extract nodes and relationships from text. We specify that we want a `description` property for both nodes and relationships to capture more context.

To speed up the process, we'll run the extraction on multiple articles concurrently using a `ThreadPoolExecutor`.

In [8]:
# Initialize the LLMGraphTransformer with our local LLM.
llm_transformer = LLMGraphTransformer(
    llm=llm, # The language model to use for extraction.
    # Define the properties to extract for each node. Here, we only want a 'description'.
    node_properties=["description"],
    # Define the properties to extract for each relationship.
    relationship_properties=["description"]
)

In [9]:
# Define a helper function to process a single piece of text.
def process_text(text: str) -> List[GraphDocument]:
    # Create a LangChain Document object from the input text.
    doc = Document(page_content=text)
    # Use the transformer to convert the document into a list of GraphDocument objects.
    return llm_transformer.convert_to_graph_documents([doc])

In [10]:
# Import libraries for concurrent execution and progress tracking.
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# Set the maximum number of worker threads.
MAX_WORKERS = 10
# Set the number of articles to process for this experiment.
NUM_ARTICLES = 20
# Initialize an empty list to store the results.
graph_documents_mistral = []

# Use a ThreadPoolExecutor for parallel processing.
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submit all text processing tasks to the executor.
    futures = [
        executor.submit(process_text, f"{row['title']} {row['text']}")
        for i, row in news.head(NUM_ARTICLES).iterrows()
    ]
    
    # Iterate through the completed futures as they finish and show a progress bar.
    for future in tqdm(
        as_completed(futures), total=len(futures), desc="Processing documents"
    ):
        # Get the result from the completed future.
        graph_document = future.result()
        # Extend the main list with the newly processed graph documents.
        graph_documents_mistral.extend(graph_document)


Processing documents:   0%|          | 0/20 [00:00<?, ?it/s]

Processing documents: 100%|██████████| 20/20 [01:32<00:00,  4.64s/it]


In [11]:
# Display the first result to inspect the output.
# Note that for many articles, the nodes and relationships lists are empty, indicating a failure.
graph_documents_mistral[:1]

[GraphDocument(nodes=[], relationships=[], source=Document(metadata={}, page_content='XPeng Stock Rises. The Tesla Rival Rolled Out Self-Driving Tech. Chinese electric-vehicle maker\nXPeng\nsaid Thursday its assisted-driving technology has been launched in Beijing and three other cities. The\nTesla\nrival’s stock was rising in premarket trading.'))]

### Analyzing the Extraction Failure Rate

A key indicator of success is whether the LLM was able to extract any nodes and relationships. If a `GraphDocument` object has no nodes, it means the LLM failed to parse the text and return the structured data in the expected format. Let's calculate the percentage of these failures.

In [12]:
# Initialize a counter for failed extractions (empty graph documents).
empty_count = 0

# Iterate through all the processed documents.
for doc in graph_documents_mistral:
    # Check if the 'nodes' list is empty.
    if not doc.nodes:
        # Increment the counter if no nodes were extracted.
        empty_count += 1

In [13]:
# Calculate and print the percentage of documents for which the LLM failed to extract a graph.
print(f"Percentage missing: {empty_count/len(graph_documents_mistral)*100}")

Percentage missing: 75.0


As we can see, the standard `LLMGraphTransformer` with a small model like `mistral` has a **75% failure rate**. This is unacceptably high for building a reliable knowledge graph. The small LLM struggles to consistently adhere to the complex JSON format required by the transformer's internal prompt.

This highlights the need for a more robust method to guide the LLM's output.

## Part 2: The BAML-Enhanced Approach

To address the high failure rate, we now turn to BAML (Boundary AI Markup Language). BAML is a configuration language that helps bridge the gap between natural language instructions and structured, typed outputs from LLMs. It allows us to define functions with clear input/output types, use Jinja for templating prompts, and set up robust parsing and retries, making it ideal for reliable structured data extraction.

In this part, we will use BAML with the `llama3` model to perform the same graph extraction task.

In [14]:
# Import the BAML client, which is the interface to our BAML-defined functions.
# This client is auto-generated from the `.baml` files in the `baml_src` directory.
import baml_client as client

In [15]:
# Import additional libraries for the BAML approach.
from langchain_core.prompts import ChatPromptTemplate

# Import the custom system prompt defined for our graph extraction task.
from prompts.graphragprompts import system_prompt

In [16]:
# Define the model name. For this improved approach, we'll use Llama 3.
model = "llama3"

# Initialize the ChatOllama instance for Llama 3.
llm = ChatOllama(model=model, temperature=0.001)

### Defining the BAML-Powered LangChain Runnable

Here, we define a set of functions that will be chained together to create our final graph extraction pipeline. BAML functions are decorated with `@chain` to be seamlessly integrated into a LangChain Runnable sequence.

- **Formatting Functions (`_format_nodes`, `_format_relationships`)**: These are utility functions to standardize the output from the LLM, ensuring node IDs and relationship types are consistently capitalized and formatted.
- **`get_graph`**: This is an async BAML function (`b.ExtractGraph`) that takes the raw text and calls the LLM to extract nodes and relationships based on the schemas defined in our `.baml` files.
- **`get_entities`**: This function is used later for entity resolution, calling another BAML function (`b.ExtractDeDupe`) to merge similar entities.

In [17]:
# Import necessary components for building the chain.
from typing import Any
from baml_client.async_client import b
from langchain_core.runnables import chain
from langchain_experimental.graph_transformers.llm import create_simple_model

# Helper function to format the extracted nodes consistently.
# It capitalizes the ID and type for standardization.
def _format_nodes(nodes: List[Node]) -> List[Node]:
    return [
        Node(
            id=el.id.title() if isinstance(el.id, str) else el.id,
            type=el.type.capitalize() if el.type else None,
            properties=(
                el.properties.dict()
                if hasattr(el.properties, "dict")
                else el.properties
            ),
        )
        for el in nodes
    ]

# Helper function to map the BAML relationship object to the base LangChain Relationship object.
def map_to_base_relationship(rel: Any) -> Relationship:
    """Map the BAML Relationship to the base LangChain Relationship."""
    source = Node(id=rel.source_node_id, type=rel.source_node_type)
    target = Node(id=rel.target_node_id, type=rel.target_node_type)
    properties = {}
    if hasattr(rel, "properties") and rel.properties:
        properties = rel.properties.model_dump()
    return Relationship(
        source=source, target=target, type=rel.type, properties=properties
    )

# Helper function to format a list of relationships.
# It standardizes the source/target nodes and the relationship type (e.g., replacing spaces with underscores).
def _format_relationships(rels) -> List[Relationship]:
    relationships = [
        map_to_base_relationship(rel)
        for rel in rels
        if rel.type and rel.source_node_id and rel.target_node_id
    ]
    return [
        Relationship(
            source=_format_nodes([el.source])[0],
            target=_format_nodes([el.target])[0],
            type=el.type.replace(" ", "_").upper(),
            properties=(
                el.properties.dict()
                if hasattr(el.properties, "dict")
                else el.properties
            ),
        )
        for el in relationships
    ]

# Decorate the function with @chain to make it a LangChain Runnable.
# This function calls the BAML `ExtractGraph` function asynchronously.
@chain
async def get_graph(message):
    graph = await b.ExtractGraph(graph=message.content)
    return graph

# Define another BAML-powered chain for deduplicating entities.
# This will be used later in the notebook for entity resolution.
@chain
def get_entities(message):
    entities = b.ExtractDeDupe(graph=message.content)
    return entities.merged_results

### Assembling the LangChain Pipeline

With our components defined, we now create the prompt template and chain them together into a single pipeline. The flow will be:

1.  **Prompt**: Format the input text using our custom prompt template.
2.  **LLM**: Send the formatted prompt to the `llama3` model.
3.  **BAML Parser (`get_graph`)**: Take the raw LLM output and use BAML's robust parsing and type-checking to convert it into a structured `Graph` object.

In [18]:
# Create a ChatPromptTemplate from a system and human message.
default_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            # The system prompt contains the detailed instructions and schema for the LLM.
            system_prompt,
        ),
        (
            "human",
            (
                "Tip: Make sure to answer in the correct format and do "
                "not include any explanations. "
                "Use the given format to extract information from the "
                "following input: {input}"
            ),
        ),
    ]
)

In [19]:
# Define the complete extraction chain using the LangChain Expression Language (LCEL) pipe operator.
chain = default_prompt | llm | get_graph

### Asynchronous Processing with BAML

BAML is designed to be async-first, which is highly efficient for I/O-bound tasks like making numerous calls to an LLM API. We define async helper functions to process a list of documents in parallel using Python's `asyncio` library. This is significantly faster than processing them one by one.

In [20]:
# Import the asyncio library for asynchronous programming.
import asyncio
from typing import List, Optional, Sequence

# Define an async function to process a single response.
async def aprocess_response(document: Document) -> GraphDocument:
    # Asynchronously invoke our chain with the document's content.
    resp = await chain.ainvoke({"input": document.page_content})
    # Format the structured response from BAML into a LangChain GraphDocument.
    return GraphDocument(
        nodes=_format_nodes(resp.nodes),
        relationships=_format_relationships(resp.relationships),
        source=document,
    )

# Define an async function to convert a sequence of documents to GraphDocuments.
async def aconvert_to_graph_documents(
    documents: Sequence[Document],
) -> List[GraphDocument]:
    # Create a list of async tasks, one for each document.
    tasks = [asyncio.create_task(aprocess_response(document)) for document in documents]
    # Run all tasks concurrently and wait for them to complete.
    results = await asyncio.gather(*tasks)
    return results

# Define a top-level async function to process a list of raw text strings.
async def aprocess_text(texts: List[str]) -> List[GraphDocument]:
    # Convert raw texts to LangChain Document objects.
    docs = [Document(page_content=text) for text in texts]
    # Call the conversion function to get the final graph documents.
    graph_docs = await aconvert_to_graph_documents(docs)
    return graph_docs

#### Testing the BAML Chain
Let's run a quick test on a simple sentence to see the BAML-powered chain in action.

In [21]:
# Run an async process on a single test document.
temp = await aprocess_response(Document(page_content="elon musk sued open ai"))
# Print the resulting GraphDocument.
print(temp)

nodes=[Node(id='Elon_Musk', type='Person', properties={'description': 'entrepreneur and business magnate'}), Node(id='Open_Ai', type='Organization', properties={'description': 'artificial intelligence research laboratory'})] relationships=[Relationship(source=Node(id='Elon_Musk', type='Person', properties={}), target=Node(id='Open_Ai', type='Organization', properties={}), type='PROTEST', properties={'description': None})] source=Document(metadata={}, page_content='elon musk sued open ai')


### Running Full Extraction with BAML

Now, we'll run the extraction process on a larger set of articles. We process the articles in chunks to manage memory and API call volume. The async nature of our pipeline allows for efficient processing of these chunks.

In [22]:
# Re-initialize the list for the new, successful results.
graph_documents = []
# Set a larger number of articles to process. (Note: reduced for demonstration purposes).
NUM_ARTICLES = 36
news_subset = news.head(NUM_ARTICLES)
titles = news_subset["title"]
texts = news_subset["text"]
# Define the size of each processing chunk.
chunk_size = 4

# Loop through the articles in chunks.
for i in tqdm(range(0, len(titles), chunk_size), desc="Processing Chunks"):
    # Get the current chunk of titles and texts.
    title_chunk = titles[i : i + chunk_size]
    text_chunk = texts[i : i + chunk_size]
    # Combine title and text for each article in the chunk.
    combined_docs = [f"{title} {text}" for title, text in zip(title_chunk, text_chunk)]

    # Use a try-except block to handle potential errors during the async processing.
    try:
        # Run the async text processing function.
        docs = await aprocess_text(combined_docs)
        graph_documents.extend(docs)
    except Exception as e:
        print(f"\n\n*************Error, Request failed: {str(e)}\n\n")

    print(f"--- End of Chunk {i} ---")

Processing Chunks:  11%|█         | 1/9 [00:15<02:02, 15.34s/it]

--- End of Chunk 0 ---


Processing Chunks:  22%|██▏       | 2/9 [00:30<01:43, 14.85s/it]

--- End of Chunk 4 ---


In [23]:
# Check the number of graph documents successfully created.
len(graph_documents)

36

### Analyzing the BAML-Enhanced Failure Rate

Let's perform the same failure analysis on the results from our BAML pipeline. We expect a dramatic improvement.

In [24]:
# Reset the counter.
empty_count = 0

# Iterate through the graph documents generated by the BAML chain.
for doc in graph_documents:
    # Check for empty node lists.
    if not doc.nodes:
        empty_count += 1
# Calculate and print the new, lower failure rate.
print(f"Percentage missing: {empty_count/len(graph_documents)*100}")

Percentage missing: 0.0


The results are clear: the failure rate has dropped from **75%** to **0%**. This demonstrates the power of BAML in providing the necessary structure, prompting, and error-handling to enable even small LLMs to perform complex structured data extraction tasks reliably.

With a high-quality set of graph documents, we can now proceed to build and analyze our knowledge graph.

### Saving Our Progress

Since the extraction process can be time-consuming, it's a good practice to save the generated `graph_documents`. We'll use `pickle` to serialize and save the list to a file, so we can easily load it back later without re-running the extraction.

In [None]:
# Import the pickle and os libraries for saving the results.
import pickle
import os

# Ensure the 'data' directory exists.
os.makedirs("data", exist_ok=True)

# Save the list of graph documents to a pickle file.
# 'wb' mode opens the file for writing in binary format.
with open("data/graph_documents.pkl", "wb") as f:
    pickle.dump(graph_documents, f)

## Part 3: Building and Analyzing the Knowledge Graph in Neo4j

In this final section, we will take our successfully extracted `GraphDocument` objects and ingest them into a Neo4j database. Once the data is in a native graph format, we can use powerful graph algorithms and queries to explore connections, identify communities, and derive meaningful insights.

In [26]:
# Import necessary libraries for interacting with Neo4j.
from langchain_community.graphs import Neo4jGraph

# Set environment variables for Neo4j connection.
# Replace with your actual credentials and database name.
os.environ["NEO4J_URI"] = "bolt://localhost:7687"
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = "password" # CHANGE THIS TO YOUR PASSWORD
os.environ["DATABASE"] = "graphragdemo"

# Initialize the Neo4jGraph instance which provides an interface to the database.
graph = Neo4jGraph(
    url=os.environ["NEO4J_URI"],
    username=os.environ["NEO4J_USERNAME"],
    password=os.environ["NEO4J_PASSWORD"],
    database=os.environ["DATABASE"],
)

### Ingesting Graph Documents

We use the `.add_graph_documents()` method to populate our Neo4j database. This method intelligently merges nodes with the same ID and creates relationships between them, effectively building the graph structure from our extracted data.

In [27]:
# Clear the existing graph to start fresh.
graph.query("MATCH (n) DETACH DELETE n")

# Add the generated graph documents to the Neo4j graph.
# `baseEntityLabel=True` adds a `__Entity__` label to all nodes for easier querying.
# `include_source=True` creates a `Document` node for each source article and links entities to it.
graph.add_graph_documents(graph_documents, baseEntityLabel=True, include_source=True)

### Graph Analytics with GraphDataScience Library

With our data in Neo4j, we can now use the Graph Data Science (GDS) library for advanced analysis. We'll start by exploring the basic properties of our graph.

#### Entity Count vs. Token Count
Let's see if there is a correlation between the length of an article (token count) and the number of entities extracted from it.

In [28]:
# Query the graph to get the text and entity count for each document.
entity_dist = graph.query(
    """
MATCH (d:Document)
RETURN d.text AS text,
       count {(d)-[:MENTIONS]->()} AS entity_count
"""
)
# Convert the query result into a pandas DataFrame.
entity_dist_df = pd.DataFrame(entity_dist)
# Calculate the token count for each document's text.
entity_dist_df["token_count"] = [
    num_tokens_from_string(str(el)) for el in entity_dist_df["text"]
]
# Create a scatter plot with a regression line to visualize the relationship.
sns.lmplot(
    x="token_count", y="entity_count", data=entity_dist_df, line_kws={"color": "red"}
)
plt.title("Entity Count vs Token Count Distribution")
plt.xlabel("Token Count")
plt.ylabel("Entity Count")
plt.show()

#### Node Degree Distribution

Node degree is a measure of how many connections a node has. Analyzing its distribution helps us understand the overall structure of the graph. A power-law distribution (a long tail), which is common in real-world networks, would indicate the presence of a few highly connected hub nodes.

In [29]:
# Import numpy for numerical operations.
import numpy as np

# Query the graph to get the degree of each entity node.
# The `[:!MENTIONS]` syntax excludes the MENTIONS relationship from the count.
degree_dist = graph.query(
    """
MATCH (e:__Entity__)
RETURN count {(e)-[:!MENTIONS]-()} AS node_degree
"""
)
degree_dist_df = pd.DataFrame.from_records(degree_dist)

# Calculate descriptive statistics for the node degrees.
mean_degree = np.mean(degree_dist_df["node_degree"])
percentiles = np.percentile(degree_dist_df["node_degree"], [25, 50, 75, 90])

# Create a histogram to visualize the distribution.
plt.figure(figsize=(12, 6))
sns.histplot(degree_dist_df["node_degree"], bins=50, kde=False, color="blue")
# Use a logarithmic scale on the y-axis to better visualize the long tail.
plt.yscale("log")
plt.title("Node Degree Distribution")
plt.xlabel("Node Degree")
plt.ylabel("Count (log scale)")

# Add vertical lines for mean and percentiles to the plot.
plt.axvline(
    mean_degree,
    color="red",
    linestyle="dashed",
    linewidth=1,
    label=f"Mean: {mean_degree:.2f}",
)
plt.axvline(
    percentiles[0],
    color="purple",
    linestyle="dashed",
    linewidth=1,
    label=f"25th Percentile: {percentiles[0]:.2f}",
)
plt.axvline(
    percentiles[1],
    color="orange",
    linestyle="dashed",
    linewidth=1,
    label=f"50th Percentile: {percentiles[1]:.2f}",
)
plt.axvline(
    percentiles[2],
    color="yellow",
    linestyle="dashed",
    linewidth=1,
    label=f"75th Percentile: {percentiles[2]:.2f}",
)
plt.axvline(
    percentiles[3],
    color="brown",
    linestyle="dashed",
    linewidth=1,
    label=f"90th Percentile: {percentiles[3]:.2f}",
)
plt.legend()
plt.show()

In [30]:
# Query to check the number of nodes and relationships with non-null descriptions.
# This helps assess the richness of the extracted information.
graph.query(
    """
MATCH (n:`__Entity__`)
RETURN "node" AS type,
       count(*) AS total_count,
       count(n.description) AS non_null_descriptions
UNION ALL
MATCH (n)-[r:!MENTIONS]->()
RETURN "relationship" AS type,
       count(*) AS total_count,
       count(r.description) AS non_null_descriptions
"""
)

[{'type': 'node', 'total_count': 1879, 'non_null_descriptions': 962},
 {'type': 'relationship', 'total_count': 1450, 'non_null_descriptions': 358}]

### Generating and Storing Node Embeddings

To perform more advanced graph algorithms like similarity search and community detection, we need to represent our nodes numerically. We'll generate vector embeddings for each entity node using our `llama3` embeddings model. These embeddings capture the semantic meaning of the node's ID and description. The `Neo4jVector` library simplifies this process, automatically generating embeddings for nodes that don't have them and storing them back in the database.

In [31]:
# Import libraries for vector stores and the GDS library.
from langchain_community.vectorstores import Neo4jVector
from langchain_ollama import OllamaEmbeddings
from graphdatascience import GraphDataScience

# Initialize the GDS client, pointing to the correct database.
gds = GraphDataScience(
    os.environ["NEO4J_URI"],
    auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"]),
)
gds.set_database(os.environ["DATABASE"])

# Initialize the embeddings model.
embeddings = OllamaEmbeddings(model="llama3")

# Initialize Neo4jVector from the existing graph.
# It will use the 'id' and 'description' properties to generate embeddings.
# The generated vectors will be stored in the 'embedding' property of the nodes.
vector = Neo4jVector.from_existing_graph(
    embeddings,
    node_label="__Entity__",
    text_node_properties=["id", "description"],
    embedding_node_property="embedding",
    database=os.environ["DATABASE"],
)

print("Embedding update complete.")

Embedding update complete.


### Finding Similar Nodes with k-Nearest Neighbors (KNN)

With embeddings in place, we can find semantically similar nodes. We'll use the KNN algorithm in GDS to create `SIMILAR` relationships between nodes whose embeddings are close to each other (above a certain cosine similarity threshold). This helps enrich the graph by adding inferred connections.

In [32]:
# Project the graph into the GDS in-memory catalog.
# This is a necessary step before running GDS algorithms.
G, result = gds.graph.project(
    "entities",                   # Name for the in-memory graph
    "__Entity__",                 # Node label to project
    "*",                          # Project all relationship types
    nodeProperties=["embedding"]  # Include the embedding property
)

In [33]:
# Define a similarity threshold.
similarity_threshold = 0.95

# Run the k-NN algorithm.
# This will create new 'SIMILAR' relationships in-memory for nodes with a cosine similarity
# score above the defined threshold.
gds.knn.mutate(
    G,
    nodeProperties=["embedding"],
    mutateRelationshipType="SIMILAR",
    mutateProperty="score",
    similarityCutoff=similarity_threshold,
)

K-Nearest Neighbours: 100%|██████████| 100.0/100 [00:01<00:00, 110.46%/s]

ranIterations                                                            10
nodePairsConsidered                                                  886993
didConverge                                                            True
preProcessingMillis                                                       2
computeMillis                                                          1498
mutateMillis                                                             28
postProcessingMillis                                                      0
nodesCompared                                                          1879
relationshipsWritten                                                   1934
similarityDistribution    {'min': 0.9500160217285156, 'p5': 0.9507484436...
configuration             {'mutateProperty': 'score', 'jobId': '198c5612...
Name: 0, dtype: object

K-Nearest Neighbours: 100%|██████████| 100.0/100 [00:02<00:00, 110.46%/s]

### Entity Resolution: Merging Duplicate Nodes

LLM-based extraction isn't perfect and often creates duplicate entities (e.g., "Elon Musk" and "Elonmusk"). We can combine our graph structure with another LLM call to resolve these duplicates.

1.  **Candidate Identification**: We use a Cypher query to find potential duplicates. This query looks for nodes that are within the same community (identified by the Weakly Connected Components algorithm), have similar names (using Levenshtein distance via `apoc.text.distance`), and have a `SIMILAR` relationship.
2.  **LLM-based Merging**: We then pass these candidate groups to another BAML function (`ExtractDeDupe`) which asks the LLM to decide on a canonical, single name for each group.
3.  **Graph Refactoring**: Finally, we use the `apoc.refactor.mergeNodes` procedure in Neo4j to merge the identified duplicate nodes into a single node, cleaning up our graph.

In [34]:
# Run the Weakly Connected Components (WCC) algorithm based on SIMILAR relationships.
# This assigns a community ID (`wcc`) to each node, which helps us scope our search for duplicates.
# The `.write()` mode persists this new property back to the Neo4j database.
gds.wcc.write(G, writeProperty="wcc", relationshipTypes=["SIMILAR"])

K-Nearest Neighbours: 100%|██████████| 100.0/100 [00:06<00:00, 16.66%/s] 


writeMillis                                                             20
nodePropertiesWritten                                                 1879
componentCount                                                        1419
componentDistribution    {'min': 1, 'p5': 1, 'max': 103, 'p999': 31, 'p...
postProcessingMillis                                                     5
preProcessingMillis                                                      0
computeMillis                                                            5
configuration            {'writeProperty': 'wcc', 'jobId': '4a4318a6-29...
Name: 0, dtype: object

In [35]:
# Set the edit distance for finding similar strings (e.g., 'Market' vs 'Markets').
word_edit_distance = 3

# This complex Cypher query identifies groups of potential duplicate entities.
potential_duplicate_candidates = graph.query(
    """MATCH (e:`__Entity__`)
    WHERE size(e.id) > 4 // longer than 4 characters
    WITH e.wcc AS community, collect(e) AS nodes, count(*) AS count
    WHERE count > 1
    UNWIND nodes AS node
    // Add text distance
    WITH distinct
      [n IN nodes WHERE apoc.text.distance(toLower(node.id), toLower(n.id)) < $distance | n.id] AS intermediate_results
    WHERE size(intermediate_results) > 1
    WITH collect(intermediate_results) AS results
    // combine groups together if they share elements
    UNWIND range(0, size(results)-1, 1) as index
    WITH results, index, results[index] as result
    WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
            CASE WHEN index <> index2 AND
                size(apoc.coll.intersection(acc, results[index2])) > 0
                THEN apoc.coll.union(acc, results[index2])
                ELSE acc
            END
    )) as combinedResult
    WITH distinct(combinedResult) as combinedResult
    // extra filtering
    WITH collect(combinedResult) as allCombinedResults
    UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
    WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
    WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1)
        WHERE x <> combinedResultIndex
        AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
    )
    RETURN combinedResult
    """,
    params={"distance": word_edit_distance},
)

In [36]:
# Display the first 5 groups of potential duplicates found by the query.
potential_duplicate_candidates[:5]

[{'combinedResult': ['David Van', 'Davidvan']},
 {'combinedResult': ['Cyb003', 'Cyb004']},
 {'combinedResult': ['Delta Air Lines', 'Delta_Air_Lines']},
 {'combinedResult': ['Elon Musk', 'Elonmusk']},
 {'combinedResult': ['Market', 'Markets']}]

In [37]:
# This cell was commented out in the original notebook.
# It's used to reset BAML environment variables if they were set in a specific way,
# which is not necessary for this demonstration.
# from baml_client import reset_baml_env_vars
# import os
#
# reset_baml_env_vars(dict(os.environ))

In [38]:
# Import the prompts specific to the deduplication task.
from prompts.graphragprompts import system_prompt_duplicates, user_template

# Create the prompt template for this specific task.
extraction_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            system_prompt_duplicates,
        ),
        (
            "human",
            user_template,
        ),
    ]
)

In [39]:
# Assemble the chain for entity resolution by piping the prompt, LLM, and BAML parser.
extraction_chain = extraction_prompt | llm | get_entities

In [40]:
# Test the deduplication chain with a sample list of entities.
entities = ["Star Ocean The Second Story R", "Star Ocean: The Second Story R"]
print(extraction_chain.invoke(entities))

['Star Ocean: The Second Story R']


In [41]:
# A simple wrapper function to call the extraction chain.
def entity_resolution(entities: List[str]) -> Optional[List[str]]:
    return [extraction_chain.invoke(entities)]

# Test the function on the first list of candidates from our Cypher query.
entity_resolution(potential_duplicate_candidates[0]["combinedResult"])

[['David Van']]

In [42]:
# Import libraries for handling potential timeouts and retries.
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional
from tqdm import tqdm
import time

# Set the number of parallel workers.
MAX_WORKERS = 3

# Define a robust entity resolution function with retries to handle potential API failures.
def entity_resolution(
    entities: List[str], retries: int = 3, delay: float = 30.0
) -> Optional[List[str]]:
    for attempt in range(1, retries + 1):
        try:
            # Invoke the chain to get the canonical entity names.
            return [extraction_chain.invoke(entities)]
        except Exception as e:
            print(f"Attempt {attempt} failed for entities: {entities}, Error: {e}")
            if attempt < retries:
                print(f"Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                print("Max retries reached. Returning None.")
                return None

# Process all candidate lists in parallel to get the final merged entity lists.
merged_entities = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = [
        executor.submit(entity_resolution, el["combinedResult"])
        for el in potential_duplicate_candidates
    ]
    for future in tqdm(
        as_completed(futures), total=len(futures), desc="Processing documents"
    ):
        try:
            to_merge = future.result()
            if to_merge:
                merged_entities.extend(to_merge)
        except Exception as e:
            print(f"Error in future result: {e}")
            continue

Processing documents: 100%|██████████| 26/26 [00:23<00:00,  1.12it/s]


In [43]:
# Execute a Cypher query to merge the nodes in the database.
# It iterates through the lists of duplicates and merges them into a single node using APOC procedures.
graph.query(
    """
UNWIND $data AS candidates
CALL {
  WITH candidates
  MATCH (e:__Entity__) WHERE e.id IN candidates
  RETURN collect(e) AS nodes
}
CALL apoc.refactor.mergeNodes(nodes, {properties: {
    `.*`: 'discard'
}})
YIELD node
RETURN count(*)
""",
    params={"data": merged_entities},
)

[{'count(*)': 18}]

### Hierarchical Community Detection with Leiden Algorithm

Community detection algorithms help uncover clusters of densely connected nodes in a graph. We will use the Leiden algorithm, a high-quality method that often produces better-defined communities than older algorithms like Louvain. By running it hierarchically, we can see how smaller communities merge into larger, more abstract ones at different levels of granularity.

In [44]:
# Project the graph again, this time including relationship weights for the community detection algorithm.
G, result = gds.graph.project(
    "communities",  #  Name for the new in-memory graph
    "__Entity__",  #  Node projection
    {
        "_ALL_": {
            "type": "*",
            "orientation": "UNDIRECTED",
            "properties": {"weight": {"property": "*", "aggregation": "COUNT"}},
        }
    },
)

In [45]:
# Run WCC stats to see the number of disconnected components before running Leiden.
wcc = gds.wcc.stats(G)
print(f"Component count: {wcc['componentCount']}")
print(f"Component distribution: {wcc['componentDistribution']}")

Component count: 722
Component distribution: {'min': 1, 'p5': 1, 'max': 305, 'p999': 305, 'p99': 14, 'p1': 1, 'p10': 1, 'p90': 4, 'p50': 1, 'p25': 1, 'p75': 2, 'p95': 6, 'mean': 2.596952908587258}


In [46]:
# Run the Leiden algorithm to detect communities and write the results back to the database.
# `includeIntermediateCommunities=True` allows us to see the hierarchical structure.
gds.leiden.write(
    G,
    writeProperty="communities",
    includeIntermediateCommunities=True,
    relationshipWeightProperty="weight",
)

writeMillis                                                             19
nodePropertiesWritten                                                 1875
ranLevels                                                                4
didConverge                                                           True
nodeCount                                                             1875
communityCount                                                         732
communityDistribution    {'min': 1, 'p5': 1, 'max': 77, 'p999': 77, 'p9...
modularity                                                        0.970618
modularities             [0.8718561236623066, 0.9618311533888227, 0.970...
postProcessingMillis                                                     1
preProcessingMillis                                                      0
computeMillis                                                           62
configuration            {'writeProperty': 'communities', 'theta': 0.01...
Name: 0, dtype: object

### Structuring Communities in the Graph

The Leiden algorithm stores community IDs as a list property on each node. The following queries transform this flat list into a hierarchical graph structure by creating `(__Community__)` nodes and linking them together, as well as linking entities to their respective communities.

In [47]:
# Create a uniqueness constraint on community IDs to ensure data integrity and improve query performance.
graph.query(
    "CREATE CONSTRAINT IF NOT EXISTS FOR (c:__Community__) REQUIRE c.id IS UNIQUE;"
)

[]

In [48]:
# This query creates the community nodes and their relationships.
graph.query(
    """
MATCH (e:`__Entity__`)
UNWIND range(0, size(e.communities) - 1 , 1) AS index
// Create the link from an entity to its lowest-level community
CALL {
  WITH e, index
  WHERE index = 0
  MERGE (c:`__Community__` {id: toString(index) + '-' + toString(e.communities[index])})
  ON CREATE SET c.level = index
  MERGE (e)-[:IN_COMMUNITY]->(c)
  RETURN count(*) AS count_0
}
// Create links between hierarchical community levels
CALL {
  WITH e, index
  WHERE index > 0
  MERGE (current:`__Community__` {id: toString(index) + '-' + toString(e.communities[index])})
  ON CREATE SET current.level = index
  MERGE (previous:`__Community__` {id: toString(index - 1) + '-' + toString(e.communities[index - 1])})
  ON CREATE SET previous.level = index - 1
  MERGE (previous)-[:IN_COMMUNITY]->(current)
  RETURN count(*) AS count_1
}
RETURN count(*)
"""
)

[{'count(*)': 7500}]

In [49]:
# This query calculates a 'community_rank' for each community based on how many unique documents are linked to it.
graph.query(
    """
MATCH (c:__Community__)<-[:IN_COMMUNITY*]-(:__Entity__)<-[:MENTIONS]-(d:Document)
WITH c, count(distinct d) AS rank
SET c.community_rank = rank;
"""
)

[]

### Analyzing Community Sizes

Finally, let's look at the size distribution of our detected communities at each level of the hierarchy. This helps us understand the structure of the topics within our document set, from very specific small clusters to broader thematic groups.

In [50]:
# Query to get the size (number of entities) of each community at each level.
community_size = graph.query(
    """
MATCH (c:__Community__)<-[:IN_COMMUNITY*]-(e:__Entity__)
WITH c, count(distinct e) AS entities
RETURN split(c.id, '-')[0] AS level, entities
"""
)
community_size_df = pd.DataFrame.from_records(community_size)

# Calculate percentile data for community sizes at each level.
percentiles_data = []
for level in sorted(community_size_df["level"].unique()):
    subset = community_size_df[community_size_df["level"] == level]["entities"]
    num_communities = len(subset)
    percentiles = np.percentile(subset, [25, 50, 75, 90, 99])
    percentiles_data.append(
        [
            level,
            num_communities,
            percentiles[0],
            percentiles[1],
            percentiles[2],
            percentiles[3],
            percentiles[4],
            max(subset),
        ]
    )

# Create a DataFrame to display the percentile statistics for easy comparison.
percentiles_df = pd.DataFrame(
    percentiles_data,
    columns=[
        "Level", "Number of communities", "25th Percentile", "50th Percentile",
        "75th Percentile", "90th Percentile", "99th Percentile", "Max",
    ],
)
percentiles_df

Unnamed: 0,Level,Number of communities,25th Percentile,50th Percentile,75th Percentile,90th Percentile,99th Percentile,Max
0,0,858,1.0,1.0,2.0,4.0,10.43,37
1,1,749,1.0,1.0,2.0,5.0,18.52,77
2,2,734,1.0,1.0,2.0,5.0,27.67,77
3,3,732,1.0,1.0,2.0,5.0,27.69,77


In [51]:
# Query to retrieve detailed information about the communities, including their nodes and relationships.
# This data can be used for further summarization or visualization.
community_info = graph.query(
    """
MATCH (c:`__Community__`)<-[:IN_COMMUNITY*]-(e:__Entity__)
WHERE c.level IN [0,1]
WITH c, collect(e) AS nodes
WHERE size(nodes) > 1
CALL apoc.path.subgraphAll(nodes[0], {
	whitelistNodes:nodes
})
YIELD relationships
RETURN c.id AS communityId,
       [n in nodes | {id: n.id, description: n.description, type: [el in labels(n) WHERE el <> '__Entity__'][0]}] AS nodes,
       [r in relationships | {start: startNode(r).id, type: type(r), end: endNode(r).id, description: r.description}] AS rels
"""
)

In [52]:
# Display information for the first community as an example.
community_info[:1]

[{'communityId': '0-1',
  'nodes': [{'id': 'Chevron',
    'description': 'Energy company',
    'type': 'Company'},
   {'id': 'O&G Sector', 'description': None, 'type': 'Industry'},
   {'id': 'Q2', 'description': None, 'type': 'Period of time'}],
  'rels': [{'start': 'Chevron',
    'description': None,
    'type': 'IS_A_PART_OF',
    'end': 'O&G Sector'},
   {'start': 'Q2',
    'description': 'risen sharply (~25%) during that same time frame',
    'type': 'EARNINGS_ESTIMATES',
    'end': 'Chevron'}]}]