<a href="https://colab.research.google.com/github/Theedon/AIAgentDemos/blob/main/multi_tool_calling_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Download chinook data
!curl -L -o chinook.zip http://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip

!unzip chinook.zip
!mv chinook.db sample_data/chinook.db
!rm chinook.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   167  100   167    0     0   2146      0 --:--:-- --:--:-- --:--:--  2168
100  298k  100  298k    0     0   860k      0 --:--:-- --:--:-- --:--:-- 2426k
Archive:  chinook.zip
  inflating: chinook.db              


In [None]:
%%capture

!pip install  langchain_experimental langchain_openai langchain_google_community redis
!pip install --upgrade openai
!pip install -qU httpx

In [None]:
import os
from google.colab import userdata

os.environ["AZURE_OPENAI_ENDPOINT"] = userdata.get('AZURE_OPENAI_ENDPOINT')
os.environ["AZURE_OPENAI_API_KEY"] = userdata.get('AZURE_OPENAI_API_KEY')
os.environ["OPENAI_API_VERSION"] = userdata.get('OPENAI_API_VERSION')
os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
os.environ["GOOGLE_API_KEY"] = userdata.get('GOOGLE_API_KEY')
os.environ["GOOGLE_CSE_ID"] = userdata.get('GOOGLE_CSE_ID')
os.environ["REDIS_URL"] = userdata.get('REDIS_URL')

In [None]:
databases = [
    {
        "database_name": "Datamart",
        "connection_string": "postgresql+psycopg2://eagle_admin:Test1234Test1234@eagle-db.postgres.database.azure.com/supply_db",
        "description": "Datamart is a database that contains information on stocks and orders.",
        "prompt_query": ""
    },
    {
        "database_name": "Chinook",
        "connection_string": "sqlite:///sample_data/chinook.db",
        "description": "Chinook contains information about a music store.",
        "prompt_query": ""
    }
]

In [None]:
from google.colab import userdata
from langchain_openai import ChatOpenAI
from langchain.llms import OpenAI

from langchain_openai import (
    AzureChatOpenAI,
    AzureOpenAIEmbeddings,
)

LLM = AzureChatOpenAI(
    azure_deployment="4o-mini-deploy",
)

In [None]:
from typing import Any, Optional, Type
import json
from langchain.tools import BaseTool
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import BaseModel

class DbTool(BaseTool):
    name: str
    description: str
    args_schema: Type[BaseModel]
    db_conn_details: str
    prompt_query: Optional[str] = None
    return_direct: bool = False
    llm: Any

    def _run(self, question: str):
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "You are a database expert.",
                ),
                ("human", "{input}"),
                MessagesPlaceholder("agent_scratchpad"),
            ]
        )
        try:
            db = SQLDatabase.from_uri(database_uri=self.db_conn_details)
            agent_executor = create_sql_agent(
                self.llm,
                db=db,
                prompt=prompt,
                agent_type="tool-calling",
                verbose=True,
                extra_tools=[],
            )
            response = agent_executor.invoke({"input": question})
            return response
        except Exception as e:
            raise e

    def _arun(self, question: str):
        raise NotImplementedError("does not support async")

In [None]:
from langchain.tools import BaseTool
import os
import ast
import re
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from datetime import datetime
import uuid
import traceback
from redis import Redis
import json
from pydantic import Field

