In [1]:
# !pip install neo4j
# !pip install langchain
# !pip install PyPDF2
# !pip install tiktoken
# !pip install openai  # Only if you want to use the OpenAI API
# !pip install transformers  # For open (HF) models
# !pip install sentence_transformers
# !pip install -U langchain-community
# !pip install graphdatascience
# For advanced community detection with Leiden, you might need external libraries (e.g., igraph, networkx, etc.).


In [2]:
#!ollama pull llama3.1

In [3]:
import os
from typing import List, Dict, Any
import tqdm
import concurrent.futures

# -----------------------
# Neo4j Database imports
# -----------------------
from neo4j import GraphDatabase

# -----------------------
# LLM / Embeddings imports
# -----------------------
# If using HuggingFace transformers:
from transformers import pipeline

# If using LangChain for retrieval + QA
from langchain.embeddings import HuggingFaceEmbeddings
# from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.question_answering import load_qa_chain

# LangChain GraphRag Setup
from langchain_community.graphs import Neo4jGraph
from langchain_experimental.graph_transformers import LLMGraphTransformer
# from langchain_core.documents import Document
from langchain_experimental.llms.ollama_functions import OllamaFunctions
from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama


from langchain.vectorstores import Milvus
from langchain.schema import Document
from langchain.schema import HumanMessage, SystemMessage

# If you want to use OpenAI, uncomment:
import openai
openai_model="gpt-4o-mini"
# -----------------------
# Load environment variables
import dotenv
dotenv.load_dotenv()

# -----------------------
# ArXiv API
# -----------------------
import arxiv

# -----------------------
# PDF Parsing library
# -----------------------
import PyPDF2  # or "pypdf" if needed

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
#############################################
# 1) CONFIGURATION: toggle open vs. OpenAI
#############################################

USE_OPENAI = True  # Set to True if you want to switch to OpenAI’s ChatGPT
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# For Neo4j:
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USER = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")

driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

In [5]:
##################################################
# 2) NEO4J CONNECTION AND GRAPH FUNCTIONS
##################################################

def add_chunk_node(tx, chunk_text: str, chunk_id: str, embedding: List[float]):
    """
    Create or merge a chunk node in Neo4j to represent a piece of text.
    Store its embedding as well.
    """
    embedding_str = ",".join([str(x) for x in embedding])

    query = """
    MERGE (c:Chunk {chunk_id: $chunk_id})
    ON CREATE SET c.text = $chunk_text,
                  c.embedding = $embedding_str
    """
    tx.run(query, chunk_id=chunk_id, chunk_text=chunk_text, embedding_str=embedding_str)


def add_document_relationship(tx, doc_id: str, chunk_id: str):
    """
    Link each chunk node to a parent Document node in Neo4j.
    """
    query = """
    MERGE (d:Document {doc_id: $doc_id})
    MERGE (c:Chunk {chunk_id: $chunk_id})
    MERGE (d)-[:HAS_CHUNK]->(c)
    """
    tx.run(query, doc_id=doc_id, chunk_id=chunk_id)


def add_element_instance(tx, element_data: Dict[str, Any]):
    """
    (Step 2.2) Insert extracted graph node/edge relationships from text chunk.
    Example structure of element_data might be:
    {
        "entity_name": "...",
        "entity_type": "...",
        "entity_description": "...",
        "relationship": {
            "source_entity": "...",
            "target_entity": "...",
            "description": "...",
        },
        ...
    }

    In a real pipeline, you might store these as separate nodes/edges. 
    For demonstration, we store them in a single node with a property that can be parsed.
    """
    query = """
    CREATE (e:ElementInstance {data: $element_data})
    """
    tx.run(query, element_data=str(element_data))


def retrieve_relevant_chunks(tx, user_query: str, limit: int = 5) -> List[Dict[str, Any]]:
    """
    A simplistic retrieval function that returns chunk nodes based on naive text search.
    In a real scenario, you'd have a vector similarity search using embeddings.
    """
    query = """
    MATCH (c:Chunk)
    WHERE c.text CONTAINS $user_query
    RETURN c.chunk_id AS chunk_id, c.text AS text
    LIMIT $limit
    """
    result = tx.run(query, user_query=user_query, limit=limit)
    return [record.data() for record in result]

