## Intelligent Video Processing Knowledge Graph Enrichment : 
### Part 2: Knowledge Graph creation and Enrichment

Importing and updating required libraries

In [None]:
%pip install --upgrade --quiet  langchain langchain-community langchain-google-genai langchain-experimental neo4j tiktoken yfiles_jupyter_graphs

Acessing data from google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install json-repair

In [None]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

In [None]:
from google.colab import userdata
GEMINI_API_KEY='YOUR_GEMINI'


In [None]:
from typing import Tuple, List, Optional

In [None]:
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser

In [None]:
from langchain_core.runnables import ConfigurableField

In [None]:
from yfiles_jupyter_graphs import GraphWidget
from neo4j import GraphDatabase


In [None]:
import os

In [None]:
try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

In [None]:
from langchain_community.vectorstores import Neo4jVector

Required Cridentials for Neo4j auraDB

In [None]:

NEO4J_URI=""
NEO4J_USERNAME=""
NEO4J_PASSWORD=""

In [None]:
os.environ["GEMINI_API_KEY"] = GEMINI_API_KEY
os.environ["NEO4J_URI"] = NEO4J_URI
os.environ["NEO4J_USERNAME"] = NEO4J_USERNAME
os.environ["NEO4J_PASSWORD"] = NEO4J_PASSWORD

In [None]:
from langchain_community.graphs import Neo4jGraph

In [None]:
graph = Neo4jGraph()

Using Langchain pakage for graph transformer.

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_experimental.graph_transformers import LLMGraphTransformer

llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=GEMINI_API_KEY)
llm_transformer = LLMGraphTransformer(llm=llm)
# gemini-2.5-pro
# gemini-2.0-flash

Creating Knowledge Graph

In [None]:
from langchain.document_loaders import TextLoader
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import TokenTextSplitter
import time
import math

# Load all .txt files
INPUT_DIR = "/content/drive/MyDrive/10_sample"
loader = DirectoryLoader(INPUT_DIR, glob="*.txt", loader_cls=TextLoader)
raw_documents = loader.load()

# Init splitter
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)

# Batching setup
batch_size = 10
pause_duration = 60  # Increased pause duration
total_batches = math.ceil(len(raw_documents) / batch_size)

# Batch processing loop
for i in range(0, len(raw_documents), batch_size):
    print(f"\nüîÑ Processing batch {i // batch_size + 1} / {total_batches}")
    batch_docs = raw_documents[i:i+batch_size]

    # Step 1: Chunking
    documents = text_splitter.split_documents(batch_docs)

    # Step 2: LLM ‚Üí Graph
    graph_documents = llm_transformer.convert_to_graph_documents(documents)

    # Step 3: Push to Neo4j
    graph.add_graph_documents(
        graph_documents,
        baseEntityLabel=True,
        include_source=True
    )

    # Step 4: Pause
    print(f"‚úÖ Batch {i // batch_size + 1} completed. Sleeping for {pause_duration} seconds...\n")
    time.sleep(pause_duration)




In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI


In [None]:
from langchain_experimental.graph_transformers import LLMGraphTransformer
llm_transformer = LLMGraphTransformer(llm=llm)

In [None]:
graph_documents = llm_transformer.convert_to_graph_documents(documents)

In [None]:
graph_documents

In [None]:
graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True
)

In [None]:
# directly show the graph resulting from the given Cypher query
default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 5000"

In [None]:
from yfiles_jupyter_graphs import GraphWidget
from neo4j import GraphDatabase

In [None]:
try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

In [None]:
def showGraph(cypher: str = default_cypher):
    # create a neo4j session to run queries
    driver = GraphDatabase.driver(
        uri = os.environ["NEO4J_URI"],
        auth = (os.environ["NEO4J_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    display(widget)
    return widget

Showing created Knowledge graph

In [None]:
showGraph()

In [None]:
from typing import Tuple, List, Optional

In [None]:
from langchain_community.vectorstores import Neo4jVector

In [None]:
from langchain_google_genai import GoogleGenerativeAIEmbeddings
vector_index = Neo4jVector.from_existing_graph(
    GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GEMINI_API_KEY),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

# gemini-embedding-exp-03-07
# models/embedding-001

In [None]:
graph.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

In [None]:
from langchain_core.pydantic_v1 import BaseModel, Field
# Extract entities from text
class Entities(BaseModel):
    """Identifying information about entities."""

    names: List[str] = Field(
        ...,
        description="All the person, organization, or business entities that "
        "appear in the text",
    )


In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

In [None]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are extracting organization and person entities from the text.",
        ),
        (
            "human",
            "Use the given format to extract information from the following "
            "input: {question}",
        ),
    ]
)

