In [1]:
from langchain_ollama import ChatOllama
from typing import Annotated, List
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.message import AnyMessage, add_messages
from typing_extensions import TypedDict
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import tools_condition
import uuid
import os
from dotenv import load_dotenv
load_dotenv(os.path.join('../config/','.env'))  

True

In [2]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


class Assistant:
    def __init__(self, runnable: Runnable):
        """
        Initialize the Assistant with a runnable object.

        Args:
            runnable (Runnable): The runnable instance to invoke.
        """
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        """
        Call method to invoke the LLM and handle its responses.
        Re-prompt the assistant if the response is not a tool call or meaningful text.

        Args:
            state (State): The current state containing messages.
            config (RunnableConfig): The configuration for the runnable.

        Returns:
            dict: The final state containing the updated messages.
        """
        while True:
            result = self.runnable.invoke(state)  # Invoke the LLM
            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}



In [3]:
def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state: State) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }


def conversation_state_tracker(query: str) -> str:
    """
    Analyzes the conversation history to determine the current state of the conversation.
    
    Parameters:
    - query (str): Pass the query and find the current state of the conversation.
    
    Returns:
    - One from the following (str):
        - Initial -> meaning that the user is at The beginning stage of the conversation.
        - Exploring -> meaning that the user is at an In-depth discussion and exploration of topics.
        - Probing -> meaning that the user is Asking deeper questions to uncover more information.
        - Concluding -> meaning that the user is at Wrapping up the conversation and reaching a conclusion.
    """
    prompt = f"""
        Analyze the following conversation history:
        {query}
        Determine the current conversation state in Socratic learning Method to decide what to do next. 
        Consider factors such as the topic, depth of discussion, and user engagement. 
        Respond with only on of the possible states:
            - Initial -> meaning that the user is at The beginning stage of the conversation.
            - Exploring -> meaning that the user is at an In-depth discussion and exploration of topics.
            - Probing -> meaning that the user is Asking deeper questions to uncover more information.
            - Concluding -> meaning that the user is at Wrapping up the conversation and reaching a conclusion.
        
        Return one from this List [Initial,Exploring,Probing,Concluding]
        
        Example:
            -   query: "Hi, I'm new to Machine Learning, where should I start?" 
                return: 'Initial'
            -   query: "Can you explain the difference between supervised and unsupervised learning?" 
                return: 'Exploring'
            -   query: "What happens if we use a high learning rate in training?" 
                return: 'Probing'
            -   query: "Got it, thanks for your help with Machine Learning basics." 
                return: 'Concluding'
    """
    response = llm.invoke(prompt)
    return response

In [4]:
llm = ChatOllama(
    model="llama3-groq-tool-use",
    # model="llama3.1",
    temperature=0,
)
tools = [conversation_state_tracker]

In [5]:
# Create the primary assistant prompt template
primary_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant you're purpose is to return the current state of the conversation."
            "You have access to one tool: conversation_state_tracker."
            "You must use this tool whenever you're being queried and respond with the current state, don't add anything else to your response."
            "You can use this tool to get the current state of the conversation by sending the user input as query to the tool."
            "Strictly return the tool's output as your response."
        ),
        ("placeholder", "{messages}"),
    ]
)

# Prompt our LLM and bind tools
assistant_runnable = primary_assistant_prompt | llm.bind_tools(tools, tool_choice='any')


In [6]:
# Graph
builder = StateGraph(State)

# Define nodes: these do the work
builder.add_node("assistant", Assistant(assistant_runnable))
builder.add_node("tools", create_tool_node_with_fallback(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "assistant")

# The checkpointer lets the graph persist its state
memory = MemorySaver()
react_graph = builder.compile(checkpointer=memory)

# Show
# display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))

In [7]:
def predict_react_agent_answer(example: dict):
    """Use this for answer evaluation"""

    config = {"configurable": {"thread_id": str(uuid.uuid4())}}
    messages= react_graph.invoke({"messages": ("user", example["input"])}, config)
    return {"response": messages["messages"][-1].content, "messages": messages}


example = {"input": "Hi, I wanna learn Datastructures"}
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
messages= react_graph.invoke({"messages": ("user", example["input"])}, config)
# print({"response": messages["messages"][-1].content, "messages": messages})
# response = predict_react_agent_answer(example)

In [8]:
messages

{'messages': [HumanMessage(content='Hi, I wanna learn Datastructures', id='1ee2f3b2-aa76-43ea-8e3f-ca49eabe3649'),
  AIMessage(content='Sure! What specific aspect of data structures are you interested in? For example, arrays, linked lists, stacks, queues, trees, graphs, etc.', response_metadata={'model': 'llama3-groq-tool-use', 'created_at': '2024-09-01T16:04:23.405101035Z', 'message': {'role': 'assistant', 'content': 'Sure! What specific aspect of data structures are you interested in? For example, arrays, linked lists, stacks, queues, trees, graphs, etc.'}, 'done_reason': 'stop', 'done': True, 'total_duration': 11721649583, 'load_duration': 3822687895, 'prompt_eval_count': 262, 'prompt_eval_duration': 3717044000, 'eval_count': 32, 'eval_duration': 4096965000}, id='run-37541952-ce9a-400d-9e17-b4b23e040016-0', usage_metadata={'input_tokens': 262, 'output_tokens': 32, 'total_tokens': 294})]}