In [3]:
'''
User ask a question and we take the folowing steps:

1. We decide if the questions fall into any of these categories:
    -  Summarization: eg. Tell me about Daniel
    -  Comparison: eg who has a better job btw Daniel and James
    -  Basic_Query: eg how old is daniel
    -  General_Knowledge/message: eg Hy, Hello (question that dont require featching context)
2. We draw out a plan to answer the question based on the 
    For instance, lets take a comparison problem
    question =  who has a better job btw Daniel and James
    steps:
     - context_search(what does daniel do?)
     - context_search(what does james do?)
     - compare results
     - draft your final response

'''

'\nUser ask a question and we take the folowing steps:\n\n1. We decide if the questions fall into any of these categories:\n    -  Summarization: eg. Tell me about Daniel\n    -  Comparison: eg who has a better job btw Daniel and James\n    -  Basic_Query: eg how old is daniel\n    -  General_Knowledge/message: eg Hy, Hello (question that dont require featching context)\n2. We draw out a plan to answer the question based on the \n    For instance, lets take a comparison problem\n    question =  who has a better job btw Daniel and James\n    steps:\n     - context_search(what does daniel do?)\n     - context_search(what does james do?)\n     - compare results\n     - draft your final response\n\n'

In [4]:
import chromadb
from chromadb.utils import embedding_functions
from langchain_openai import ChatOpenAI
from langchain_core.messages import (HumanMessage,SystemMessage, FunctionMessage)
from langchain_core.tools import tool
from langchain.tools.render import format_tool_to_openai_function
import json

default_ef = embedding_functions.DefaultEmbeddingFunction()

client = chromadb.PersistentClient(path="contexts")
db = client.delete_collection("daniel_and_james")
db = client.create_collection("daniel_and_james", embedding_function=default_ef)



context_db_for_dan = [
    "He is 45 years old",
    "Daniel works in tech",
    "Daniel is a developer"
]
context_db_for_james= [
    "He is 49 years old",
    "James works in finance",
    "James is a Bank  Manager"
]

db.add(documents=context_db_for_dan, ids=['0','1','2'], metadatas=[{"name":"daniel"}]*3)
db.add(documents=context_db_for_james, ids=['3','4','5'], metadatas=[{"name":"james"}]*3)


In [5]:
db.query(query_texts=["age"], n_results=2)

{'ids': [['3', '0']],
 'distances': [[1.0891575047416255, 1.227068588394885]],
 'metadatas': [[{'name': 'james'}, {'name': 'daniel'}]],
 'embeddings': None,
 'documents': [['He is 49 years old', 'He is 45 years old']],
 'uris': None,
 'data': None}

In [6]:
filters = {"name": {"$eq": "james"}}
db.query(query_texts=["age"], n_results=2, where=filters)

{'ids': [['3', '5']],
 'distances': [[1.0891575047416255, 1.920391985364417]],
 'metadatas': [[{'name': 'james'}, {'name': 'james'}]],
 'embeddings': None,
 'documents': [['He is 49 years old', 'James is a Bank  Manager']],
 'uris': None,
 'data': None}

In [7]:
def context_search(question, filter_by_name, k=2):
    filters = {"name": {"$eq": filter_by_name.lower()}}
    context = db.query(query_texts=question, n_results=k, where=filters )
    return context["documents"]
    

### AGENTS GRAPH

In [8]:
import operator
from typing import Annotated,  Sequence,  TypedDict
from langchain_openai.chat_models import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from typing import TypedDict, Annotated 
from langchain_core.messages import BaseMessage
import operator
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.graph import END, StateGraph
import json
from langchain_core.messages import (
    BaseMessage,
    FunctionMessage,
    HumanMessage,
    AIMessage
    
)
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from langchain_core.tools import tool
from typing import Annotated
from rich.markdown import Markdown
from rich import print as md



In [9]:
#define graph state
class AgentState(TypedDict):
    chat_history: list[BaseMessage]
    messages: Annotated[Sequence[BaseMessage], operator.add]

In [23]:
#context_generator agent
@tool
def query_person_info( keywords : str, filter_by_name: str, k: int = 2) -> list:
    """
    This function searches the database for context related to a given question and a name of person to filter by.
    
    Parameters:
    keywords (str): The keywords to search context for e.g age, occupation etc.
    filter_by_name (str): The name of person to filter the search by.
    

    Returns:
    list: A list of documents from the database that match the search criteria.
    """
    filters = {"name": {"$eq": filter_by_name.lower()}}
    context = db.query(query_texts=keywords, n_results=k, where=filters )
    return context["documents"]
    
    
    
