In [1]:
from dotenv import load_dotenv
import os
load_dotenv()


# Set the GROQ_API_KEY in the environment (optional)
# os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")


from langchain_groq import ChatGroq

llm = ChatGroq(model="llama3-8b-8192")

In [2]:
from pydantic import BaseModel, Field


class add(BaseModel):
    """Add two integers."""

    a: int = Field(..., description="First integer")
    b: int = Field(..., description="Second integer")


class multiply(BaseModel):
    """Multiply two integers."""

    a: int = Field(..., description="First integer")
    b: int = Field(..., description="Second integer")

tools = [add, multiply]

In [3]:
# Code for pydantic schema with funtion attached using decorator 

# from pydantic import BaseModel, Field


# class CalculatorInput(BaseModel):
#     a: int = Field(description="first number")
#     b: int = Field(description="second number")


# @tool("multiplication-tool", args_schema=CalculatorInput, return_direct=True)
# def multiply(a: int, b: int) -> int:
#     """Multiply two numbers."""
#     return a * b

In [4]:
llm_with_tools = llm.bind_tools(tools)

query = "What is 3 * 12?"

llm_with_tools.invoke(query)

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_wz66', 'function': {'arguments': '{"a":3,"b":12}', 'name': 'multiply'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 72, 'prompt_tokens': 1064, 'total_tokens': 1136, 'completion_time': 0.06, 'prompt_time': 0.169452123, 'queue_time': 0.0029757609999999934, 'total_time': 0.229452123}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_6a6771ae9c', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-3f87058c-7b2c-40c8-ae93-3dc015d1dd79-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_wz66', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1064, 'output_tokens': 72, 'total_tokens': 1136})

In [5]:
query = "What is 3 * 12? Also, what is 11 + 49?"

llm_with_tools.invoke(query).tool_calls

[{'name': 'multiply',
  'args': {'a': 3, 'b': 12},
  'id': 'call_2hxj',
  'type': 'tool_call'},
 {'name': 'add',
  'args': {'a': 11, 'b': 49},
  'id': 'call_s6dj',
  'type': 'tool_call'}]

In [6]:
from langchain_core.tools import tool


@tool
def add(a: int, b: int) -> int:
    """Adds a and b."""
    return a + b


@tool
def multiply(a: int, b: int) -> int:
    """Multiplies a and b."""
    return a * b


tools = [add, multiply]

llm_with_tools = llm.bind_tools(tools)

In [7]:
from langchain_core.messages import HumanMessage

query = "What is 3 * 12? Also, what is 11 + 49?"

messages = [HumanMessage(query)]

ai_msg = llm_with_tools.invoke(messages)

print(ai_msg.tool_calls)

messages.append(ai_msg)

[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_e90r', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_nw36', 'type': 'tool_call'}]


In [8]:
for tool_call in ai_msg.tool_calls:
    selected_tool = {"add": add, "multiply": multiply}[tool_call["name"].lower()]
    tool_msg = selected_tool.invoke(tool_call)
    messages.append(tool_msg)

messages

[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?', additional_kwargs={}, response_metadata={}),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_e90r', 'function': {'arguments': '{"a":3,"b":12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_nw36', 'function': {'arguments': '{"a":11,"b":49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 125, 'prompt_tokens': 1045, 'total_tokens': 1170, 'completion_time': 0.104166667, 'prompt_time': 0.286175094, 'queue_time': 0.005599607999999978, 'total_time': 0.390341761}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_a97cfe35ae', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8e304b2d-fc91-4453-a57a-f5744f03d7a0-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_e90r', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_nw36', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1045,

In [9]:
llm_with_tools.invoke(messages)

AIMessage(content='36 * 12 = 432 and 11 + 49 = 60. Therefore, the answer is 432 and 60.', additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 29, 'prompt_tokens': 1169, 'total_tokens': 1198, 'completion_time': 0.024166667, 'prompt_time': 0.135963614, 'queue_time': 0.004036023, 'total_time': 0.160130281}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_6a6771ae9c', 'finish_reason': 'stop', 'logprobs': None}, id='run-80fae12b-3ebc-4361-bb1a-497a88bffd5f-0', usage_metadata={'input_tokens': 1169, 'output_tokens': 29, 'total_tokens': 1198})

In [10]:
### streaming

In [12]:
query = "What is 3 * 12? Also, what is 11 + 49?"

async for chunk in llm_with_tools.astream(query):
    print(chunk.tool_call_chunks)

[]
[{'name': 'multiply', 'args': '{"a":3,"b":12}', 'id': 'call_k14g', 'index': 0, 'type': 'tool_call_chunk'}]
[{'name': 'add', 'args': '{"a":11,"b":49}', 'id': 'call_nxkt', 'index': 1, 'type': 'tool_call_chunk'}]
[]


In [13]:
first = True
async for chunk in llm_with_tools.astream(query):
    if first:
        gathered = chunk
        first = False
    else:
        gathered = gathered + chunk

    print(gathered.tool_call_chunks)

[]
[{'name': 'multiply', 'args': '{"a":3,"b":12}', 'id': 'call_803x', 'index': 0, 'type': 'tool_call_chunk'}]
[{'name': 'multiply', 'args': '{"a":3,"b":12}', 'id': 'call_803x', 'index': 0, 'type': 'tool_call_chunk'}, {'name': 'add', 'args': '{"a":11,"b":49}', 'id': 'call_2zek', 'index': 1, 'type': 'tool_call_chunk'}]
[{'name': 'multiply', 'args': '{"a":3,"b":12}', 'id': 'call_803x', 'index': 0, 'type': 'tool_call_chunk'}, {'name': 'add', 'args': '{"a":11,"b":49}', 'id': 'call_2zek', 'index': 1, 'type': 'tool_call_chunk'}]


In [14]:
print(type(gathered.tool_call_chunks[0]["args"]))

<class 'str'>


In [15]:
first = True
async for chunk in llm_with_tools.astream(query):
    if first:
        gathered = chunk
        first = False
    else:
        gathered = gathered + chunk

    print(gathered.tool_calls)

[]
[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_g7p8', 'type': 'tool_call'}]
[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_g7p8', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_qrwm', 'type': 'tool_call'}]
[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_g7p8', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_qrwm', 'type': 'tool_call'}]


In [16]:
print(type(gathered.tool_calls[0]["args"]))

<class 'dict'>