In [6]:
##################################################
# 2.1 SOURCE DOCUMENTS → TEXT CHUNKS
##################################################
def find_metadata(doc_id: str) -> Dict[str, Any]:
    """
    (Step 2.1) Retrieve metadata for a document from a database or API.
    """
    client = arxiv.Client()
    search = arxiv.Search(
        id_list=[doc_id]
    )

    try:
        result = next(client.results(search))
        return \
            {"title": result.title, 
            "summary": result.summary, 
            "url": result.entry_id,
            "authors": ', '.join([a.name for a in result.authors]),
            "categories": ', '.join(result.categories)
            }
    except StopIteration:
        return {}

def parse_pdf(pdf_path: str) -> str:
    """
    Extract raw text from a PDF file using PyPDF2.
    """
    text = ""
    with open(pdf_path, 'rb') as f:
        reader = PyPDF2.PdfReader(f)
        for page in reader.pages:
            text += page.extract_text() + "\n"
    return text


def chunk_text(text: str, chunk_size: int = 600, chunk_overlap: int = 100) -> List[str]:
    """
    (Step 2.1) Split text into chunks. 
    Following the guidance in 2.1, we use a smaller chunk size (e.g., ~600 tokens).
    This can improve entity recall at the cost of more LLM calls.
    """
    # Note: the chunk_size is in characters by default using this splitter;
    # you may want to adapt to token-based splitting.
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    
    chunks = text_splitter.split_text(text)
    return chunks

In [7]:
##################################################
# 4) EMBEDDING UTILITIES
##################################################

def get_hf_embedding_function(model_name: str = "sentence-transformers/all-MiniLM-L6-v2", device: str = "mps"):
    """
    Returns a function that can generate embeddings using a HuggingFace model.
    """
    hf_embed = HuggingFaceEmbeddings(model_name=model_name, device=device)
    return hf_embed.embed_documents


# If using OpenAI embeddings, uncomment and implement:
# def get_openai_embedding_function(model_name: str = "text-embedding-ada-002"):
#     def _embeddings(texts: List[str]) -> List[List[float]]:
#         response = openai.Embedding.create(
#             input=texts,
#             model=model_name
#         )
#         embeddings = [item["embedding"] for item in response["data"]]
#         return embeddings
#     return _embeddings


In [8]:
# Function to check and initialize Milvus database
def init_milvus_db(collection_name: str, uri: str, embedding_function):
    """
    Initialize the Milvus database if it does not exist.
    """
    if not os.path.exists(uri.replace("sqlite://", "")):
        print(f"Creating database at {uri}")
    vectorstore = Milvus(
        collection_name=collection_name,
        embedding=embedding_function,
        connection_args={"uri": uri},
    )
    return vectorstore

# Function to add multiple documents to Milvus
def add_documents_to_milvus(docs: list, embedding_function, collection_name: str = "rag_milvus", uri: str = "sqlite://./vector_db_graphRAG/milvus_ingest.db"):
    """
    Add multiple documents to the Milvus vector store.

    Args:
        docs (list): List of tuples containing text and metadata.
        collection_name (str): Name of the Milvus collection.
        uri (str): URI for the Milvus database.
    
    # Example usage
    docs = [
        {"text": "Chunk 1 of the document", "metadata": {"doc_id": "doc_1", "chunk": 1}},
        {"text": "Chunk 2 of the document", "metadata": {"doc_id": "doc_1", "chunk": 2}}
    ]

    add_documents_to_milvus(docs)
    """
    # Initialize the database
    vectorstore = init_milvus_db(collection_name, uri, embedding_function)

    # Prepare documents
    document_list = []
    for doc in docs:
        text = doc.get("text", "")
        metadata = doc.get("metadata", {})
        document_list.append(Document(page_content=text, metadata=metadata))

    # Add documents to the vector store
    vectorstore.add_documents(document_list)


In [9]:

