Approach 2, second try:
- Scientist receives conversation. 
- Goes through the conversational flow.
- Queries the graph.
- Save query in history of queries. 
- Save Query results in history of query results. 
- ToolMessage is not saved in conversation messages.
- Have a variable for the extracted occupations or traits.

In [1]:
# LangChain
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langchain.graphs import Neo4jGraph
from langchain_core.tools import tool
from langchain_groq import ChatGroq

# LangGraph
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver

# LangSmith
from langsmith import traceable

# General Imports
import os
import ast
import prompts
import operator
from pathlib import Path
from dotenv import load_dotenv
from typing import TypedDict, Annotated # to construct the agent's state
from IPython.display import Image, display

# Connect to graph
dotenv_path = Path('../.env')
load_dotenv(dotenv_path=dotenv_path)
os.environ["NEO4J_URI"] = os.getenv('NEO4J_URI')
os.environ["NEO4J_USERNAME"] = os.getenv('NEO4J_USERNAME')
os.environ["NEO4J_PASSWORD"] = os.getenv('NEO4J_PASSWORD')
graph = Neo4jGraph()

# Create the tool to be used by the Agent
@tool
def query_graph(query):
  """Query from Neo4j knowledge graph using Cypher."""
  return graph.query(query)

In [2]:
# Create Agent's State
class AgentState(TypedDict):
    conversation: Annotated[list[ AnyMessage ], operator.add]
    good_cypher_and_outputs: Annotated[dict[ str, str ], operator.or_]
    bad_cypher: Annotated[list[ str ], operator.add]
    extracted_data: Annotated[list[ str ], operator.add]

    # existing_cyphers: list[str]
    graph_data_to_be_used: list[str]

