In [None]:
import os
from dotenv import load_dotenv

from src.agents.graph_qa import GraphAgentResponder
from src.config import LLMConf, EmbedderConf, KnowledgeGraphConfig
from src.graph.knowledge_graph import KnowledgeGraph
from src.ingestion.embedder import ChunkEmbedder


env = load_dotenv('config.env')

In [None]:
kg_config = KnowledgeGraphConfig(
    uri=os.getenv("NEO4J_URI"),
    user=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
    index_name=os.getenv("INDEX_NAME")
)

# llm_conf = LLMConf(
#     type="ollama",
#     model="llama3.2:latest", 
#     temperature=0.0, 
# )

llm_conf=LLMConf(
    type=os.getenv("QA_MODEL_TYPE"),
    model=os.getenv("QA_MODEL_NAME"), 
    temperature=os.getenv("QA_MODEL_TEMPERATURE"), 
    deployment=os.getenv("QA_MODEL_DEPLOYMENT"),
    api_key=os.getenv("QA_API_KEY"),
    endpoint=os.getenv("QA_MODEL_ENDPOINT"),
    api_version=os.getenv("QA_MODEL_API_VERSION") or None
)

embedder_conf = EmbedderConf(
    type="ollama",
    model="mxbai-embed-large",
)

In [None]:
embedder = ChunkEmbedder(conf=embedder_conf)

knowledge_graph = KnowledgeGraph(
    conf=kg_config, 
    embeddings_model=embedder.embeddings
)

knowledge_graph._driver.verify_connectivity()

knowledge_graph._driver.verify_authentication()

In [None]:
knowledge_graph.get_structured_schema # too long to pass in prompt

In [None]:
knowledge_graph.labels

In [None]:
knowledge_graph.relationships

In [None]:
responder = GraphAgentResponder(
    qa_llm_conf=llm_conf, # TODO try different LLMs
    cypher_llm_conf=llm_conf,
    graph=knowledge_graph,
    rephrase_llm_conf=llm_conf
)

In [None]:
query = "What countries are mentioned in the graph?"
responder.answer(query)

In [None]:
query = "Is Russia mentioned in the graph?"
responder.answer(query)

In [None]:
query = "What document mentions Ukraine?"
responder.answer(query)

In [None]:
query = "When was the request for payment from Germany submitted?"
responder.answer(query)

In [None]:
query = "What happened in the baltic sea?"
responder.answer(query)