def extract_element_instances_from_chunk(
    chunk_text: str,
    gleaning_rounds: int = 1,
    USE_OPENAI: bool = True,
    local_llm: str = "llama3.1"
) -> List[Dict[str, Any]]:
    """
    (Step 2.2) Use an LLM prompt to identify entity references, relationships, and covariates.
    - Identifies entities (name, type, description) and relationships.
    - Supports multiple rounds of "gleanings" to find any missed entities.
    """
    extracted_elements = []
    not_parsed = []

    # Select the LLM to use
    if USE_OPENAI:
        llm = ChatOpenAI(model=openai_model, temperature=0.0)
    else:
        llm = ChatOllama(model=local_llm, temperature=0)

    # Base prompt for extracting entities and relationships
    base_prompt = (
        "Extract entities and relationships from the following text. "
        "For each entity, provide its name, type, and description. "
        "For each relationship, provide the source entity, target entity, and description. "
        "Text: \n{chunk_text}\n"
        "Output format: List of dictionaries with keys 'entity_name', 'entity_type', 'entity_description', 'relationship'. "
        "Follow this format in the example: [{{\"entity_name\": \"Alice\", \"entity_type\": \"Person\", \"entity_description\": \"A person of interest.\", \"relationship\": {{\"source_entity\": \"Alice\", \"target_entity\": \"Bob\", \"description\": \"Knows\"}}}}]"
        "Return just the list, so that we can parse it."
        "No bullet list or asterisks needed."
        "It must be a unique list, do NOT separate entities and relationships in different lists."
    )

    # Loop through gleaning rounds
    for round_num in range(gleaning_rounds):
        prompt = base_prompt.format(chunk_text=chunk_text)

        # Send prompt to the selected LLM
        response = llm.invoke([HumanMessage(content=prompt)])

        # Parse the LLM response
        new_elements = response.content
        new_elements = new_elements.replace('```json','').replace('```','')

        # Assume the response is already in JSON format
        try:
            new_elements = eval(new_elements)  # Convert string to list of dicts
        except Exception as e:
            print(f"Error parsing LLM output: {e}")
            not_parsed.extend(new_elements)
            new_elements = []

        # Add new elements to the result
        extracted_elements.extend(new_elements)

        # Check if gleaning is needed (e.g., ask LLM if entities were missed)
        if round_num < gleaning_rounds - 1:
            print("Asking for validation...")
            validation_prompt = (
                "Were any entities or relationships missed in the previous extraction? "
                "Answer 'Yes' or 'No'."
            )
            validation_response = llm.invoke(
                [ HumanMessage(content=validation_prompt)], 
                chat_history=
                [HumanMessage(content=prompt), SystemMessage(content=new_elements)]
            )

            # If LLM says 'No', break early
            if 'No' in validation_response.content:
                break

    # Return all extracted elements
    return extracted_elements, not_parsed


In [10]:
summarization_prompt = """
You are an expert in text summarization and knowledge graph construction. I will provide you with:

1. A list of initial nodes, each containing:
   - A unique identifier
   - A textual description extracted from the source
   - Additional properties (optional)

2. A list of relationships between these nodes (e.g., an entity "Einstein" related to another entity "Relativity").

Your task is to:
- Identify when multiple nodes actually refer to the same entity or concept.
- Generate *summarized nodes* by consolidating their textual descriptions and removing duplicates or near-duplicates.
- Maintain references to each node's original ID within your summarized node.
- Create new relationships among these summarized nodes that reflect the original relationships, but merged and simplified where appropriate.

**Important requirements and format details:**
1. Each summarized node should have:
   - A `summary` field with the merged description.
   - A list of `original_ids` that were merged into this new summary node.
   - Any relevant `type` or `label` (e.g., Person, Theory, Location) if it can be inferred from the text.
   - (Optional) A short list of `keywords` extracted from the descriptions.

2. Each relationship should:
   - Include `source` and `target` references to the new summarized nodes.
   - Provide a `relation_type` (e.g., "INVENTED", "WORKS_ON", "LOCATED_IN", etc.).
   - Have a `weight` or `relevance_score` if it can be inferred (e.g., frequency or importance).
   - (Optional) Include an `original_relationships` list indicating which original relationships were merged.

3. Return the final data in **JSON** format, containing two top-level keys: `summarized_nodes` and `summarized_relationships`.

4. Be concise but ensure the summaries and relationships accurately capture the original meaning.

---

### **Here is the initial Nodes and Relationships**:

{initial_data}

---

### **Instructions to the LLM**:
1. **Identify duplicates or near-duplicates** (e.g., "Albert Einstein" and "Einstein" might refer to the same entity).
2. **Create a new summarized node** that merges the descriptions of "N1" and "N2" if they represent the same entity (in this case, Albert Einstein).
3. **Consolidate relationships** so that if multiple original relationships lead to the same concept, you unify them into a single relationship with an updated weight (e.g., sum or average of the original).
4. Provide your final answer in the following **JSON** structure:

```json
{{
  "summarized_nodes": [
    {{
      "title": "NewTitle1",
      "summary": "Your merged summary text here...",
      "original_ids": ["ExampleID1", "ExampleID2", ...],
      "type": "Person",
      "keywords": ["Einstein", "relativity", "physics"]
    }},
    {{
      "title": "NewTitle2",
      "summary": "Your summary text here...",
      "original_ids": ["ExampleID3", "ExampleID4"],
      "type": "Theory",
      "keywords": ["relativity", "physics"]
    }}
  ],
  "summarized_relationships": [
    {{
      "source": "NewTitle1",
      "target": "NewTitle2",
      "relation_type": "DEVELOPED_OR_ASSOCIATED_WITH",
      "weight": 5,
      "original_relationships": ["N1->N3(DEVELOPED)", "N2->N3(ASSOCIATED_WITH)"]
    }}
  ]
}}
```

Please **only** output valid JSON in the format described above, without additional commentary so that we can parse it correctly. 
Make sure to capture the essence of each original node and relationship in your summarized version.
"""