In [None]:
entity_chain = prompt | llm.with_structured_output(Entities)

In [None]:
entity_chain.invoke({"question": "What is Chain of Thought (CoT) Reasoning and which papers it is mentioned mentioned?"}).names

In [None]:
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars

In [None]:
def generate_full_text_query(input: str) -> str:
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()


structured retriver from graph data base

In [None]:
# Fulltext index query
def structured_retriever(question: str) -> str:
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
            YIELD node,score
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output
            } IN TRANSACTIONS
            RETURN output LIMIT 50
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    return result

In [None]:
print(structured_retriever("what is INFERENCE-TIME-COMPUTE?"))

hybrid retriver: vector + structured retriver

In [None]:
def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
{structured_data}
Unstructured data:
{"#Document ". join(unstructured_data)}
    """
    return final_data

In [None]:
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""

In [None]:
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

In [None]:
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

In [None]:
_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | ChatGoogleGenerativeAI(model="gemini-2.5-pro", temperature=0)
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(lambda x : x["question"]),
)

In [None]:
template = """Answer the question based only on the following context:
{context}

Question: {question}
Use natural language and be concise.
Answer:"""

In [None]:
prompt = ChatPromptTemplate.from_template(template)

In [None]:
chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

Q&A validation Before Enrichment process

In [None]:
chain.invoke({"question": "What is the time stamp for Problem of Faithfulness?"})

In [None]:
chain.invoke({"question": "What is Inference-Time-Compute (ITC) Models?"})

In [None]:
chain.invoke({"question": "what is Adversarial Robustness?"})

In [None]:
chain.invoke({"question": "title of the document which mentions about Adversarial Robustness?"})

In [None]:
chain.invoke({"question": "Is there any use of AI agents or LLMs mentioned in TRADING INFERENCE-TIME COMPUTE FOR ADVERSARIAL ROBUSTNESS?"})

In [None]:
chain.invoke({"question": "what is the usage of AI Agents mentoned in that?"})

In [None]:
chain.invoke({"question": "Is there any relation between Antidistillation Sampling and INFERENCE-TIME COMPUTE FOR ADVERSARIAL ROBUSTNESS?"})

In [None]:
chain.invoke({"question": "What exploration strategies do large language models use in open-ended tasks?"})

In [None]:

chain.invoke({"question": "How does AlphaGeometry2 achieve gold-medalist performance in IMO geometry problems?"})

In [None]:
chain.invoke({"question": "What roles do the crawler and selector play in the PaSa agent system?"})


In [None]:
chain.invoke({"question": "What is the purpose of Continuous Concept Mixing (CoCoMix) in LLM pretraining?"})

In [None]:
chain.invoke({"question": "Which language model performed best in the Little Alchemy 2 exploration study?"})

In [None]:
chain.invoke({"question": "What improvements does AG2 have over AG1 in symbolic engine design?"})

In [None]:
chain.invoke({"question": "Which techniques does AlphaGeometry2 use for domain language expansion in geometry problems?"})

statistics related created knowledge graph

In [None]:
def run_and_print(query, description):
    print(f"\nüìä {description}")
    result = graph.query(query)
    for row in result:
        print(dict(row))

# 1. Total number of nodes
run_and_print(
    "MATCH (n) RETURN count(n) AS total_nodes",
    "1. Total number of nodes"
)

# 2. Total number of relationships
run_and_print(
    "MATCH ()-[r]->() RETURN count(r) AS total_relationships",
    "2. Total number of relationships"
)

# 3. Number of nodes per label
run_and_print(
    "MATCH (n) RETURN labels(n)[0] AS label, count(*) AS count ORDER BY count DESC",
    "3. Number of nodes per label"
)

# 4. Number of relationships per type
run_and_print(
    "MATCH ()-[r]->() RETURN type(r) AS rel_type, count(*) AS count ORDER BY count DESC",
    "4. Number of relationships per type"
)

# 5. Top 5 most connected nodes by total degree
run_and_print(
    """
    MATCH (n)
    WITH n, COUNT { (n)-->() } + COUNT { ()-->(n) } AS totalDegree
    RETURN n, totalDegree
    ORDER BY totalDegree DESC
    LIMIT 5
    """,
    "5. Top 5 most connected nodes (total degree)"
)

# 6. Top 5 nodes by incoming relationships
run_and_print(
    """
    MATCH (n)
    WITH n, COUNT { ()-->(n) } AS inDegree
    RETURN n, inDegree
    ORDER BY inDegree DESC
    LIMIT 5
    """,
    "6. Top 5 nodes by incoming relationships"
)

# 7. Top 5 nodes by outgoing relationships
run_and_print(
    """
    MATCH (n)
    WITH n, COUNT { (n)-->() } AS outDegree
    RETURN n, outDegree
    ORDER BY outDegree DESC
    LIMIT 5
    """,
    "7. Top 5 nodes by outgoing relationships"
)

# 8. Average degree per node
run_and_print(
    """
    MATCH (n)
    WITH COUNT { (n)-->() } + COUNT { ()-->(n) } AS degree
    RETURN avg(degree) AS avg_degree
    """,
    "8. Average degree per node"
)

# 9. List all distinct relationship types
run_and_print(
    "MATCH ()-[r]->() RETURN DISTINCT type(r) AS relationship_type",
    "9. All distinct relationship types"
)

# 10. List all distinct labels
run_and_print(
    "MATCH (n) RETURN DISTINCT labels(n)[0] AS label",
    "10. All distinct node labels"
)


In [None]:
# graph.query("MATCH (n) DETACH DELETE n")


In [None]:
%pip install -U langchain-community

getting statiscics before enrichment

In [None]:

def get_graph_stats(title="Graph Statistics"):
    """Prints statistics of the current Neo4j graph."""
    print(f"\nüìä {title}")
    queries = {
        "Total Nodes": "MATCH (n) RETURN count(n) AS count",
        "Total Relationships": "MATCH ()-[r]->() RETURN count(r) AS count",
        "Nodes by Label": "MATCH (n) RETURN labels(n)[0] AS label, count(*) AS count ORDER BY count DESC",
        "Relationships by Type": "MATCH ()-[r]->() RETURN type(r) AS type, count(*) AS count ORDER BY count DESC",
        "Top Connected Nodes": """
            MATCH (n)-[r]-()
            RETURN n.id AS node_id, count(r) AS connections
            ORDER BY connections DESC LIMIT 10
        """
    }
    for desc, query in queries.items():
        result = graph.query(query)
        print(f"\nüîπ {desc}:")
        for row in result:
            print("   ", dict(row))



In [None]:
get_graph_stats("üì• BEFORE ENRICHMENT")


Method 1: Knowledge graph Enrichment with manual categorization

In [None]:
import time
import math
import difflib
from langchain.text_splitter import TokenTextSplitter
from langchain.document_loaders import TextLoader
from langchain_community.document_loaders import DirectoryLoader
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser

def get_existing_entities_by_type(graph):
    """Get existing entities grouped by type for better matching"""
    result = graph.query("""
        MATCH (n)
        RETURN DISTINCT labels(n) as labels, n.id as id, n.type as type
    """)

    entities_by_type = {
        'papers': [],
        'authors': [],
        'models': [],
        'datasets': [],
        'concepts': [],
        'methods': [],
        'other': []
    }

    for row in result:
        entity_id = row.get("id", "")
        if not isinstance(entity_id, str):
            continue

        # Categorize entities based on patterns
        entity_lower = entity_id.lower()
        if any(keyword in entity_lower for keyword in ['et al', 'paper', 'arxiv', 'proceedings']):
            entities_by_type['papers'].append(entity_id)
        elif any(keyword in entity_lower for keyword in ['model', 'llm', 'gpt', 'bert', 'llama']):
            entities_by_type['models'].append(entity_id)
        elif any(keyword in entity_lower for keyword in ['dataset', 'benchmark', 'gsm', 'imagenet']):
            entities_by_type['datasets'].append(entity_id)
        elif any(keyword in entity_lower for keyword in ['training', 'learning', 'optimization']):
            entities_by_type['methods'].append(entity_id)
        else:
            entities_by_type['other'].append(entity_id)

    return entities_by_type

def smart_entity_matching(new_entity, entities_by_type, threshold=0.9):
    """Smarter entity matching considering semantic context"""
    new_entity_lower = new_entity.lower()

    # Determine entity type
    if any(keyword in new_entity_lower for keyword in ['et al', 'paper', 'arxiv']):
        candidates = entities_by_type['papers']
    elif any(keyword in new_entity_lower for keyword in ['model', 'llm', 'gpt', 'bert', 'llama']):
        candidates = entities_by_type['models']
    elif any(keyword in new_entity_lower for keyword in ['dataset', 'benchmark', 'gsm', 'imagenet']):
        candidates = entities_by_type['datasets']
    elif any(keyword in new_entity_lower for keyword in ['training', 'learning', 'optimization']):
        candidates = entities_by_type['methods']
    else:
        candidates = entities_by_type['other']

    # Find best match within the same category
    matches = difflib.get_close_matches(new_entity, candidates, n=1, cutoff=threshold)

    # Additional validation for critical entities
    if matches:
        match = matches[0]

        # Special validation for author citations
        if 'et al' in new_entity_lower and 'et al' in match.lower():
            # Extract years and main author names for comparison
            new_parts = new_entity.lower().split()
            match_parts = match.lower().split()

            # If years are different, don't match
            new_year = next((part for part in new_parts if part.isdigit() and len(part) == 4), None)
            match_year = next((part for part in match_parts if part.isdigit() and len(part) == 4), None)

            if new_year and match_year and new_year != match_year:
                return new_entity  # Don't match different years

        return match

    return new_entity

def create_enhanced_enrichment_prompt():
    """Create a better enrichment prompt"""
    return ChatPromptTemplate.from_messages([
        ("system", """You are an expert knowledge graph enrichment assistant.

        Your task is to analyze new research text and identify entities that should be linked to existing knowledge.

        Guidelines:
        1. Preserve exact entity names when they represent the same concept
        2. Be conservative - only link entities if you're confident they're the same
        3. For author citations, match only if the main author and year are identical
        4. For models/datasets, match only if they're clearly the same version
        5. Extract all relevant entities, relationships, and properties

        Known entities in the graph: {known_entities}

        Analyze the following text and extract entities while being mindful of existing ones:"""),
        ("human", "{input_text}")
    ])

def enrich_knowledge_graph_improved(input_path, batch_size=5, pause_duration=10):
    """Improved knowledge graph enrichment with better entity matching"""

    # Load documents
    loader = DirectoryLoader(input_path, glob="*.txt", loader_cls=TextLoader)
    raw_documents = loader.load()
    text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
    total_batches = math.ceil(len(raw_documents) / batch_size)

    # Initialize LLM
    llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.environ["GEMINI_API_KEY"])
    llm_transformer = LLMGraphTransformer(llm=llm)

    # Get existing entities by type
    entities_by_type = get_existing_entities_by_type(graph)
    all_known_entities = []
    for entity_list in entities_by_type.values():
        all_known_entities.extend(entity_list)

    print(f"üìä Found {len(all_known_entities)} existing entities in the graph")

    # Create enrichment chain
    enrichment_prompt = create_enhanced_enrichment_prompt()
    enrichment_chain = enrichment_prompt | llm | StrOutputParser()

    for i in range(0, len(raw_documents), batch_size):
        print(f"\nüîÑ Enriching batch {i // batch_size + 1} / {total_batches}")
        batch_docs = raw_documents[i:i+batch_size]
        documents = text_splitter.split_documents(batch_docs)

        # Process each document with enrichment context
        enriched_documents = []
        for doc in documents:
            try:
                # Get enriched text from LLM
                enriched_text = enrichment_chain.invoke({
                    "known_entities": ", ".join(all_known_entities[:100]),  # Limit to avoid token overflow
                    "input_text": doc.page_content
                })

                enriched_documents.append(Document(
                    page_content=enriched_text,
                    metadata=doc.metadata
                ))
            except Exception as e:
                print(f"‚ö†Ô∏è Error enriching document: {e}")
                enriched_documents.append(doc)  # Use original if enrichment fails

        # Convert to graph documents
        graph_documents = llm_transformer.convert_to_graph_documents(enriched_documents)

        # Apply smart entity matching
        renamed_count = 0
        for gdoc in graph_documents:
            # Process nodes
            for node in gdoc.nodes:
                original_id = node.id
                matched_id = smart_entity_matching(node.id, entities_by_type, threshold=0.92)
                if original_id != matched_id:
                    node.id = matched_id
                    renamed_count += 1
                    print(f"üîÅ Renamed node '{original_id}' ‚Üí '{matched_id}'")

            # Process relationships
            for rel in gdoc.relationships:
                rel.source.id = smart_entity_matching(rel.source.id, entities_by_type, threshold=0.92)
                rel.target.id = smart_entity_matching(rel.target.id, entities_by_type, threshold=0.92)

        # Add to graph
        try:
            graph.add_graph_documents(
                graph_documents,
                baseEntityLabel=True,
                include_source=True
            )
            print(f"‚úÖ Batch {i // batch_size + 1} processed. {renamed_count} entities matched.")
        except Exception as e:
            print(f"‚ùå Error adding batch to graph: {e}")

        print(f"üí§ Sleeping {pause_duration} seconds...\n")
        time.sleep(pause_duration)

    print("üéâ Knowledge graph enrichment completed!")

# Function to validate the enrichment results
def validate_enrichment_quality(sample_size=10):
    """Check for potential over-matching issues"""

    # Find nodes that might have been incorrectly merged
    suspicious_merges = graph.query("""
        MATCH (n)
        WHERE size(n.id) > 20  // Long entity names might indicate merged entities
        WITH n, [rel in [(n)-[r]-() | type(r)] | rel] as relationships
        WHERE size(relationships) > 5  // Nodes with many relationships might be over-merged
        RETURN n.id as entity, size(relationships) as relationship_count
        ORDER BY relationship_count DESC
        LIMIT $sample_size
    """, {"sample_size": sample_size})

    print("üîç Potentially over-merged entities:")
    for row in suspicious_merges:
        print(f"  - {row['entity']} ({row['relationship_count']} relationships)")

    # Check for duplicate-like entities that should have been merged
    potential_duplicates = graph.query("""
        MATCH (n1), (n2)
        WHERE n1.id < n2.id
        AND (
            n1.id CONTAINS n2.id OR
            n2.id CONTAINS n1.id OR
            apoc.text.levenshteinSimilarity(n1.id, n2.id) > 0.8
        )
        RETURN n1.id as entity1, n2.id as entity2,
               apoc.text.levenshteinSimilarity(n1.id, n2.id) as similarity
        ORDER BY similarity DESC
        LIMIT $sample_size
    """, {"sample_size": sample_size})

    print("\nüîç Potential duplicates that weren't merged:")
    for row in potential_duplicates:
        print(f"  - '{row['entity1']}' vs '{row['entity2']}' (similarity: {row['similarity']:.2f})")

# Usage example:
enrich_knowledge_graph_improved("/content/drive/MyDrive/enrich_10", batch_size=5, pause_duration=10)
validate_enrichment_quality()

In [None]:
showGraph()

knowledge graph statistics after enrichment

In [None]:
def run_and_print(query, description):
    print(f"\nüìä {description}")
    result = graph.query(query)
    for row in result:
        print(dict(row))

# 1. Total number of nodes
run_and_print(
    "MATCH (n) RETURN count(n) AS total_nodes",
    "1. Total number of nodes"
)

# 2. Total number of relationships
run_and_print(
    "MATCH ()-[r]->() RETURN count(r) AS total_relationships",
    "2. Total number of relationships"
)

# 3. Number of nodes per label
run_and_print(
    "MATCH (n) RETURN labels(n)[0] AS label, count(*) AS count ORDER BY count DESC",
    "3. Number of nodes per label"
)

# 4. Number of relationships per type
run_and_print(
    "MATCH ()-[r]->() RETURN type(r) AS rel_type, count(*) AS count ORDER BY count DESC",
    "4. Number of relationships per type"
)

# 5. Top 5 most connected nodes by total degree
run_and_print(
    """
    MATCH (n)
    WITH n, COUNT { (n)-->() } + COUNT { ()-->(n) } AS totalDegree
    RETURN n, totalDegree
    ORDER BY totalDegree DESC
    LIMIT 5
    """,
    "5. Top 5 most connected nodes (total degree)"
)

# 6. Top 5 nodes by incoming relationships
run_and_print(
    """
    MATCH (n)
    WITH n, COUNT { ()-->(n) } AS inDegree
    RETURN n, inDegree
    ORDER BY inDegree DESC
    LIMIT 5
    """,
    "6. Top 5 nodes by incoming relationships"
)

# 7. Top 5 nodes by outgoing relationships
run_and_print(
    """
    MATCH (n)
    WITH n, COUNT { (n)-->() } AS outDegree
    RETURN n, outDegree
    ORDER BY outDegree DESC
    LIMIT 5
    """,
    "7. Top 5 nodes by outgoing relationships"
)

# 8. Average degree per node
run_and_print(
    """
    MATCH (n)
    WITH COUNT { (n)-->() } + COUNT { ()-->(n) } AS degree
    RETURN avg(degree) AS avg_degree
    """,
    "8. Average degree per node"
)

# 9. List all distinct relationship types
run_and_print(
    "MATCH ()-[r]->() RETURN DISTINCT type(r) AS relationship_type",
    "9. All distinct relationship types"
)

# 10. List all distinct labels
run_and_print(
    "MATCH (n) RETURN DISTINCT labels(n)[0] AS label",
    "10. All distinct node labels"
)


Q&A validation After Knowledge Graph Enrichment

In [None]:
chain.invoke({"question": "Which techniques does AlphaGeometry2 use for domain language expansion in geometry problems?"})

In [None]:
chain.invoke({"question": "What exploration strategies do large language models use in open-ended tasks?"})

In [None]:
chain.invoke({"question": "How does AlphaGeometry2 achieve gold-medalist performance in IMO geometry problems?"})


In [None]:
chain.invoke({"question": "What roles do the crawler and selector play in the PaSa agent system?"})

In [None]:
chain.invoke({"question": "What is the purpose of Continuous Concept Mixing (CoCoMix) in LLM pretraining?"})


In [None]:
chain.invoke({"question": "What recent trends emerge from combining symbolic and neural approaches in LLM systems?"})
#Enrichment should surface a meta-pattern from multiple documents.

In [None]:
chain.invoke({"question": "What is the impact of reinforcement learning in academic paper search according to PaSa?"})
#Method + Impact Pairing

In [None]:
chain.invoke({"question": "Which system better supports autonomous paper search: PaSa or AG2, and why?"})
#Requires semantic graph linking between purpose, architecture, and performance.
# Implicit Framework Comparisons

In [None]:
chain.invoke({"question": "Why are sparse autoencoders important in both exploration analysis and CoCoMix concept modeling?"})
# Enrichment should surface that SAEs are a common analytical backbone used for latent feature understanding in both tasks.
# Inferred Relationships

Graph statistics after enrichment process

In [None]:
def get_graph_stats(title="Graph Statistics"):
    """Prints statistics of the current Neo4j graph."""
    print(f"\nüìä {title}")
    queries = {
        "Total Nodes": "MATCH (n) RETURN count(n) AS count",
        "Total Relationships": "MATCH ()-[r]->() RETURN count(r) AS count",
        "Nodes by Label": "MATCH (n) RETURN labels(n)[0] AS label, count(*) AS count ORDER BY count DESC",
        "Relationships by Type": "MATCH ()-[r]->() RETURN type(r) AS type, count(*) AS count ORDER BY count DESC",
        "Top Connected Nodes": """
            MATCH (n)-[r]-()
            RETURN n.id AS node_id, count(r) AS connections
            ORDER BY connections DESC LIMIT 10
        """
    }
    for desc, query in queries.items():
        result = graph.query(query)
        print(f"\nüîπ {desc}:")
        for row in result:
            print("   ", dict(row))

In [None]:
get_graph_stats("üì• AFTER ENRICHMENT")

Method 2: LLM and FAISS based Knowledge graph Enrichment

In [None]:
!pip install faiss-cpu

In [None]:
import os
import time
import math
import logging
import numpy as np
import faiss
from langchain.text_splitter import TokenTextSplitter
from langchain.document_loaders import TextLoader
from langchain_community.document_loaders import DirectoryLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_experimental.graph_transformers import LLMGraphTransformer

# ===============================
# Logging Setup
# ===============================
def setup_logger(log_path="kg_log.log"):
    logger = logging.getLogger("kg_log")
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    fh = logging.FileHandler(log_path)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    return logger

logger = setup_logger()

# ===============================
# FAISS-based Entity Matcher
# ===============================
def build_faiss_index(entities, embed_model):
    embeddings = [embed_model.embed_query(ent) for ent in entities]
    dim = len(embeddings[0])
    index = faiss.IndexFlatIP(dim)  # Cosine similarity via inner product on normalized vectors
    norm_embeddings = [vec / np.linalg.norm(vec) for vec in embeddings]
    index.add(np.array(norm_embeddings).astype('float32'))
    return index, norm_embeddings

def match_entity_faiss(entity, known_entities, embed_model, index, norm_embeddings, threshold=0.92):
    try:
        query_vec = embed_model.embed_query(entity)
        query_vec = query_vec / np.linalg.norm(query_vec)
        D, I = index.search(np.array([query_vec]).astype('float32'), 1)
        if D[0][0] >= threshold:
            match = known_entities[I[0][0]]
            logger.info(f"üîÅ Matched '{entity}' ‚Üí '{match}' (score: {D[0][0]:.2f})")
            return match
    except Exception as e:
        logger.warning(f"‚ö†Ô∏è Matching failed for: '{entity}' | {e}")
    return entity

# ===============================
# Prompt Template
# ===============================
def create_contextual_prompt():
    return ChatPromptTemplate.from_messages([
        ("system", """You are a knowledge graph enrichment expert.\nGiven a research passage and known entities, extract new entities and their relationships.\nOnly match existing entities if they refer to the same thing.\nAvoid near-duplicates unless truly distinct.\nKnown entities: {known_entities}"""),
        ("human", "{input_text}")
    ])

# ===============================
# Main Enrichment Logic
# ===============================
def enrich_knowledge_graph_faiss(input_path, graph, batch_size=5, pause_duration=10, max_entity_context=100):
    loader = DirectoryLoader(input_path, glob="*.txt", loader_cls=TextLoader)
    raw_documents = loader.load()
    text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
    total_batches = math.ceil(len(raw_documents) / batch_size)

    llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.environ["GEMINI_API_KEY"])
    embed_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=os.environ["GEMINI_API_KEY"])
    llm_transformer = LLMGraphTransformer(llm=llm)

    # Fetch known entities and build FAISS index
    existing_nodes = graph.query("MATCH (n) RETURN DISTINCT n.id AS id")
    known_entities = sorted({row['id'] for row in existing_nodes if row.get("id")})
    logger.info(f"üìä {len(known_entities)} known entities fetched from graph")

    index, norm_embeddings = build_faiss_index(known_entities, embed_model)

    prompt = create_contextual_prompt()
    enrichment_chain = prompt | llm | StrOutputParser()

    for i in range(0, len(raw_documents), batch_size):
        logger.info(f"üîÑ Enriching batch {i // batch_size + 1} / {total_batches}")
        batch_docs = raw_documents[i:i + batch_size]
        documents = text_splitter.split_documents(batch_docs)

        enriched_documents = []
        for idx, doc in enumerate(documents):
            try:
                logger.info(f"   üìÑ Processing doc {idx+1}/{len(documents)}")
                enriched = enrichment_chain.invoke({
                    "known_entities": ", ".join(known_entities[:max_entity_context]),
                    "input_text": doc.page_content[:3000]
                })
                enriched_documents.append(Document(page_content=enriched, metadata=doc.metadata))
            except Exception as e:
                logger.error(f"‚ùå LLM error on document {idx+1}: {e}")
                enriched_documents.append(doc)

        graph_docs = llm_transformer.convert_to_graph_documents(enriched_documents)

        matched_count = 0
        for gdoc in graph_docs:
            for node in gdoc.nodes:
                original = node.id
                updated = match_entity_faiss(original, known_entities, embed_model, index, norm_embeddings)
                if updated != original:
                    node.id = updated
                    matched_count += 1

            for rel in gdoc.relationships:
                rel.source.id = match_entity_faiss(rel.source.id, known_entities, embed_model, index, norm_embeddings)
                rel.target.id = match_entity_faiss(rel.target.id, known_entities, embed_model, index, norm_embeddings)

        try:
            graph.add_graph_documents(graph_docs, baseEntityLabel=True, include_source=True)
            logger.info(f"‚úÖ Batch {i // batch_size + 1} added. üîÅ {matched_count} entities matched.")
        except Exception as e:
            logger.error(f"‚ùå Graph update failed: {e}")

        logger.info(f"üïí Sleeping {pause_duration} seconds...\n")
        time.sleep(pause_duration)

    logger.info("üéâ Knowledge Graph enrichment completed.")


In [None]:
showGraph()

In [None]:
enrich_knowledge_graph_faiss(
    input_path="/content/drive/MyDrive/enrich_10",
    graph=graph,
    batch_size=5,
    pause_duration=10
)


# Step 2: Validate quality
# validate_graph_quality(graph)


In [None]:

def validate_graph_quality(graph, sample_size=20):
    logger = logging.getLogger("kg_log")
    logger.info("\nüìä Validating Knowledge Graph")

    # --------- High-Degree Nodes ---------
    logger.info("\nüîç High-degree nodes (likely over-merges):")
    try:
        high_degree = graph.query("""
            MATCH (n)
            WITH n, COUNT { (n)--() } AS deg
            WHERE deg > 10
            RETURN n.id AS entity, deg
            ORDER BY deg DESC
            LIMIT $sample_size
        """, {"sample_size": sample_size})

        if not high_degree:
            logger.info("  ‚úÖ No high-degree nodes detected.")
        else:
            for row in high_degree:
                logger.info(f"  üî∏ '{row['entity']}' has {row['deg']} connections")
    except Exception as e:
        logger.error(f"‚ùå Failed to query high-degree nodes: {e}")

    # --------- Duplicate Entity Detection ---------
    logger.info("\nüîç Potential duplicates (Levenshtein similarity > 0.92):")
    try:
        similar = graph.query("""
            MATCH (n1), (n2)
            WHERE n1.id < n2.id AND apoc.text.levenshteinSimilarity(n1.id, n2.id) > 0.92
            RETURN n1.id AS e1, n2.id AS e2,
                   apoc.text.levenshteinSimilarity(n1.id, n2.id) AS sim
            ORDER BY sim DESC
            LIMIT $sample_size
        """, {"sample_size": sample_size})

        if not similar:
            logger.info("  ‚úÖ No near-duplicate entities detected.")
        else:
            for row in similar:
                logger.info(f"  ‚ö†Ô∏è '{row['e1']}' ‚Üî '{row['e2']}' (sim: {row['sim']:.2f})")
    except Exception as e:
        logger.error(f"‚ùå Failed to detect duplicates: {e}")


In [None]:
def generate_graph_summary(graph):
    logger = logging.getLogger("kg_log")
    logger.info("\nüìà Summary Report:")

    try:
        # Total nodes
        nodes = graph.query("MATCH (n) RETURN count(n) AS total_nodes")
        total_nodes = nodes[0]["total_nodes"] if nodes else 0
        logger.info(f"üîπ Total Nodes: {total_nodes}")

        # Total relationships
        rels = graph.query("MATCH ()-[r]->() RETURN count(r) AS total_rels")
        total_rels = rels[0]["total_rels"] if rels else 0
        logger.info(f"üîπ Total Relationships: {total_rels}")

        # Sample node labels
        label_samples = graph.query("""
            MATCH (n)
            WITH labels(n) AS lbls
            RETURN DISTINCT lbls
            LIMIT 30
        """)
        logger.info("üîπ Node Labels Sample:")
        for row in label_samples:
            logger.info(f"   - {row['lbls']}")

    except Exception as e:
        logger.error(f"‚ùå Failed to generate graph summary: {e}")


In [None]:
generate_graph_summary(graph)
validate_graph_quality(graph, sample_size=20)


In [None]:
validate_graph_quality(graph, sample_size=10)


In [None]:
compare_before_after_enrichment(graph)
