In [1]:
import asyncio
import json
import os
import neo4j

In [3]:
from dotenv import load_dotenv
from tqdm.asyncio import tqdm_asyncio
from tqdm import tqdm
from neo4j_graphrag.llm import OpenAILLM, OllamaLLM
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings

In [3]:
# Load environment variables
load_dotenv(override=True)

True

In [4]:
# Initialize the Neo4j driver
neo4j_driver = neo4j.GraphDatabase.driver(
    os.getenv('NEO4J_URI'),
    auth=(os.getenv('NEO4J_USERNAME'), os.getenv('NEO4J_PASSWORD'))
)

In [5]:
# Set up the OpenAI LLM, embedder, and KG builder components:
llm = OpenAILLM(
    model_name="gpt-4o",
    model_params={
        "max_tokens": 2000,
        "response_format": {"type": "json_object"},
    },
)
embedder = OpenAIEmbeddings(model="text-embedding-3-large")
kg_builder = SimpleKGPipeline(
    llm=llm,
    driver=neo4j_driver,
    embedder=embedder,
    from_pdf=False,
    neo4j_database="neo4j"
)

In [6]:
async def clear_graph_db():
    """
    Clears the entire graph database by deleting all nodes and relationships.
    Since the neo4j_driver session is synchronous, this operation is offloaded
    to an executor.
    """
    def _clear():
        with neo4j_driver.session(database="neo4j") as session:
            session.run("MATCH (n) DETACH DELETE n")

    loop = asyncio.get_event_loop()
    await loop.run_in_executor(None, _clear)
    print("Graph database cleared.")


In [7]:
async def process_policy(policy: dict,
                         semaphore: asyncio.Semaphore,
                         max_retries: int = 3,
                         initial_backoff: int = 60):
    """
    Process a single policy document. Tries to call kg_builder.run_async 
    and retries upon errors.
    
    - max_retries: Maximum number of attempts.
    - initial_backoff: The initial waiting time (in seconds) that doubles at 
      each retry.
    """
    # Use info from policy metadata for logging.
    source = policy["url"]
    
    async with semaphore:
        for attempt in range(1, max_retries + 1):
            try:
                print(f"Processing: {source} (Attempt {attempt})")
                await kg_builder.run_async(text=policy["extracted"])
                print(f"Successfully processed: {source}")
                break  # Exit the retry loop on success.
            except Exception as e:
                print(f"Error processing {source} on attempt {attempt}: {e}")
                if attempt == max_retries:
                    print(f"Failed to process {source} after {max_retries} attempts")
                else:
                    # Compute the delay using exponential backoff.
                    delay = initial_backoff * (attempt - 1)
                    print(f"Retrying in {delay} seconds...")
                    await asyncio.sleep(delay)

In [11]:
# Load your policies data
with open("policies.json", "r") as f:
    data = json.load(f)
    # data = data[:5]
    print(len(data))

437


In [12]:
async def main():
    # Step 1: Clear the graph database.
    await clear_graph_db()

    # Step 2: Set a concurrency limit.
    semaphore = asyncio.Semaphore(2)

    # Create a task for each policy document.
    tasks = [
        asyncio.create_task(process_policy(policy, semaphore))
        for policy in data
    ]
    
    # Use tqdm to create a progress bar.
    progress_bar = tqdm(total=len(tasks), desc="Processing Policies")

    # As each task completes, update the progress bar.
    for fut in asyncio.as_completed(tasks):
        try:
            await fut
        except Exception:
            # Exceptions are already handled in process_policy.
            pass
        progress_bar.update(1)
    
    progress_bar.close()
    print("All policies have been processed.")

In [None]:
await main()