# Create Agent
class Agent:

    def __init__(self, model, tools, system: str):

        graph = StateGraph(AgentState) # Initialize a stateful graph
        memory = MemorySaver()

        graph.add_node("personality_scientist", self.call_groq)
        graph.add_node("validate_cypher_then_query_graph", self.validate_cypher_then_query_graph) # Checks if query is new
        graph.add_node("extract_data", self.extract_data)
        graph.add_node("recommend_careers", self.recommend_careers)

        graph.add_conditional_edges("personality_scientist", self.validate_tool_call, {True: 'validate_cypher_then_query_graph', False: END})
        graph.add_edge("validate_cypher_then_query_graph", "extract_data")
        graph.add_edge("extract_data", "recommend_careers")
        graph.add_edge("recommend_careers", END)
        graph.set_entry_point("personality_scientist")

        self.graph = graph.compile(checkpointer=memory)
        self.system = system
        self.tools = {t.name: t for t in tools} # Save the tools' names that can be used
        self.model = model.bind_tools(tools)

    ## Helper function that returns the cyphers written before. This is used when calling the LLM
    def get_previous_cyphers(self, state: AgentState):
        cyphers_list = ""
        
        if len(state['bad_cypher']) > 0:
            cyphers_list += f"Here are previously written cypher codes that did not return an output: {str(state['bad_cypher'])}"

        good_cypher = list(state['good_cypher_and_outputs'].keys())
        if len(good_cypher) > 0:
            cyphers_list += f"Here are previously written cypher codes that successfully returned an output: {str(good_cypher)}"
        
        if cyphers_list != "": 
            output = "When you are ready to write the cypher code, use these as help:" + cyphers_list
            return output
        else:
            return None
    
    ## Get the LLM's response and update the Agent's State by adding the response to the messages
    @traceable
    def call_groq(self, state: AgentState):
        previous_cyphers = self.get_previous_cyphers(state = state)
        if previous_cyphers:
            system_message = f"{self.system}. {previous_cyphers}"
        else:
            system_message = self.system

        conversation = [SystemMessage(content=system_message)] + state['conversation']
        ai_response = self.model.invoke(conversation)
        return {'conversation': [ai_response]}
    
    ## Check if the model called for an action by checking the last message in the conversation
    def validate_tool_call(self, state: AgentState):
        ai_message = state['conversation'][-1]
        return len(ai_message.tool_calls) > 0
    
    ## Check if the new cypher codes have already been written by the LLM
    def validate_cypher_then_query_graph(self, state: AgentState):
        print(f"\n------- in validate query")
        tool_calls = state['conversation'][-1].tool_calls

        good_cypher_and_outputs = {} # stores the cypher queries that returned an output
        bad_cypher = [] # stores the cypher queries that did not return an output
        graph_data_to_be_used = [] # stores the queries that the model currently wants to use

        # Go over the cypher codes given by the LLM
        for i in range(len(tool_calls)):
            current_tool_call = tool_calls[i]
            new_cypher = current_tool_call['args']['query']

            status = None

            # Check if tool exists
            if current_tool_call['name'] in self.tools:
                
                # LLM checks if the cypher query has already been made + if it returned an output
                if status == None:
                    print(f"----- Checker 1")
                    good_cyphers = list(state['good_cypher_and_outputs'].keys())
                    for i in range(len( good_cyphers )):
                        key = good_cyphers[i]
                        comparison = self.model.invoke(
                            prompts.cypher_code_analyst_prompt.format(cypher_code_1=new_cypher, cypher_code_2=key, graph_schema=graph.structured_schema)
                            )
                        print(f"\n-------- {comparison.content}")
                        # If Query exists => Save cypher code to later access the output and give it to the model
                        if comparison.content.lower() == "true":
                            print("\n----- Good Cyphers are similar, save the cypher")
                            graph_data_to_be_used.append(key)
                            status = "good_cypher"
                            break
                
                # LLM checks if the cypher query has already been made + if it didn't return an output
                if status == None:
                    print(f"----- Checker 2")
                    for i in range(len(state['bad_cypher'])):
                        comparison = self.model.invoke(
                            prompts.cypher_code_analyst_prompt.format(
                                cypher_code_1=new_cypher, cypher_code_2=state['bad_cypher'][i], graph_schema=graph.structured_schema)
                            )
                        # Query exists => LLM should fix the query
                        if comparison.content.lower() == "true":
                            print("---- Bad Cyphers are similar, have the LLM query again")
                            status = "bad_cypher"
                            break

                # If the cypher code hasn't been used before => query the graph
                if status == None:
                    print(f"----- Checker 3")
                    query_output = self.tools[current_tool_call['name']].invoke(new_cypher)
                    result = ToolMessage(content=str(query_output), name=current_tool_call['name'], tool_call_id=current_tool_call['id'])
                    
                    if result.content not in ["", None, '[]']:
                        good_cypher_and_outputs[new_cypher] = result.content
                        graph_data_to_be_used.append(new_cypher)
                        print("----- Successfully queried graph")

                    else:
                        bad_cypher.append(new_cypher)
                        print("----- Bad query")
            else:
                print("tool name not found in list of tools")

        # Save the data that we got in the AgentState
        return_statement = {}
        if len(good_cypher_and_outputs) > 0: return_statement['good_cypher_and_outputs'] = good_cypher_and_outputs
        if len(bad_cypher) > 0: return_statement['bad_cypher'] = bad_cypher
        if len(graph_data_to_be_used) > 0: return_statement['graph_data_to_be_used'] = graph_data_to_be_used

        return return_statement
    
    ## Invoke tool
    def extract_data(self, state: AgentState):
        print('\n-------> In extract data')

        if len(state['graph_data_to_be_used']) > 0:
            data_to_give_to_the_LLM = [{cypher: state['good_cypher_and_outputs'][cypher]} for cypher in state['graph_data_to_be_used']]
            extracted_data = self.model.invoke(prompts.extractor_prompt.format(queried_data = data_to_give_to_the_LLM))
            print("----- Data has been extracted")
            return {'extracted_data': [extracted_data.content]}

        else:
            print("No queries to be made")
            return

    ## Generate final output
    def recommend_careers(self, state: AgentState):
        previous_cyphers = self.get_previous_cyphers(state = state)
        if previous_cyphers:
            system_message = f"{self.system}. {previous_cyphers}"
        else:
            system_message = self.system
        
        conversation = [SystemMessage(content=system_message)] + state['conversation']

        if len(state['graph_data_to_be_used']) > 0:
            prompt = conversation + [HumanMessage(content= prompts.recommender_prompt_with_data.format(extracted_data=state['extracted_data'][-1]))]
            print("----- Giving extracted data to the agent")
            
            recommended_careers = self.model.invoke(prompt)
            print("ready for recommendation")
            return {'conversation': [recommended_careers]}
        else:
            
            prompt = conversation + [HumanMessage(content= prompts.recommender_prompt_without_data)]
            print("----- No data to give to the agent")
            
            ai_output = self.model.invoke(prompt)
            print("ready for recommendation")
            return {'conversation': [ai_output]}


