In [1]:
%pip install llama-index
%pip install groq
!pip install llama-index-llms-groq

Collecting llama-index
  Downloading llama_index-0.12.42-py3-none-any.whl.metadata (12 kB)
Collecting llama-index-agent-openai<0.5,>=0.4.0 (from llama-index)
  Downloading llama_index_agent_openai-0.4.11-py3-none-any.whl.metadata (439 bytes)
Collecting llama-index-cli<0.5,>=0.4.2 (from llama-index)
  Downloading llama_index_cli-0.4.3-py3-none-any.whl.metadata (1.4 kB)
Collecting llama-index-core<0.13,>=0.12.42 (from llama-index)
  Downloading llama_index_core-0.12.42-py3-none-any.whl.metadata (2.4 kB)
Collecting llama-index-embeddings-openai<0.4,>=0.3.0 (from llama-index)
  Downloading llama_index_embeddings_openai-0.3.1-py3-none-any.whl.metadata (684 bytes)
Collecting llama-index-indices-managed-llama-cloud>=0.4.0 (from llama-index)
  Downloading llama_index_indices_managed_llama_cloud-0.7.7-py3-none-any.whl.metadata (3.3 kB)
Collecting llama-index-llms-openai<0.5,>=0.4.0 (from llama-index)
  Downloading llama_index_llms_openai-0.4.5-py3-none-any.whl.metadata (3.0 kB)
Collecting llama

In [2]:
import os
os.environ["GROQ_API_KEY"] = "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"


In [3]:
from pydantic import BaseModel, ConfigDict

from llama_index.core.tools import BaseTool

class AgentConfig(BaseModel):

    model_config = ConfigDict(arbitrary_types_allowed=True)

    name: str
    description: str
    system_prompt: str | None = None
    tools: list[BaseTool] | None = None

In [4]:
from llama_index.core.tools import FunctionTool

def add_two_numbers(a: int, b: int) -> int:

    return a + b

add_two_numbers_tool = FunctionTool.from_defaults(fn=add_two_numbers)

agent_config = AgentConfig(
    name="Addition Agent",
    description="Used to add two numbers together.",
    system_prompt = (
    "You are an AI agent ONLY capable of performing addition. "
    "If the user asks anything outside of adding two numbers, respond with: "
    "'Sorry, I can only help with addition problems.' "
    "Do not answer unrelated questions."
),
    tools=[add_two_numbers_tool],
)

In [5]:
from llama_index.core.workflow import Event, Workflow, Context, StartEvent, StopEvent, step
from llama_index.core.llms import ChatMessage, LLM
from llama_index.core.tools import ToolSelection
from llama_index.llms.groq import Groq


class LLMCallEvent(Event):
    pass

class ToolCallEvent(Event):
    pass

class ToolCallResultEvent(Event):
    chat_message: ChatMessage

class ProgressEvent(Event):
    msg: str