llm2 = ChatOpenAI(model="gpt-4-1106-preview")
tools = [query_person_info]
functions = [format_tool_to_openai_function(t) for t in tools]
context_provider_name = "context_provider"
prompt = ChatPromptTemplate.from_messages(
                [
                    (
                    "system",
                    "Your name is {name}, and you are part of an AI system designed to answer questions about some persons"
                    "The questions are usually about some person named james or daniel "
                    "Takes this steps:"
                    "1 . categorise the question into one of the following:"
                    "A. summarization: This category is for requests that require a summary or info about  of a person's life (just a single person). "
                    "B. comparison: This category is for requests that require comparing two or more info about two persons "
                    "C. general_knowledge/message: This category is for general knowledge questions or messages that don't require fetching context about any person. e.g greeting "
                    "2. if the question falls under A or B, call the {tool} function one or many times to generate the needed contexts to answer the question. use the best search keywords based on the question"
                    "3. if the questions falls to general_knowledge/message (C), just answer the question without calling the function based on what you know, do not state the category of the question"
                    ),
                    MessagesPlaceholder(variable_name="messages"),
                ]
            )
prompt = prompt.partial(name=context_provider_name)
prompt = prompt.partial(tool=", ".join([tool.name for tool in tools]))
context_provider_agent=  prompt | llm2.bind_functions(functions)


def context_provider_node(state):
    result = context_provider_agent.invoke(state)
    
    if isinstance(result, FunctionMessage):
        # Here we can handle the FunctionMessage
        if result.function in functions:
            # The function/tool is valid
            pass
        else:
            raise ValueError(f"Invalid function/tool: {result.function}")
    else:
        result = AIMessage(**result.dict(exclude={"type", "name"}), name=context_provider_name)
    return {
        "messages": [result],


    }

In [24]:
#Tool Executor
tool_executor = ToolExecutor(tools)

def tool_node(state):

    """This runs tools in the graph

    It takes in an agent action and calls that tool and returns the result."""
    messages = state["messages"]
  
    # Based on the continue condition
    # we know the last message involves a function call
    last_message = messages[-1]
    # We construct an ToolInvocation from the function_call
    try:
        tool_input = json.loads(
            last_message.additional_kwargs["function_call"]["arguments"]
        )
    except:
        tool_input = {"code":last_message.additional_kwargs["function_call"]["arguments"]} #sometimes the actual code is sent as a string instead of {code:"code"}
    # We can pass single-arg inputs by value
    if len(tool_input) == 1 and "__arg1" in tool_input:
        tool_input = next(iter(tool_input.values()))
        print(tool_input)
    tool_name = last_message.additional_kwargs["function_call"]["name"]
    action = ToolInvocation(
        tool=tool_name,
        tool_input=tool_input,
    )
    # We call the tool_executor and get back a response
    response = tool_executor.invoke(action)
    # We use the response to create a FunctionMessage
    function_message = FunctionMessage(
        content=f"{tool_name} response: {str(response)}", name=action.tool
    )
    # We return a list, because this will get added to the existing list
    return {"messages": [function_message]}

In [25]:
workflow = StateGraph(AgentState)
workflow.add_node(context_provider_name, context_provider_node)
workflow.add_node("call_tool", tool_node)

def router(state):
    # This is the router
    messages = state["messages"]
    last_message = messages[-1]
    if "function_call" in last_message.additional_kwargs:
        # The previus agent is invoking a tool
        return "call_tool" #irrespective of the sender
    else:
       #proceed to human node
       return "continue"


workflow.add_conditional_edges(
    context_provider_name,
    router,
    {
       "call_tool":"call_tool", "continue":END
    },
)


workflow.add_edge(
    "call_tool",
    context_provider_name,
)



workflow.set_entry_point(context_provider_name)
graph = workflow.compile()

In [28]:
state = {
        "messages": [
            HumanMessage(
                content=f'hy, who is older btw daniel and james'
            )
        ],

    }



for s in graph.stream(state, {"recursion_limit": 150}):
    #print(s)
    agent = list(s.keys())[0]
    content = s[agent]["messages"][-1].content
    
    if agent in [ context_provider_name]:
        #check if it is trying to call a function/tool
        if "function_call" in s[agent]["messages"][-1].additional_kwargs:
            function_being_called = s[agent]["messages"][-1].additional_kwargs['function_call']['name']
            args = s[agent]["messages"][-1].additional_kwargs['function_call']['arguments']
            content = f"I am calling the function `{function_being_called}` with the following arguments: {args}"
            content = Markdown(content)
            md(content)
        else:
            content = Markdown(content)
            md(content)
    elif agent == "call_tool":
        pass