class CustomPythonReplToolForCharting(BaseTool):
    name: str = "custom_python_repl_charting"
    description: str = "A custom REPL tool for executing Python code. Use this to evaluate python scripts and return results. ALL CHARTS ARE"


    def _run(self, code: str) -> str:
        """Execute Python code safely and return the output."""

        print("**********LLM GENERATED CODE***********\n{}\n**************************************".format(code))

        # Clean up the code
        if "plt.show" in code:
            code = code.replace("plt.show()", "")

        if "plt.savefig" in code:
            code = re.sub(r"plt\.savefig\([^\)]*\)", '""', code)

        if "plt.close" in code:
            code = re.sub(r"plt\.close\([^\)]*\)", '""', code)

        print("**********PYTHON CODE BEEN EXECUTED***********\n{}\n**************************************".format(code))

        try:
            tree = ast.parse(code)
            last_node = tree.body[-1] if tree.body else None

            # Capture the last expression's result
            if isinstance(last_node, ast.Expr):
                target = ast.Name(id="_result", ctx=ast.Store())
                assign = ast.Assign(targets=[target], value=last_node.value)
                tree.body[-1] = ast.fix_missing_locations(assign)

            local_namespace = {}
            compiled_code = compile(tree, filename="<ast>", mode="exec")
            exec(compiled_code, {"__builtins__": __builtins__}, local_namespace)

            # Check if the code includes charting
            if plt.get_fignums() and plt.gca().has_data():
                chart_path = os.path.join(os.getcwd(), "charts")
                if not os.path.exists(chart_path):
                    os.makedirs(chart_path)  # Create chart directory if it doesn't exist

                # Generate a unique name for the chart
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                unique_id = uuid.uuid4().hex[:8]
                chart_file = os.path.join(chart_path, f"chart_{timestamp}_{unique_id}.png")

                # Save the chart if any
                # TODO:: upload file to azure and get file link
                plt.savefig(chart_file)
                plt.close()  # Close the plot to clear matplotlib state
                return f"Chart saved at {chart_file}\n"

            result = local_namespace.get("_result", "no result returned")
            return str(result)

        except Exception as e:
            error_details = traceback.format_exc()
            return f"Error: {e}\n\nTraceback:\n{error_details}"

        finally:
            # Ensure the matplotlib state is reset for subsequent calls
            plt.clf()
            plt.close('all')

    async def _arun(self, code: str) -> str:
        raise NotImplementedError("Asynchronous execution is not implemented yet.")

In [None]:
from langchain.tools import BaseTool
from pydantic import Field
from typing import List, Tuple, Any
from decimal import Decimal

from langchain.agents import (
    AgentExecutor,
    create_tool_calling_agent,
    AgentType,
    initialize_agent
)

class ChartToolInput(BaseModel):
    """Input for Chart Tool."""

    data: str  = Field(..., description="Data to generate chart from")
    chart_type: str = Field(..., description="Type of chart to generate")

class ChartTool(BaseTool):
    name: str = "ChartTool"
    description: str = "Accepts data that is JSON string and the generates chat for it."
    args_schema: Optional[Type[BaseModel]] = ChartToolInput

    def _run(self, data: str, chart_type: str) -> str:
        """Execute Python code safely and return the output."""

        prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "You are a python expert. Write python codes to generate charts."),
            ("placeholder", "{chat_history}"),
            ("human", "{input}"),
            ("placeholder", "{agent_scratchpad}"),
        ]
    )
        agent = create_tool_calling_agent(
        llm=LLM,
        tools=[CustomPythonReplToolForCharting()],
        prompt=prompt,
    )
        agent_executor = AgentExecutor(agent=agent, tools=[CustomPythonReplToolForCharting()], verbose=True)

        code = agent_executor.invoke({"input": f"Generate a {chart_type} chart from the data {data}"})


        return code["output"]

    async def _arun(self, code: str) -> str:
        raise NotImplementedError("Asynchronous execution is not implemented yet.")


In [None]:
from pydantic import BaseModel, Field, ValidationError

from typing import Optional, List

from typing_extensions import Annotated, TypedDict


datamart = databases[0]
chinook = databases[1]



class UserQuery(BaseModel):
    """get_info_from_prompt"""

    question: str = Field(description="question from the prompt")


In [None]:
tools = [DbTool(name=datamart['database_name'],description=datamart['description'],
            db_conn_details=datamart['connection_string'],
            prompt_query=datamart['prompt_query'],
            args_schema=UserQuery,
            llm=LLM,
            verbose=True,), DbTool(name=chinook['database_name'],description=chinook['description'],
            db_conn_details=chinook['connection_string'],
            prompt_query=chinook['prompt_query'],
            args_schema=UserQuery,
            llm=LLM,
            verbose=True,),ChartTool()]

In [None]:
from langchain.tools import MoveFileTool, format_tool_to_openai_function

format_tool_to_openai_function(tools[1])

  format_tool_to_openai_function(tools[1])


{'name': 'Chinook',
 'description': 'Chinook contains information about a music store.',
 'parameters': {'properties': {'question': {'description': 'question from the prompt',
    'type': 'string'}},
  'required': ['question'],
  'type': 'object'}}

In [None]:
tools[1]({"data": ChartToolInput(
    data=[('PCPET50', Decimal('106770934')), ('SKEPET30', Decimal('35057158')), ('MOFRPET40', Decimal('25529147')), ('7UPFRPET40', Decimal('24435012')), ('AFPET75', Decimal('21525268')), ('7UPPET50', Decimal('21288364')), ('MOPET50', Decimal('18992715'))]
)})

[('data', [('PCPET50', Decimal('106770934')), ('SKEPET30', Decimal('35057158')), ('MOFRPET40', Decimal('25529147')), ('7UPFRPET40', Decimal('24435012')), ('AFPET75', Decimal('21525268')), ('7UPPET50', Decimal('21288364')), ('MOPET50', Decimal('18992715'))])]