In [11]:
##################################################
# 2.3 ELEMENT INSTANCES → ELEMENT SUMMARIES
##################################################

def summarize_element_instances(
    element_instances: List[Dict[str, Any]], 
    USE_OPENAI: bool = True,
    local_llm: str = "llama3.1") -> str:
    """
    (Step 2.3) Summarize extracted nodes/relationships into a single descriptive block of text
    for each chunk. This is an additional LLM-based summarization step, forming "element summaries."
    """
    # Select the LLM to use
    if USE_OPENAI:
        llm = ChatOpenAI(model=openai_model, temperature=0.0)
    else:
        llm = ChatOllama(model=local_llm, temperature=0)

    prompt = summarization_prompt.format(initial_data=element_instances)
    response = llm.invoke([HumanMessage(content=prompt)])
    response = response.content

    try:
        response = response.replace('```json','').replace('```','')
        response = eval(response)
    except Exception as e:
        print(f"Error parsing LLM output: {e}")
    
    return response

In [12]:
def store_element_summary_in_graph(tx, data: Dict[str, Any], doc_id: str):
    """
    Load the summarized graph data into Neo4j.
    """
    
    # Creazione dei nodi
    for node in data["summarized_nodes"]:
        query_create_node = """
        CREATE (n:SummarizedNode {
            title: $title,
            summary: $summary,
            original_ids: $original_ids,
            type: $type,
            keywords: $keywords,
            doc_id : $doc_id
        })
        """
        tx.run(
            query_create_node,
            title=node.get("title"),
            summary=node.get("summary"),
            original_ids=node.get("original_ids"),
            type=node.get("type"),
            keywords=node.get("keywords"),
            doc_id=doc_id
        )

    # Creazione delle relazioni
    for rel in data["summarized_relationships"]:
        query_create_rel = f"""
        MATCH (source:SummarizedNode {{title: $source_id}})
        MATCH (target:SummarizedNode {{title: $target_id}})
        CREATE (source)-[:RELATIONSHIP_TYPE {{
            type : $relation_type,
            weight: $weight,
            original_relationships: $original_rels
        }}]->(target)
        """
        tx.run(
            query_create_rel,
            source_id=rel["source"],
            target_id=rel["target"],
            relation_type=rel["relation_type"],
            weight=rel.get("weight", 1),  # default=1 if not provided
            original_rels=rel.get("original_relationships", [])
        )

In [43]:
##################################################
# 2.4 ELEMENT SUMMARIES → GRAPH COMMUNITIES
##################################################

