In [None]:
import os
from dotenv import load_dotenv

from src.config import LLMConf, EmbedderConf, KnowledgeGraphConfig
from src.factory.embeddings import get_embeddings
from src.graph.knowledge_graph import KnowledgeGraph
from neo4j import Query, Session
from src.schema import Chunk
env = load_dotenv('config.env')

from src.utils.logger import get_logger
logger = get_logger(__name__)

In [None]:
kg_config = KnowledgeGraphConfig(
    uri=os.getenv("NEO4J_URI"),
    user=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
    index_name='vector'
)

embedder_conf = EmbedderConf(
    type=os.getenv("EMBEDDINGS_TYPE"),
    model=os.getenv("EMBEDDINGS_MODEL_NAME"),
)

model_conf=LLMConf(
    type=os.getenv("RE_MODEL_TYPE"),
    model=os.getenv("RE_MODEL_NAME"), 
    temperature=os.getenv("RE_MODEL_TEMPERATURE"), 
    deployment=os.getenv("RE_MODEL_DEPLOYMENT"),
    api_key=os.getenv("RE_API_KEY"),
    endpoint=os.getenv("RE_MODEL_ENDPOINT"),
    api_version=os.getenv("RE_MODEL_API_VERSION") or None
)

embeddings = get_embeddings(conf=embedder_conf)

kg = KnowledgeGraph(conf=kg_config,embeddings_model=embeddings)

In [None]:
communities = kg.get_communities()

chunks_comm_1 = communities[0].chunks

chunks_comm_1

In [None]:
communities

In [None]:
from typing import Any, Dict, List


def get_mentioned_entities(session: Session, chunk: Chunk, n_hops: int=1) -> List[Dict[str, Any]]:
    """ 
    Follows the `MENTIONS` relationships of a given Chunk in the Graph and collects mentioned entities. 
    `n_hops` is used to indicate the number of relationship layers that could be done following entities linking.  
    """
    # TODO perform n-hops retrieval
    base_query = """
        MATCH (c:Chunk)
        WHERE elementId(c) = $elementId
        MATCH (c)-[:MENTIONS]->(mentioned)
        RETURN collect(mentioned) AS mentioned_nodes
    """
    nodes = []
    result= session.run(base_query, elementId=chunk.chunk_id)
    record = result.single()
    mentioned_nodes = record["mentioned_nodes"] if record else []
    for node in mentioned_nodes:
        nodes.append(dict(node)) 
    return nodes

In [None]:
with kg._driver.session() as session:
    mentioned_ents = get_mentioned_entities(session, chunk=chunks_comm_1[0])
    
mentioned_ents

In [None]:
from typing import Tuple


def get_adjacent_chunks(session: Session, chunk: Chunk) -> Tuple[Chunk | None, Chunk | None, Chunk | None]:
    """
    Returns a tuple with the previous , current and following `Chunk` 
    given an initial node characterised by a `filename` and a `chunk_id`
    """
    base_query = """ 
        MATCH (current:Chunk)
        WHERE elementId(current) = $elementId

        OPTIONAL MATCH (prev:Chunk)-[:NEXT]->(current)
        OPTIONAL MATCH (current)-[:NEXT]->(next:Chunk)

        RETURN prev AS previous_chunk, current, next AS next_chunk
    """
    try: 
        result = session.run(base_query, elementId=chunk.chunk_id)
        record = result.single()
        
        previous_chunk = dict(record["previous_chunk"]) if record["previous_chunk"] else None
        if previous_chunk:
            previous_chunk = Chunk(
                chunk_id=previous_chunk['chunk_id'],
                filename=previous_chunk['filename'],
                text=previous_chunk["text"],
            )
            chunk.chunk_id = previous_chunk.chunk_id + 1 # original chunk id
        next_chunk = dict(record["next_chunk"]) if record["next_chunk"] else None
        if next_chunk:
            next_chunk = Chunk(
                chunk_id=next_chunk['chunk_id'],
                filename=next_chunk['filename'],
                text=next_chunk["text"],
            )
            chunk.chunk_id = next_chunk.chunk_id-1 # original chunk id
        
        return previous_chunk, chunk, next_chunk
    except Exception as e:
        logger.warning(f"Unable to retrieve adjacent chunks for Chunk: {chunk.chunk_id}")
        return None, chunk, None


In [None]:
with kg._driver.session() as session:
    ajacent_chunks = get_adjacent_chunks(session, chunk=chunks_comm_1[1])
    
ajacent_chunks

In [None]:
def filter_graph_by_communities(session: Session, community_ids: List[int], community_type: str="leiden"):
    """
    Creates a temporary  view of the Knowledge Graph to filter it into subgraphs given community ids.
    """
    query = f"""
        MATCH (n)-[r]->(m)
        WHERE n.community_{community_type} IN $community_values
        RETURN n, r, m
    """
    try:
        result = session.run(query, community_values=community_ids)
        
        subgraph = []
        
        for record in result:
            # Collecting nodes and relationships as dictionaries
            subgraph.append({
                "node_1": dict(record["n"]),
                "relationship": dict(record["r"]),
                "node_2": dict(record["m"])
            })
        
        return subgraph
    
    except Exception as e:
        print(f"Error while fetching subgraph: {e}")
        return []

In [None]:
with kg._driver.session() as session:
    subgraph = filter_graph_by_communities(session, community_ids=[3, 4])
    
subgraph