In [None]:
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.llms import OpenAI
from langchain.agents import (
    AgentExecutor,
    create_tool_calling_agent,
    AgentType
)

llm = LLM

from langchain_core.prompts import ChatPromptTemplate

def get_prompt_for_openai_functions_agent():
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", ""),
            ("placeholder", "{chat_history}"),
            ("human", "{input}"),
            ("placeholder", "{agent_scratchpad}"),
        ]
    )
    return prompt

agent = create_tool_calling_agent(
        llm=llm,
        tools=tools,
        prompt=get_prompt_for_openai_functions_agent(),
    )

#new agent executor
agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    verbose=True,
    handle_parsing_errors=True,
    return_intermediate_steps=True,
)

In [None]:
def run_demo(agent_executor: AgentExecutor, query: list | str):
    query_list = [query] if isinstance(query, str) else query

    for i in range(len(query_list)):
        res = agent_executor.invoke({"input": query_list[i]})

        # Print the result
        print(f"Result for query {i+1}: {res}")

        # Demarcation line after each result
        print("\n" + "="*50 + f" End of Query {i+1} " + "="*50 + "\n")

In [None]:
run_demo(agent_executor, "Generate a chart showing the top 7 skus in datamart?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Datamart` with `{'question': 'What are the top 7 SKUs?'}`


[0m

[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mCustomer Order, Customers, Hourly Rated Capacity, Product in Facility(Stocks), Production Output[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Product in Facility(Stocks)'}`


[0m[33;1m[1;3m
CREATE TABLE "Product in Facility(Stocks)" (
	"Region" TEXT, 
	"Period No" BIGINT, 
	"Date" TIMESTAMP WITHOUT TIME ZONE, 
	"SKU" TEXT, 
	quantity BIGINT, 
	index INTEGER DEFAULT nextval('"Product in Facility(Stocks)_index_seq"'::regclass) NOT NULL
)

/*
3 rows from Product in Facility(Stocks) table:
Region	Period No	Date	SKU	quantity	index
Ikeja	1	2022-04-01 00:00:00	2SDWCL1000	7535	1
Ikeja	1	2022-04-01 00:00:00	2SDWCL10L	40	2
Ikeja	1	2022-04-01 00:00:00	2SDWCL250	1789	3
*/[0m[32;1m[1;3m
Invoking: `sql_db_sche

In [None]:
run_demo(agent_executor, "Conduct an analysis of the data in datamart with visualizations")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Datamart` with `{'question': 'Analyze the trends in stock prices and order volumes over the past year.'}`


[0m

[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`
responded: To analyze the trends in stock prices and order volumes over the past year, we typically need to gather data from specific tables in the database that contain information related to stock prices and order volumes. 

Here's how we can proceed:

1. **Identify the relevant tables**: We need to know which tables in the database contain stock price data and order volume data.
2. **Query the data**: Once we have identified the right tables, we can run SQL queries to retrieve data for the past year and analyze trends.

Let's start by listing all the tables in the database to identify those that might contain stock information.

[0m[38;5;200m[1;3mCustomer Order, Customers, Hourly Rated Capaci

RateLimitError: Error code: 429 - {'error': {'code': '429', 'message': 'Requests to the ChatCompletions_Create Operation under Azure OpenAI API version 2024-02-15-preview have exceeded token rate limit of your current OpenAI S0 pricing tier. Please retry after 86400 seconds. Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit.'}}

In [None]:
run_demo(agent_executor, "What are the top artist with the most albums")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Chinook` with `{'question': 'Who are the top artists with the most albums?'}`


[0m

[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3malbums, artists, customers, employees, genres, invoice_items, invoices, media_types, playlist_track, playlists, tracks[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'albums, artists'}`


[0m[33;1m[1;3m
CREATE TABLE albums (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES artists ("ArtistId")
)

/*
3 rows from albums table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE artists (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from artists table:
Arti

In [None]:
run_demo(agent_executor, "Plot a pie chart of the top 5 artists with the most albums")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Chinook` with `{'question': 'List the top 5 artists by the number of albums they have released.'}`


[0m

[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3malbums, artists, customers, employees, genres, invoice_items, invoices, media_types, playlist_track, playlists, tracks[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'albums, artists'}`


[0m[33;1m[1;3m
CREATE TABLE albums (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES artists ("ArtistId")
)

/*
3 rows from albums table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE artists (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows fro