class CommunityDetection:
    def __init__(self, driver: Any = None, uri: str = '', user: str = '', password: str = ''):
        """
        Initialize the CommunityDetection class with a Neo4j driver or connection details.

        Args:
            driver (Any): An existing Neo4j driver instance. If provided, `uri`, `user`, and `password` are ignored.
            uri (str): The URI for the Neo4j database.
            user (str): The username for the Neo4j database.
            password (str): The password for the Neo4j database.
        """
        # If a driver is provided, use it; otherwise, create a new driver instance
        self.driver = driver or GraphDatabase.driver(uri, auth=(user, password))
        self.graph_name = 'summarizedGraph'

    def close(self):
        self.driver.close()

    def project_graph(self, relationship_weight_property: str = "weight") -> None:
        """
        Projects the graph into memory for analysis.
        """
        with self.driver.session() as session:
            session.run(
                f"""
                CALL gds.graph.project(
                    '{self.graph_name}',
                    'SummarizedNode',
                    {{ RELATIONSHIP_TYPE: {{ orientation: 'UNDIRECTED', properties: ['{relationship_weight_property}'] }} }}
                )
                """
            )


    def set_communities(self, relationship_weight_property: str = "weight") -> List[Dict[str, Any]]:
        """
        Sets the community IDs directly into the graph database.
        """
        with self.driver.session() as session:
            result = session.run(
                f"""
                CALL gds.leiden.stream('{self.graph_name}', {{ relationshipWeightProperty: '{relationship_weight_property}' }})
                YIELD nodeId, communityId
                SET gds.util.asNode(nodeId).communityId = communityId
                """
            )

            return True
        
        return False
        
    def retrieve_communities(self) -> List[Dict[str, Any]]:
        """
        Retrieve the community assignments from the graph.
        """
        with self.driver.session() as session:
            result = session.run(
                f"""
                MATCH (n:SummarizedNode)
                RETURN
                    n.communityId AS communityId,
                    n.title AS nodeTitle,      
                    n.summary AS nodeSummary,
                    n.keywords AS nodeKeywords 
                """
            )

            communities = {}
            for record in result:
                community_id = record["communityId"]
                node_title = record["nodeTitle"]
                node_summary = record["nodeSummary"]
                node_keywords = record["nodeKeywords"]
                
                communities[community_id] = communities.get(community_id, []) + [{
                    "title": node_title,
                    "summary": node_summary,
                    "keywords": node_keywords
                }]


            return communities

    def drop_graph(self) -> None:
        """
        Drops the graph from memory.
        """
        with self.driver.session() as session:
            session.run(f"CALL gds.graph.drop('{self.graph_name}') YIELD graphName")


In [14]:
##################################################
# 2.5 GRAPH COMMUNITIES → COMMUNITY SUMMARIES
##################################################

def summarize_communities():
    """
    (Step 2.5) Summarize each community (or sub-community in a hierarchical approach).
    - Gather all element summaries (nodes, edges, covariates) in that community.
    - Summarize them, potentially chunking if they don't fit in an LLM context window.
    """
    # Placeholder logic
    print("[Community Summaries] Placeholder: gather summaries and do hierarchical summarization.")

In [None]:
import openai

