In [12]:
import os
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import RetrievalQA
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.prompts import (
    PromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
    ChatPromptTemplate
)
from langchain.tools import tool


In [13]:
# This part is for finding relevant PROJECTS based on their description
openai_project_index = Neo4jVector.from_existing_graph(
    embedding=OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")),
    url=os.getenv("NEO4J_URI"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
    index_name="openai_project_index",
    node_label="Project",
    text_node_properties=["name", "description", "complexity"],
    embedding_node_property="openai_embedding",
)

In [14]:
project_details_chat_template_str = """
Your job is to use the provided project data to answer 
questions about their status, description, startdate, enddate, etc. 
within the company. Use the following context to answer questions. 
Be as detailed as possible, but don't make up any information that's 
not from the context. If you don't know an answer, say you don't know.
{context}
"""

In [19]:
@tool("project-qa-tool", return_direct=False)
def project_qa_tool(query: str) -> str:
    """Useful for answering questions about projects."""
    
    project_details_chat_system_prompt = SystemMessagePromptTemplate(
        prompt=PromptTemplate(
            input_variables=["context"],
            template=project_details_chat_template_str))
    
    human_prompt = HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            input_variables=["question"],
            template="Can you provide details on: {question}?"))
    
    messages = [project_details_chat_template_str, human_prompt]
    
    qa_prompt = ChatPromptTemplate(
        messages=messages,
        input_variables=["context", "question"])
    
    llm = ChatOpenAI(model="gpt-3.5-turbo")
    
    
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=openai_project_index.as_retriever(),
        # ['stuff', 'map_reduce', 'refine', 'map_rerank']
        chain_type="stuff", )
    
    qa_chain.combine_documents_chain.llm_chain.prompt = qa_prompt
    response = qa_chain.invoke(query)
    return response.get("result")

In [20]:
@tool("general-qa-tool", return_direct=False)
def general_qa_tool(query: str) -> str:
    """Useful for answering general questions."""
    response = cypher_chain.invoke(query)
    return response.get("result")

In [21]:
project_qa_tool.invoke("Provide me the detailed information of Skyhawks.")


'Sure, here are the detailed information about the project named Skyhawks:\n\n- Name: Skyhawks\n- Description: Skyhawks is a solution designed for founders who aim to excel in content marketing. The project utilizes proprietary AI technology to analyze the competition related to the chosen topic. It helps in creating optimized content quickly by leveraging state-of-the-art AI generation capabilities. This allows users to produce high-quality content efficiently and stay ahead in the competitive content marketing landscape.\n- Complexity: High\n\nIf you need further information or have any specific questions about Skyhawks, feel free to ask!'

In [22]:
general_qa_tool.invoke("Provide me the detailed information of Skyhawks.")


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/tejasjay/Desktop/graphrag/.venv/lib/python3.13/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
    ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/90/h9hy_l096_jcf3041fnkklp00000gn/T/ipykernel_2249/2653480292.py", line 1, in <module>
    general_qa_tool.invoke("Provide me the detailed information of Skyhawks.")
    ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tejasjay/Desktop/graphrag/.venv/lib/python3.13/site-packages/langchain_core/tools/base.py", line 599, in invoke
    return self.run(tool_input, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tejasjay/Desktop/graphrag/.venv/lib/python3.13/site-packages/langchain_core/tools/base.py", line 883, in run
    raise error_to_raise
  File "/Users/tejasjay/Desktop/graphrag/.venv/lib/python3.13/site-packages/langchain_core/tools/