In [3]:
model = ChatGroq(temperature=0.8, groq_api_key=os.environ["GROQ_API_KEY"], model_name="llama-3.1-70b-versatile", max_retries=2)
# model = ChatGroq(temperature=0.0, groq_api_key=os.environ["GROQ_API_KEY"], model_name="llama3-70b-8192")

agent = Agent(
  model=model, 
  tools=[query_graph], 
  system=prompts.personality_scientist_prompt.format(schema=graph.structured_schema)
  )

# graph_png = Image(agent.graph.get_graph().draw_mermaid_png())
# display(graph_png)

In [17]:
thread = {'configurable': {'thread_id': "100"}}
output = agent.graph.stream({"conversation": [HumanMessage(content="Do you know any occupations that suit my character and include Python and Psychology?")], 'graph_data_to_be_used': []}, thread, stream_mode='values')

for message in output:
    message['conversation'][-1].pretty_print()


Do you know any occupations that suit my character and include Python and Psychology?
Tool Calls:
  query_graph (call_mw3e)
 Call ID: call_mw3e
  Args:
    query: MATCH (n:Occupation)-[]->(m:Personality_Trait) WHERE m.title CONTAINS 'Psychology' OR n.title CONTAINS 'Python' RETURN n,m

------- in validate query
----- Checker 1

-------- False
----- Checker 2
----- Checker 3
----- Bad query
Tool Calls:
  query_graph (call_mw3e)
 Call ID: call_mw3e
  Args:
    query: MATCH (n:Occupation)-[]->(m:Personality_Trait) WHERE m.title CONTAINS 'Psychology' OR n.title CONTAINS 'Python' RETURN n,m

-------> In extract data
No queries to be made
----- No data to give to the agent
ready for recommendation

Based on our conversation, I understand that you:

* Are detail-oriented and tend to focus on the details to ensure everything is correct
* Are more reserved and reflective, enjoying quieter environments
* Tend to approach problems in a logical and analytical way, relying on data and facts to fin

In [14]:
# agent.graph.get_state(thread).values['conversation'][-1].tool_calls[0]['args']['query']
# agent.graph.get_state(thread).values['conversation']
# agent.graph.get_state(thread).values

state = agent.graph.get_state(thread).values
state['good_cypher_and_outputs']
state['conversation']