# Function to generate a summary for each community
def generate_community_summary(community_nodes, api_key, model="gpt-4"):
    """
    Generates summaries for communities using OpenAI's GPT model.

    Args:
        community_nodes (dict): A dictionary where keys are community IDs and values are lists of node titles.
        api_key (str): OpenAI API key.
        model (str): Model name to use (default is 'gpt-4').

    Returns:
        dict: Summaries for each community.
    """

    # Initialize the OpenAI API
    openai.api_key = api_key

    summaries = {}

    for community_id, nodes in community_nodes.items():
        # Create a prompt with node information
        prompt = f"Summarize the main topics and themes for the following nodes in community {community_id}:\n"
        prompt += "\n".join(nodes)

        try:
            # Call the OpenAI API
            response = openai.ChatCompletion.create(
                model=model,
                messages=[
                    {"role": "system", "content": "You are an expert data analyst."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=200
            )

            # Extract and store the summary
            summaries[community_id] = response["choices"][0]["message"]["content"].strip()

        except Exception as e:
            summaries[community_id] = f"Error generating summary: {str(e)}"

    return summaries


# Example usage
if __name__ == "__main__":
    # Replace with your OpenAI API key
    api_key = "your-openai-api-key"

    # Example community nodes
    community_nodes = {
        1: ["Node A", "Node B", "Node C"],
        2: ["Node D", "Node E"],
        3: ["Node F", "Node G", "Node H"]
    }

    summaries = generate_community_summary(community_nodes, api_key)

    for community_id, summary in summaries.items():
        print(f"Community {community_id}: {summary}")


// Step 1: Create the Graph in GDS
CALL gds.graph.project(
    'summarizedGraph',
    'SummarizedNode',
    {
        RELATIONSHIP_TYPE: {
            orientation: 'UNDIRECTED',
            properties: 'weight'
        }
    }
);

// Step 2: Run the Leiden algorithm
CALL gds.leiden.stream('summarizedGraph', { relationshipWeightProperty: 'weight' })
YIELD nodeId, communityId
RETURN gds.util.asNode(nodeId).title AS nodeTitle, communityId;

// Step 3: Drop the graph from memory
CALL gds.graph.drop('summarizedGraph');


In [15]:
##################################################
# 2.6 COMMUNITY SUMMARIES → COMMUNITY ANSWERS → GLOBAL ANSWER
##################################################

def answer_query_from_communities(user_query: str) -> str:
    """
    (Step 2.6) Use the hierarchical community summaries to answer user queries globally.
    - In an actual implementation, you'd fetch the relevant community summaries, chunk them,
      run partial QA on each chunk, rank answers by helpfulness, and then produce a final answer.
    - Below is a simplified approach that just returns a single, direct LLM-based QA.
    """
    # Placeholder logic
    # If you have multiple community summaries, you'd do partial QA in parallel, rank by
    # self-reported "helpfulness" (0-100), then combine or reduce them into a global answer.

    return f"Global answer to '{user_query}' (placeholder)."

In [16]:
##################################################
# 5) INGEST PDF -> STORE IN GRAPH (Putting Steps 2.1 and 2.2+ in context)
##################################################

# Define parallelized function for processing a single chunk
def process_chunk_wrapper(i, chunk_text_str):
    # Step 5: Extract element instances
    #print(f"Extracting element instances from chunk {i}...")
    elements_instance, not_parsed_instance = extract_element_instances_from_chunk(chunk_text_str, USE_OPENAI=USE_OPENAI)
    #print(f"Extracted {len(elements_instance)} elements from chunk {i}\n\n")
    return elements_instance, not_parsed_instance


def ingest_pdf_into_graph(pdf_path: str, doc_id: str, embed_opt: bool = False, BATCH_SIZE: int = 25):
    """
    1) Parse PDF into raw text.
    2) Chunk it (Step 2.1).
    3) Generate embeddings for each chunk.
    4) Store chunk nodes in Neo4j.
    5) For each chunk, call LLM to extract element instances (Step 2.2).
    6) Summarize them into a single descriptive block (Step 2.3).
    7) Optionally store the block in Neo4j for further community detection.
    """
    # Step 1: Parse PDF
    print(f"Parsing PDF at {pdf_path}...")
    raw_text = parse_pdf(pdf_path)
    print(f"Extracted {len(raw_text)} characters from {pdf_path} \n\n")

    # Step 1.1: Retrieve metadata
    print(f"Retrieving metadata for {doc_id}...")
    metadata = find_metadata(doc_id)
    print(f"Metadata: {metadata}\n\n")

    # Step 2: Chunk the text (default chunk_size=600 for improved recall)
    chunks = chunk_text(raw_text)
    print(f"Chunked {len(chunks)} segments from {pdf_path} \n\n")

    # Step 3: Embeddings
    if embed_opt:
        print("Generating embeddings for each chunk...")
        # if USE_OPENAI:
        #     # Implement an OpenAI embedding function if desired
        #     raise NotImplementedError("OpenAI embeddings not implemented here.")
        # else:
        embed_fn = get_hf_embedding_function()

        add_documents_to_milvus([
            {
                "text": chunk, 
                "metadata": {
                    "doc_id": doc_id, 
                    "chunk": i, 
                    **metadata
                    }
            } for i, chunk in enumerate(chunks)], embed_fn)
        print(f"Stored {len(chunks)} chunks in Milvus under Document {doc_id} \n\n")

    # Step 4: Store chunk nodes in Neo4j
    # Process chunks in parallel with batch size of 25
    for lim in tqdm.tqdm(range(0, len(chunks), BATCH_SIZE), desc="Processing chunks in parallel"):
        extracted_elements = []
        not_parsed_elements = []

        with concurrent.futures.ThreadPoolExecutor() as executor:
            # Submit tasks
            futures = [
                executor.submit(process_chunk_wrapper, i, chunk_text_str) 
                for i, chunk_text_str in enumerate(chunks[lim:lim+BATCH_SIZE])
            ]
            
            # Process results as they complete
            for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing chunks"):
                elements_instance, not_parsed_instance = future.result()
                extracted_elements.extend(elements_instance)
                not_parsed_elements.extend(not_parsed_instance)

        print("Parallel processing completed.")
            
        # Step 6: Summarize them (element-level)
        print("Summarizing element instances...")
        element_summary = summarize_element_instances(extracted_elements, USE_OPENAI=USE_OPENAI)

        print(f"Summarized {len(element_summary['summarized_nodes'])} nodes and {len(element_summary['summarized_relationships'])} relationships\n\n")

        print("Storing backup files...")
        # Backup element_summary, extracted_elements and not_parsed_elements
        with open(f'backup_extraction_nodes/element_summary_{doc_id}_{lim}.json', 'w') as f:
            f.write(str(element_summary))
        
        with open(f'backup_extraction_nodes/extracted_elements_{doc_id}_{lim}.json', 'w') as f:
            f.write(str(extracted_elements))
        
        with open(f'backup_extraction_nodes/not_parsed_elements_{doc_id}_{lim}.json', 'w') as f:
            f.write(str(not_parsed_elements))
        
        print("Backup files stored.\n\n")

        # Step 7: Store the summary
        print("Storing element summary in Neo4j...")
        with driver.session() as session:
            # Step 7: Store the summary
            session.execute_write(store_element_summary_in_graph, element_summary, doc_id)
    
    print(f"Ingested {len(chunks)} chunks from {pdf_path} into Neo4j under Document {doc_id}")
    return element_summary, extracted_elements, not_parsed_elements

    

In [17]:
##################################################
# QA / RAG PIPELINE (Simplified)
##################################################

def perform_qa_with_graph(user_query: str) -> str:
    """
    A simplified RAG approach:
    1) Retrieve relevant chunks from the Neo4j graph (naive text search).
    2) Build a context from those chunks.
    3) Use either an open model or an OpenAI model for generative answer.

    NOTE: This doesn't incorporate full community-based summarization from 2.6.
    For a more complete approach, see `answer_query_from_communities()`.
    """
    with driver.session() as session:
        candidate_chunks = session.read_transaction(retrieve_relevant_chunks, user_query)

    # Build the context
    context = "\n\n".join([c["text"] for c in candidate_chunks])

    # Use a HuggingFace or OpenAI model for generation
    if USE_OPENAI:
        # If using OpenAI ChatCompletion:
        # openai.api_key = OPENAI_API_KEY
        # response = openai.ChatCompletion.create(
        #     model="gpt-3.5-turbo",
        #     messages=[
        #         {"role": "system", "content": "You are a helpful assistant."},
        #         {"role": "user", "content": f"Context: {context}\n\nQuestion: {user_query}"}
        #     ]
        # )
        # answer = response["choices"][0]["message"]["content"]
        raise NotImplementedError("OpenAI ChatCompletion usage not fully implemented here.")
    else:
        # Example with a local HF pipeline
        qa_pipeline = pipeline("text-generation", model="bigscience/bloom-560m")
        prompt = f"Context: {context}\nQuestion: {user_query}\nAnswer:"
        answer_list = qa_pipeline(prompt, max_new_tokens=100, do_sample=True)
        if answer_list:
            answer = answer_list[0]["generated_text"].split("Answer:")[-1].strip()
        else:
            answer = "No answer generated."

    return answer

In [239]:
#############################################
# MAIN EXECUTION EXAMPLE
#############################################

In [198]:
# 1) Ingest an arXiv PDF (Steps 2.1–2.3)
pdf_path = "data/docs/0704.2547.pdf"  # Replace with the path to your local arXiv PDF
doc_id = "0704.2547"           # Arbitrary doc ID for grouping in Neo4j
element_summary, extracted_elements, not_parsed_elements=ingest_pdf_into_graph(pdf_path, doc_id)

Parsing PDF at data/docs/0704.2547.pdf...
Extracted 148193 characters from data/docs/0704.2547.pdf 


Retrieving metadata for 0704.2547...
Metadata: {'title': 'Inferring DNA sequences from mechanical unzipping data: the large-bandwidth case', 'summary': 'The complementary strands of DNA molecules can be separated when stretched\napart by a force; the unzipping signal is correlated to the base content of the\nsequence but is affected by thermal and instrumental noise. We consider here\nthe ideal case where opening events are known to a very good time resolution\n(very large bandwidth), and study how the sequence can be reconstructed from\nthe unzipping data. Our approach relies on the use of statistical Bayesian\ninference and of Viterbi decoding algorithm. Performances are studied\nnumerically on Monte Carlo generated data, and analytically. We show how\nmultiple unzippings of the same molecule may be exploited to improve the\nquality of the prediction, and calculate analytically the n

Processing chunks: 100%|██████████| 25/25 [00:43<00:00,  1.76s/it]s]


Parallel processing completed.
Summarizing element instances...
Summarized 11 nodes and 7 relationships


Storing backup files...
Backup files stored.


Storing element summary in Neo4j...


Processing chunks: 100%|██████████| 25/25 [00:25<00:00,  1.04s/it]67.30s/it]


Parallel processing completed.
Summarizing element instances...
Summarized 15 nodes and 13 relationships


Storing backup files...
Backup files stored.


Storing element summary in Neo4j...


Processing chunks: 100%|██████████| 25/25 [00:42<00:00,  1.70s/it]72.03s/it]


