In [13]:
import os

from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain.agents.agent_types import AgentType

In [3]:
OPENAI_KEY = os.getenv('OPENAI_API_KEY')

In [None]:
db_uri = os.getenv("DB_URI")
db = SQLDatabase.from_uri(db_uri)

In [6]:
db.get_table_info()

'\nCREATE TABLE `Album` (\n\t`AlbumId` INTEGER NOT NULL, \n\t`Title` VARCHAR(160) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`ArtistId` INTEGER NOT NULL, \n\tPRIMARY KEY (`AlbumId`), \n\tCONSTRAINT `FK_AlbumArtistId` FOREIGN KEY(`ArtistId`) REFERENCES `Artist` (`ArtistId`)\n)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE `Artist` (\n\t`ArtistId` INTEGER NOT NULL, \n\t`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, \n\tPRIMARY KEY (`ArtistId`)\n)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/\n\n\nCREATE TABLE `Customer` (\n\t`CustomerId` INTEGER NOT NULL, \n\t`FirstName` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`LastNa

In [8]:
llm = ChatOpenAI(
    temperature = 0,
    openai_api_key = OPENAI_KEY,
    model_name = 'gpt-4o'
)

In [9]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [10]:
args = {"return_intermediate_steps" : True}

In [14]:
template = """
    You are a data analyst at a company. You are interacting with a user who is asking you question about the company's database.
    Based on the table schema, question, write a natural language response. Take the conversation history into account.
    Conversation History: {chat_history}
    
    Answer the following questions as best you can. You have access to the following tools:

    {tools}

    Use the following format:

    Question: the input question you must answer
    Thought: you should always think about what to do
    Action: the action to take, should be one of [{tool_names}]
    Action Input: the input to the action
    Observation: the result of the action
    ... (this Thought/Action/Action Input/Observation can repeat N times)
    Thought: I now know the final answer
    Final Answer: the final answer to the original input question

    Begin!

    Question: {input}
    Thought:{agent_scratchpad}
    """
    
prompt = PromptTemplate(
        input_variables=["input","chat_history","tools", "tool_names", "agent_scratchpad"],
        template=template
        )

In [15]:
db_agent = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    prompt=prompt,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    agent_executor_kwargs=args
)
db_agent.handle_parsing_errors = True

In [16]:
# response = db_agent.invoke("how many artists do we have")
response = db_agent.invoke({"input": "how many artists do we have","chat_history": []})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mTo determine how many artists we have, I first need to check the list of tables in the database to identify which table might contain artist information.

Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3mThe table that likely contains information about artists is the "Artist" table. I will now check the schema of the "Artist" table to confirm this and then count the number of artists.

Action: sql_db_schema
Action Input: "Artist"[0m[33;1m[1;3m
CREATE TABLE `Artist` (
	`ArtistId` INTEGER NOT NULL, 
	`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	PRIMARY KEY (`ArtistId`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/[0m[32;1m[1;3mThe "Artist" table contains informa

In [17]:
response

{'input': 'how many artists do we have',
 'chat_history': [],
 'output': 'We have 275 artists in our database.',
 'intermediate_steps': [(AgentAction(tool='sql_db_list_tables', tool_input='', log='To determine how many artists we have, I first need to check the list of tables in the database to identify which table might contain artist information.\n\nAction: sql_db_list_tables\nAction Input: ""'),
   'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'),
  (AgentAction(tool='sql_db_schema', tool_input='Artist', log='The table that likely contains information about artists is the "Artist" table. I will now check the schema of the "Artist" table to confirm this and then count the number of artists.\n\nAction: sql_db_schema\nAction Input: "Artist"'),
   '\nCREATE TABLE `Artist` (\n\t`ArtistId` INTEGER NOT NULL, \n\t`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, \n\tPRIMARY KEY (`ArtistId`)\n)ENGINE=InnoDB DEFAU

In [52]:
response

{'input': 'how many artists do we have',
 'output': 'There are 275 artists in the database.',
 'intermediate_steps': [(AgentAction(tool='sql_db_list_tables', tool_input='', log='Action: sql_db_list_tables\nAction Input: ""'),
   'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'),
  (AgentAction(tool='sql_db_schema', tool_input='Artist', log='The "Artist" table seems to be the most relevant for finding out how many artists are in the database. I should check the schema of the "Artist" table to understand its structure and identify the relevant column to count the number of artists.\n\nAction: sql_db_schema\nAction Input: Artist'),
   '\nCREATE TABLE `Artist` (\n\t`ArtistId` INTEGER NOT NULL, \n\t`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, \n\tPRIMARY KEY (`ArtistId`)\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3

In [47]:
response['intermediate_steps'][0][0].log + " " + response['intermediate_steps'][0][1]

'Action: sql_db_list_tables\nAction Input: "" Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'

In [51]:
for steps in response['intermediate_steps']:
    # print((steps))
    print(steps[0].log+" " + steps[1])
    # thought_text = steps[0][0].log + " " + steps[0][1]
    # print(thought_text)


Action: sql_db_list_tables
Action Input: "" Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
The "Artist" table seems to be the most relevant for finding out how many artists are in the database. I should check the schema of the "Artist" table to understand its structure and identify the relevant column to count the number of artists.

Action: sql_db_schema
Action Input: Artist 
CREATE TABLE `Artist` (
	`ArtistId` INTEGER NOT NULL, 
	`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	PRIMARY KEY (`ArtistId`)
)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/
The "Artist" table contains the column "ArtistId" which can be used to count the number of artists. I will write a query to count the number of distinct "ArtistId" entries in the "Artist" table.

Action: sql_db_query_checker
Action Input: "SELECT COUNT(DISTINCT Arti