In [1]:
from langchain.chat_models import init_chat_model
from langchain.agents.middleware import TodoListMiddleware
from dotenv import load_dotenv
from langchain.agents import create_agent
from langchain.tools import tool
import pprint
load_dotenv()

all_calls = []

supervisor_llm = init_chat_model(
    model = "gpt-4o-mini",
    model_provider='openai',
    )


rag_llm = init_chat_model(
    model = "gpt-4o-mini",
    model_provider='openai',
)

calc_llm = init_chat_model(
    model = 'gpt-4o-mini',
    model_provider='openai',
)




@tool#(parse_docstring=True)
def addition(a: float, b: float)->float:
    """
    Adds two floating point numbers
    Args:
        a : First floating point number
        b : Second floating point number
    Returns:
        a+b
    """

    return a+b

@tool#(parse_docstring=True)
def subtraction(a:float, b: float) -> float:
    """
    Performs subtraction of two floating point numbers.

    Args:
        a : First floating point number
        b : Second floating point number
    Returns:
        a-b

    """

    return a-b

@tool#(parse_docstring=True)
def multiplication(a: float, b: float) -> float:
    """
    Performs the product of two floating point numbers.
    
    Args:
        a : First floating point number
        b : Second floating point number
    Returns:
        a*b
    """

    return a*b

@tool#(parse_docstring=True)
def division(a: float, b: float) -> float:
    """
    Performs the division of two floating point numbers.
    
    Args:
        a : First floating point number
        b : Second floating point number
    Returns:
        a/b
    """

    return a/b

calc_agent = create_agent(
    model=calc_llm,
    tools = [addition, subtraction, multiplication, division],
    middleware=[TodoListMiddleware()],
    system_prompt="You are an arithmetic agent. You have access to addition, subtraction, multiplication, and division tools "
)



@tool
def retreive_augment_context(query: str) -> str:
    """
    Retreives policy information about TCS company.
    
    Args:
        query : Query about the company

    Returns:
        Returns the data retreived from the vector db to answer the query and augments the query.
    """
    return f"{query} : TCS has very strict policies about personal information"



@tool
def generate_respond() -> str:
    """
    Generates final response for the query.
    Returns:
        Returns the final response for the query.
    """
    return "TCS has very strict policies about personal information"


rag_agent = create_agent(
    model=rag_llm,
    tools = [retreive_augment_context, generate_respond],
    middleware=[TodoListMiddleware()],
    system_prompt="""You are a RAG agent. Your task is to answer user queries that are relates to TCS policies.
    Use the below tools to respond to all questions related to TCS policies.
    Do not hallucinate, just invoke the tools to get the information regarding all TCS policies.
    You have access to retreive_augment_context, and generate_respond tools 
    Do not add extra text other than what tools provide"""
)

@tool
def calc_agent_tool(query: str):
    """
    Capable of performing expression evaluation that contain addition, subtraction, multiplication, and division operations

    Args:
        query: Contains the expression to evaluate.

    Returns:
        The final value after evaluation
    """
    calc_resp =  calc_agent.invoke(
        {
            "messages": [{
                "role": "human",
                "content" : query
                }
            ]
        }
    )
    all_calls.append(calc_resp)
    return calc_resp['messages'][-1].content

@tool
def rag_policy_agent_tool(query: str):
    """
    Capable of generating responses for any TCS related Policies.

    Args:
        query: Contains the query related to TCS Policies.

    Returns:
        The response to the query related to TCS Policies
    """
    rag_resp =  rag_agent.invoke(
        {
            "messages": [{
                "role": "human",
                "content" : query
                }
            ]
        }
    )
    all_calls.append(rag_resp)
    return rag_resp['messages'][-1].content

supervisor_agent = create_agent(
    model=supervisor_llm,
    tools=[rag_policy_agent_tool, calc_agent_tool],
    middleware=[TodoListMiddleware()],
    system_prompt=""" You are an expert agent that can repond to queries related to expression evaluation and
    TCS policies. You have access to "calc_agent_tool", and "rag_policy_agent_tool" to provide responses to
    expression evaluation and TCS policies respectively.

    Donot hallicunate and respond.
    Respond only with the information obtained from the tools available in hand. 

"""
)