Parallel processing completed.
Summarizing element instances...
Summarized 12 nodes and 6 relationships


Storing backup files...
Backup files stored.


Storing element summary in Neo4j...


Processing chunks: 100%|██████████| 25/25 [00:44<00:00,  1.77s/it]70.85s/it]


Parallel processing completed.
Summarizing element instances...
Summarized 11 nodes and 7 relationships


Storing backup files...
Backup files stored.


Storing element summary in Neo4j...


Processing chunks: 100%|██████████| 25/25 [00:29<00:00,  1.18s/it]68.25s/it]


Parallel processing completed.
Summarizing element instances...
Summarized 8 nodes and 5 relationships


Storing backup files...
Backup files stored.


Storing element summary in Neo4j...


Processing chunks: 100%|██████████| 9/9 [00:20<00:00,  2.24s/it], 60.01s/it]


Parallel processing completed.
Summarizing element instances...
Summarized 11 nodes and 9 relationships


Storing backup files...
Backup files stored.


Storing element summary in Neo4j...


Processing chunks in parallel: 100%|██████████| 6/6 [06:13<00:00, 62.18s/it]

Ingested 284 chunks from data/docs/0704.2547.pdf into Neo4j under Document 0704.2547





In [None]:
raise NotImplementedError("The following steps are not yet implemented.")

NotImplementedError: The following steps are not yet implemented.

In [44]:
# 2) Community detection & summarization (Steps 2.4–2.5)

# Initialize the community detection class
detector = CommunityDetection(driver)

# Project the graph for community detection
detector.project_graph()

# Set community IDs in the graph
detector.set_communities()

# Retrieve the communities
communities = detector.retrieve_communities()

# Drop the graph from memory
detector.drop_graph()

In [32]:
summarize_communities(communities)

[Community Summaries] Placeholder: gather summaries and do hierarchical summarization.


In [22]:
# 3) Ask a question (Step 2.6 simplified vs. full approach)
user_query = "What are the main contributions of the paper?"
# Simple QA (naive RAG):
answer = perform_qa_with_graph(user_query)
print(f"User's question: {user_query}")
print(f"Answer: {answer}")

  candidate_chunks = session.read_transaction(retrieve_relevant_chunks, user_query)


NotImplementedError: OpenAI ChatCompletion usage not fully implemented here.

In [None]:
# Alternatively, full approach using community-based QA:
# global_answer = answer_query_from_communities(user_query)
# print(f"Global Answer: {global_answer}")