class BasicAgent(Workflow):

    @step
    async def setup(
        self, ctx: Context, ev: StartEvent
    ) -> LLMCallEvent:

        agent_config = ev.get("agent_config")
        user_msg = ev.get("user_msg")
        llm: LLM = ev.get("llm", default=Groq(model="llama3-70b-8192", temperature=0.3))
        chat_history = ev.get("chat_history", default=[])
        initial_state = ev.get("initial_state", default={})

        if (
            user_msg is None
            or llm is None
            or chat_history is None
        ):
            raise ValueError(
                "User message, llm, and chat_history are required!"
            )

        if not llm.metadata.is_function_calling_model:
            raise ValueError("LLM must be a function calling model!")

        await ctx.set("agent_config", agent_config)
        await ctx.set("llm", llm)

        chat_history.append(ChatMessage(role="user", content=user_msg))
        await ctx.set("chat_history", chat_history)

        await ctx.set("user_state", initial_state)

        return LLMCallEvent()

    @step
    async def speak_with_agent(
        self, ctx: Context, ev: LLMCallEvent
    ) -> ToolCallEvent | StopEvent:


        agent_config: AgentConfig = await ctx.get("agent_config")
        chat_history = await ctx.get("chat_history")
        llm = await ctx.get("llm")

        user_state = await ctx.get("user_state")
        user_state_str = "\n".join([f"{k}: {v}" for k, v in user_state.items()])
        system_prompt = (
            agent_config.system_prompt.strip()
            + f"\n\nHere is the current user state:\n{user_state_str}"
        )

        llm_input = [ChatMessage(role="system", content=system_prompt)] + chat_history
        tools = agent_config.tools

        response = await llm.achat_with_tools(tools, chat_history=llm_input)

        tool_calls: list[ToolSelection] = llm.get_tool_calls_from_response(
            response, error_on_no_tool_call=False
        )
        if len(tool_calls) == 0:
            chat_history.append(response.message)
            await ctx.set("chat_history", chat_history)
            return StopEvent(
                result={
                    "response": response.message.content,
                    "chat_history": chat_history,
                }
            )

        await ctx.set("num_tool_calls", len(tool_calls))

        for tool_call in tool_calls:
            ctx.send_event(
                ToolCallEvent(tool_call=tool_call, tools=agent_config.tools)
            )

        chat_history.append(response.message)
        await ctx.set("chat_history", chat_history)

    @step(num_workers=4)
    async def handle_tool_call(
        self, ctx: Context, ev: ToolCallEvent
    ) -> ToolCallResultEvent:

        tool_call = ev.tool_call
        tools_by_name = {tool.metadata.get_name(): tool for tool in ev.tools}

        tool_msg = None

        tool = tools_by_name.get(tool_call.tool_name)
        additional_kwargs = {
            "tool_call_id": tool_call.tool_id,
            "name": tool.metadata.get_name(),
        }
        if not tool:
            tool_msg = ChatMessage(
                role="tool",
                content=f"Tool {tool_call.tool_name} does not exist",
                additional_kwargs=additional_kwargs,
            )

        try:
            tool_output = await tool.acall(**tool_call.tool_kwargs)

            tool_msg = ChatMessage(
                role="tool",
                content=tool_output.content,
                additional_kwargs=additional_kwargs,
            )
        except Exception as e:
            tool_msg = ChatMessage(
                role="tool",
                content=f"Encountered error in tool call: {e}",
                additional_kwargs=additional_kwargs,
            )

        ctx.write_event_to_stream(
            ProgressEvent(
                msg=f"Tool {tool_call.tool_name} called with {tool_call.tool_kwargs} returned {tool_msg.content}"
            )
        )

        return ToolCallResultEvent(chat_message=tool_msg)

    @step
    async def aggregate_tool_results(
        self, ctx: Context, ev: ToolCallResultEvent
    ) -> LLMCallEvent:

        num_tool_calls = await ctx.get("num_tool_calls")
        results = ctx.collect_events(ev, [ToolCallResultEvent] * num_tool_calls)
        if not results:
            return

        chat_history = await ctx.get("chat_history")
        for result in results:
            chat_history.append(result.chat_message)
        await ctx.set("chat_history", chat_history)

        return LLMCallEvent()

In [12]:
from llama_index.llms.groq import Groq


llm = Groq(model="llama3-70b-8192", temperature=0.3)
workflow = BasicAgent()


handler = workflow.run(
        agent_config=agent_config,
        user_msg="What is 10 + 10?",
        chat_history=[],
        initial_state={"user_name": "Akshay"},
        llm=llm,
    )

async for event in handler.stream_events():
        if isinstance(event, ProgressEvent):
            print(event.msg)

final_result = await handler
print(final_result["response"])


Tool add_two_numbers called with {'a': 10, 'b': 10} returned 20
The answer is 20.


In [13]:
from llama_index.core.memory import ChatMemoryBuffer

memory = ChatMemoryBuffer.from_defaults(
    llm=llm,
)

memory.set(final_result["chat_history"])

In [14]:
async def main():
  handler = workflow.run(
      agent_config=agent_config,
      user_msg="What is the capital of Canada?",
      chat_history=memory.get(),
      initial_state={"user_name": "Akshay"},
      llm=llm,
      memory=memory,
  )

  async for event in handler.stream_events():
      if isinstance(event, ProgressEvent):
          print(event.msg)

  print("-----------")

  final_result = await handler
  print(final_result["response"])

await main()

-----------
Sorry, I can only help with addition problems."}


In [15]:
memory.set(final_result["chat_history"])