[HumanMessage(content='Hi'),
 AIMessage(content="Hello, I'm excited to help you discover some fantastic career paths that fit your personality. To get started, I have a couple of questions to help me understand you better.\n\nHere's my first question:\n\n1. When working on a project, what do you usually focus on: the details to ensure everything is correct, or the overall vision to see the big picture?", response_metadata={'token_usage': {'completion_tokens': 75, 'prompt_tokens': 690, 'total_tokens': 765, 'completion_time': 0.301961196, 'prompt_time': 0.228078331, 'queue_time': 0.319569941, 'total_time': 0.530039527}, 'model_name': 'llama-3.1-70b-versatile', 'system_fingerprint': 'fp_5c5d1b5cfb', 'finish_reason': 'stop', 'logprobs': None}, id='run-fb4bc51b-6caf-4dca-a61d-3f4790458186-0', usage_metadata={'input_tokens': 690, 'output_tokens': 75, 'total_tokens': 765}),
 HumanMessage(content='Option 1'),
 AIMessage(content="You're a detail-oriented person who likes to make sure everything

In [6]:
# tool_calls = state['conversation'][-2].tool_calls

# good_cypher_and_outputs = {} # stores the cypher queries that returned an output
# bad_cypher = [] # stores the cypher queries that did not return an output
# graph_data_to_be_used = [] # stores the queries that the model currently wants to use

# # Go over the cypher codes given by the LLM
# for i in range(len(tool_calls)):
#     current_tool_call = tool_calls[i]
#     new_cypher = current_tool_call['args']['query']

#     status = None

#     # Check if tool exists
#     if current_tool_call['name'] in agent.tools:
        
#         # LLM checks if the cypher query has already been made + if it returned an output
#         if status == None:
#             print(f"----- Checker 1")
#             good_cyphers = list(state['good_cypher_and_outputs'].keys())
#             for i in range(len( good_cyphers )):
#                 key = good_cyphers[i]
#                 comparison = agent.model.invoke(
#                     prompts.cypher_code_analyst_prompt.format(cypher_code_1=new_cypher, cypher_code_2=key, graph_schema=graph.structured_schema)
#                     )
#                 print(f"\n-------- {comparison.content}")
#                 # If Query exists => Save cypher code to later access the output and give it to the model
#                 if comparison.content.lower() == "true":
#                     print("\n----- Good Cyphers are similar, save the cypher")
#                     graph_data_to_be_used.append(key)
#                     status = "good_cypher"
#                     break
        
#         # LLM checks if the cypher query has already been made + if it didn't return an output
#         if status == None:
#             print(f"----- Checker 2")
#             for i in range(len(state['bad_cypher'])):
#                 comparison = agent.model.invoke(
#                     prompts.cypher_code_analyst_prompt.format(
#                         cypher_code_1=new_cypher, cypher_code_2=state['bad_cypher'][i], graph_schema=graph.structured_schema)
#                     )
#                 # Query exists => LLM should fix the query
#                 if comparison.content.lower() == "true":
#                     print("---- Bad Cyphers are similar, have the LLM query again")
#                     status = "bad_cypher"
#                     break

In [7]:
print(comparison.content)

NameError: name 'comparison' is not defined

In [6]:
import streamlit as st
import pandas as pd
import networkx as nx
import plotly.graph_objects as go

In [None]:
def display_knowledge_graph(data):
    head = []
    tail = []

    # Extract head and tail nodes
    for i in range(len(data)):
        row = data[i]
        keys = list(row.keys())
        node_1 = row[keys[0]]['title']
        node_2 = row[keys[1]]['title']

        head.append(node_1)
        tail.append(node_2)

    # Create a DataFrame
    df = pd.DataFrame({'head': head, 'tail': tail})

    # Create the graph
    G = nx.Graph()
    for _, row in df.iterrows():
        G.add_edge(row['head'], row['tail'], label="")

    # Get positions for nodes
    pos = nx.fruchterman_reingold_layout(G, k=0.7)

    # Create edge traces (lines between nodes)
    edge_traces = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace = go.Scatter(
            x=[x0, x1, None],
            y=[y0, y1, None],
            mode='lines',
            line=dict(width=0.3, color='gray'),
            hoverinfo='none'
        )
        edge_traces.append(edge_trace)

    # Assign colors based on whether the node is in the head or tail
    node_colors = []
    for node in G.nodes():
        if node in head:
            node_colors.append('lightblue')  # Color for head nodes (node_1)
        elif node in tail:
            node_colors.append('lightcoral')   # Color for tail nodes (node_2)
        else:
            node_colors.append('gold')  # Default color

    # Create node trace (nodes with their respective colors)
    node_trace = go.Scatter(
        x=[pos[node][0] for node in G.nodes()],
        y=[pos[node][1] for node in G.nodes()],
        mode='markers+text',
        marker=dict(size=10, color=node_colors),  # Use node_colors list for colors
        text=[node for node in G.nodes()],
        textposition='top center',
        hoverinfo='text',
        textfont=dict(size=7)
    )

    # Create edge label trace (optional, for labeling edges)
    edge_label_trace = go.Scatter(
        x=[(pos[edge[0]][0] + pos[edge[1]][0]) / 2 for edge in G.edges()],
        y=[(pos[edge[0]][1] + pos[edge[1]][1]) / 2 for edge in G.edges()],
        mode='text',
        text=[G[edge[0]][edge[1]]['label'] for edge in G.edges()],
        textposition='middle center',
        hoverinfo='none',
        textfont=dict(size=7)
    )

    # Create layout
    layout = go.Layout(
        title='',
        titlefont_size=16,
        title_x=0.5,
        showlegend=False,
        hovermode='closest',
        margin=dict(b=20, l=5, r=5, t=40),
        xaxis_visible=False,
        yaxis_visible=False
    )

    # Create Plotly figure
    fig = go.Figure(data=edge_traces + [node_trace, edge_label_trace], layout=layout)

    return fig


# for i in range(len(state['graph_data_to_be_used'])):
g = ast.literal_eval(state['good_cypher_and_outputs']['MATCH (n:Occupation)-[]->(m:Personality_Trait) RETURN n, m'])
# display_knowledge_graph(g)
g 