In [12]:
import asyncio
# import getpass
import os
from datetime import datetime
from hashlib import md5
from typing import Dict, List

import pandas as pd
# import seaborn as sns
# import tiktoken
from langchain_community.graphs import Neo4jGraph
# from langchain_community.tools import WikipediaQueryRun
# from langchain_community.utilities import WikipediaAPIWrapper
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_text_splitters import TokenTextSplitter
from pydantic import BaseModel, Field
import openai
from neo4j import GraphDatabase

In [6]:
os.environ["NEO4J_URI"] = os.getenv('NEO4J_URI')
os.environ["NEO4J_USERNAME"] = os.getenv('NEO4J_USERNAME')
os.environ["NEO4J_PASSWORD"] = os.getenv('NEO4J_PASSWORD')
os.environ["OPENAI_API_KEY"] =  os.getenv('OPENAI_API_KEY')

graph = Neo4jGraph(refresh_schema=False)

graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Chunk) REQUIRE c.id IS UNIQUE")
graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (c:AtomicFact) REQUIRE c.id IS UNIQUE")
graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (c:KeyElement) REQUIRE c.id IS UNIQUE")
graph.query("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Document) REQUIRE d.id IS UNIQUE")

[]

In [13]:
data = pd.read_json('data/arxiv_ml_papers.jsonl', nrows = 20, lines=True)
document_names = []
documents_full = []
for i, id in enumerate(data["id"].values):
    # print working directory
    file_path = os.path.join("data/arxiv_ml_text", f"0{id}.json")   
    if os.path.exists(file_path):
        document_names.append(id)
        documents_full.append(open(file_path).read())
        #documents_full.append(Document(page_content=open(file_path).read(), metadata = data.iloc[i].to_dict()))


## Construct system

In [7]:
construction_system = """
You are now an intelligent assistant tasked with meticulously extracting both key elements and
atomic facts from multiple articles.
1. Key Elements: The essential nouns (e.g., authors, theorems, definitions, year, places) and verbs (e.g.,
referencing, citing) that are pivotal to the text’s narrative.
2. Atomic Facts: The smallest, indivisible facts, presented as concise sentences. These include
propositions, theories, existences, concepts, and implicit elements like logic, causality, event
sequences, interpersonal relationships, timelines, etc.
Requirements:
#####
1. Ensure that all identified key elements are reflected within the corresponding atomic facts.
2. You should extract key elements and atomic facts comprehensively, especially those that are
important and potentially query-worthy and do not leave out details.
3. Whenever applicable, replace pronouns with their specific noun counterparts (e.g., change I, He,
She to actual names).
4. Ensure that the key elements and atomic facts you extract are presented in the same language as
the original text (e.g., English or Chinese).
"""

construction_human = """Use the given format to extract information from the 
following input: {input}"""

construction_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            construction_system,
        ),
        (
            "human",
            (
                "Use the given format to extract information from the "
                "following input: {input}"
            ),
        ),
    ]
)

In [8]:
class AtomicFact(BaseModel):
    key_elements: List[str] = Field(description="""The essential nouns (e.g., authors, theorems, definitions, year, places) and verbs (e.g.,
referencing, citing) that are pivotal to the text’s narrative.""")
    atomic_fact: str = Field(description="""The smallest, indivisible facts, presented as concise sentences. These include
propositions, theories, existences, concepts, and implicit elements like logic, causality, event
sequences, interpersonal relationships, timelines, etc.""")

class Extraction(BaseModel):
    atomic_facts: List[AtomicFact] = Field(description="List of atomic facts")

## Model


In [9]:
model = ChatOpenAI(model="gpt-4o-mini", temperature=0.1)
structured_llm = model.with_structured_output(Extraction)

construction_chain = construction_prompt | structured_llm

In [10]:
import_query = """
MERGE (d:Document {id:$document_name})
WITH d
UNWIND $data AS row
MERGE (c:Chunk {id: row.chunk_id})
SET c.text = row.chunk_text,
    c.index = row.index,
    c.document_name = row.document_name
MERGE (d)-[:HAS_CHUNK]->(c)
WITH c, row
UNWIND row.atomic_facts AS af
MERGE (a:AtomicFact {id: af.id})
SET a.text = af.atomic_fact
MERGE (c)-[:HAS_ATOMIC_FACT]->(a)
WITH c, a, af
UNWIND af.key_elements AS ke
MERGE (k:KeyElement {id: ke})
MERGE (a)-[:HAS_KEY_ELEMENT]->(k)
"""

def encode_md5(text):
    return md5(text.encode("utf-8")).hexdigest()

In [11]:
# Paper used 2k token size
async def process_document(text, document_name, chunk_size=2000, chunk_overlap=200):
    start = datetime.now()
    print(f"Started extraction at: {start}")
    text_splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    texts = text_splitter.split_text(text)
    print(f"Total text chunks: {len(texts)}")
    tasks = [
        asyncio.create_task(construction_chain.ainvoke({"input":chunk_text}))
        for index, chunk_text in enumerate(texts)
    ]
    results = await asyncio.gather(*tasks)
    print(f"Finished LLM extraction after: {datetime.now() - start}")
    docs = [el.dict() for el in results]
    for index, doc in enumerate(docs):
        doc['chunk_id'] = encode_md5(texts[index])
        doc['chunk_text'] = texts[index]
        doc['index'] = index
        for af in doc["atomic_facts"]:
            af["id"] = encode_md5(af["atomic_fact"])
    # Import chunks/atomic facts/key elements
    graph.query(import_query, 
            params={"data": docs, "document_name": document_name})
    # Create next relationships between chunks
    graph.query("""MATCH (c:Chunk)<-[:HAS_CHUNK]-(d:Document)
WHERE d.id = $document_name
WITH c ORDER BY c.index WITH collect(c) AS nodes
UNWIND range(0, size(nodes) -2) AS index
WITH nodes[index] AS start, nodes[index + 1] AS end
MERGE (start)-[:NEXT]->(end)
""",
           params={"document_name":document_name})
    print(f"Finished import at: {datetime.now() - start}")

In [14]:
for text, name in zip(documents_full, document_names):
    await process_document(text, name, chunk_size=500, chunk_overlap=100)

Started extraction at: 2024-11-25 13:48:14.471099
Total text chunks: 18
Finished LLM extraction after: 0:13:15.810646
Finished import at: 0:13:17.652449
Started extraction at: 2024-11-25 14:01:32.123913
Total text chunks: 33
Finished LLM extraction after: 0:00:15.344908
Finished import at: 0:00:17.130754
Started extraction at: 2024-11-25 14:01:49.254785
Total text chunks: 32
Finished LLM extraction after: 0:00:29.736022
Finished import at: 0:00:30.602721
Started extraction at: 2024-11-25 14:02:19.857874
Total text chunks: 65
Finished LLM extraction after: 0:00:22.396244
Finished import at: 0:00:23.596015
Started extraction at: 2024-11-25 14:02:43.454100
Total text chunks: 14
Finished LLM extraction after: 0:00:20.043014
Finished import at: 0:00:20.485119
Started extraction at: 2024-11-25 14:03:03.939445
Total text chunks: 26
Finished LLM extraction after: 0:00:22.271768
Finished import at: 0:00:22.802752
Started extraction at: 2024-11-25 14:03:26.742593
Total text chunks: 62


## Query

In [29]:
import os
import openai
from openai import OpenAI

In [14]:
host = os.getenv('NEO4J_URI')
user = os.getenv('NEO4J_USERNAME')
password = os.getenv('NEO4J_PASSWORD')
driver = GraphDatabase.driver(host, auth=(user, password))

