In [1]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.llms import AzureOpenAI

In [2]:
import duckdb

In [3]:
conn = duckdb.connect(database='profiles.db')

In [4]:
conn.execute("CREATE TABLE profile AS SELECT * FROM read_parquet('data/summary.parquet')")

<duckdb.DuckDBPyConnection at 0x29b94679430>

In [5]:
conn.execute("INSTALL sqlite;")

<duckdb.DuckDBPyConnection at 0x29b94679430>

In [6]:
uri = "duckdb:///profiles.db"

In [7]:
from sqlalchemy import create_engine

In [8]:
engine = create_engine(uri)

In [9]:
db = SQLDatabase.from_uri(uri,include_tables=['profile'])



In [10]:
from langchain.chat_models import AzureChatOpenAI

In [11]:
llm = AzureChatOpenAI(model="gpt-35-turbo",
                      openai_api_key="2dd4400d079a4fd49ddd2e864802522a",
                      openai_api_base="https://genai-interns.openai.azure.com/",
                      deployment_name="genai",
                      openai_api_version="2023-07-01-preview",
                      temperature=0,)

In [12]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [13]:
from langchain.prompts.chat import ChatPromptTemplate

In [15]:
template = ("""
            You are a SQL Analyst that is querying a database of a summary of a airline .
            
            Below is a description of the columns ,data types , and information in the columns:
            
            The column name customers with the data type int64 contains the total number of customers
            
            The column name journeys with the data type int64 contains the total number of journeys
            
            The column name bookers with the data type int64 contains the total number of bookers
            
            The column name bookings with the data type int64 contains the total number of bookings
            
            Your job is to answer the following questions:
            {query}
            """)

In [16]:
agent= create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION
)

In [17]:
from langchain.prompts import PromptTemplate

In [18]:
prompt = PromptTemplate.from_template(template)

In [20]:
agent.run(prompt.format(query="what is the total booking count"))



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m
Observation: [38;5;200m[1;3mprofile[0m
Thought:[32;1m[1;3mThe table "profile" seems relevant. I should query its schema.
Action: sql_db_schema
Action Input: "profile"[0m
Observation: [33;1m[1;3m
CREATE TABLE profile (
	customers BIGINT, 
	journeys BIGINT, 
	bookers BIGINT, 
	bookings BIGINT
)

/*
3 rows from profile table:
customers	journeys	bookers	bookings
17851840	37512228	1316354	20300158
*/[0m
Thought:[32;1m[1;3mThe relevant table is "profile" and it has a column named "bookings" which contains the total number of bookings.
Action: sql_db_query
Action Input: "SELECT SUM(bookings) FROM profile"[0m
Observation: [36;1m[1;3m[(20300158,)][0m
Thought:[32;1m[1;3mThe total booking count is 20,300,158.
Final Answer: The total booking count is 20,300,158.[0m

[1m> Finished chain.[0m


'The total booking count is 20,300,158.'