# response = supervisor_agent.invoke({"messages": [{"role": "human", "content": "query : Answer about TCS policies?"}]})
response = supervisor_agent.invoke({"messages": [{"role": "human", "content": "query : 2*3-2?"}]})



In [20]:
from langchain.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
for message in response['messages']:
    if(type(message) == HumanMessage):
        print(f"Human : {message.content}")
    elif(type(message) == SystemMessage):
        print(f"System : {message.content}")
    elif(type(message) == AIMessage):
        if(message.content == ''):
            print(f"AI : \n suggested tool_calls : \n name : {message.tool_calls[0]['name']} \n args : {message.tool_calls[0]['args']}")
        else:
            print(f"AI : {message.content}")
    elif(type(message) == ToolMessage):
        print(f"Tool : {message.content}")
    # print(type(message) , message)

Human : query : 2*3-2?
AI : 
 suggested tool_calls : 
 name : calc_agent_tool 
 args : {'query': '2*3-2'}
Tool : The result of \(2 \times 3 - 2\) is \(4\).
AI : The result of the expression \(2 \times 3 - 2\) is \(4\).


In [49]:
for message in all_calls[0]['messages']:
    if(type(message) == HumanMessage):
        print(f"Human Prompt: {message.content}")
    elif(type(message) == SystemMessage):
        print(f"System Prompt: {message.content}")
    elif(type(message) == AIMessage):
        if(message.content == ''):
            if(len(message.tool_calls) >1):
                for i in range(len(message.tool_calls)):
                    print(f"AI Response: \n suggested tool_calls : \n name : {message.tool_calls[i]['name']} \n args : {message.tool_calls[i]['args']}")
            else:
                print(f"AI Response: \n suggested tool_calls : \n name : {message.tool_calls[0]['name']} \n args : {message.tool_calls[0]['args']}")
        else:
            print(f"AI Response: {message.content}")
    elif(type(message) == ToolMessage):
        print(f"Tool Resposne: {message.content}")

Human Prompt: 2*3-2
AI Response: 
 suggested tool_calls : 
 name : multiplication 
 args : {'a': 2, 'b': 3}
AI Response: 
 suggested tool_calls : 
 name : subtraction 
 args : {'a': 6, 'b': 2}
Tool Resposne: 6.0
Tool Resposne: 4.0
AI Response: The result of \(2 \times 3 - 2\) is \(4\).


In [23]:
all_calls

[{'messages': [HumanMessage(content='2*3-2', additional_kwargs={}, response_metadata={}, id='95f3332c-d8ba-48a3-8703-a9c3a43978a3'),
   AIMessage(content='', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 1310, 'total_tokens': 1363, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_provider': 'openai', 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_560af6e559', 'id': 'chatcmpl-CgEoX6vZ1KBq2cxhAHAV5lIpz5OGP', 'service_tier': 'default', 'finish_reason': 'tool_calls', 'logprobs': None}, id='lc_run--f0544469-9357-4dc0-a6f4-083aa2ed5157-0', tool_calls=[{'name': 'multiplication', 'args': {'a': 2, 'b': 3}, 'id': 'call_wtovDtSJ4tlzEkkqtCDehfOw', 'type': 'tool_call'}, {'name': 'subtraction', 'args': {'a': 6, 'b': 2}, 'id': 'call_Zlx79BD5jpFwF0GSyeTHX

In [36]:
tool_calls=[{'name': 'multiplication', 'args': {'a': 2, 'b': 3}, 'id': 'call_wtovDtSJ4tlzEkkqtCDehfOw', 'type': 'tool_call'}, {'name': 'subtraction', 'args': {'a': 6, 'b': 2}, 'id': 'call_Zlx79BD5jpFwF0GSyeTHXjNK', 'type': 'tool_call'}]

In [48]:
tool_calls[0]['name'], tool_calls[0]['args'], tool_calls[1]['name'], tool_calls[0]['args']

('multiplication', {'a': 2, 'b': 3}, 'subtraction', {'a': 2, 'b': 3})