In [8]:
import os
from dotenv import find_dotenv, load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain 

_ = load_dotenv(find_dotenv())

In [9]:
# https://learn.microsoft.com/en-us/azure/databricks/sql/admin/sql-execution-tutorial#step-1-execute-a-sql-statement-and-save-the-data-result-as-json

llm = ChatOpenAI(temperature=0, model_name="gpt-4")
db = SQLDatabase.from_databricks(
    host="adb-7365545121955305.5.azuredatabricks.net",
    api_token="dapi8a4dca8f9b2c4c71137b97fbe0092554-3",
    warehouse_id="01de5ba4dfbb71da",
    catalog="samples",
    schema="nyctaxi"
)

In [11]:
# 通过 Chain 的方式
db_chain = SQLDatabaseChain.from_llm(llm=llm, db=db, verbose=True)
db_chain.run("What is the average duration of taxi rides that start between midnight and 6am?")

# 接口的耗时大概在 2s-3s（香港服务器），主要的耗时在SQL生成



[1m> Entering new SQLDatabaseChain chain...[0m
What is the average duration of taxi rides that start between midnight and 6am?
SQLQuery:

[32;1m[1;3mSELECT AVG(UNIX_TIMESTAMP(tpep_dropoff_datetime) - UNIX_TIMESTAMP(tpep_pickup_datetime)) as average_duration
FROM trips
WHERE HOUR(tpep_pickup_datetime) BETWEEN 0 AND 6
LIMIT 5[0m
SQLResult: [33;1m[1;3m[(933.2543103448276,)][0m
Answer:[32;1m[1;3mThe average duration of taxi rides that start between midnight and 6am is approximately 933.25 seconds.[0m
[1m> Finished chain.[0m


'The average duration of taxi rides that start between midnight and 6am is approximately 933.25 seconds.'

In [12]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)

agent.run("What is the longest trip distance and how long did it take?")



[1m> Entering new AgentExecutor chain...[0m


[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m
Observation: [38;5;200m[1;3mtrips[0m
Thought:[32;1m[1;3mThe 'trips' table seems to be the most relevant for this query. I should check its schema to see what columns it has.
Action: sql_db_schema
Action Input: "trips"[0m
Observation: [33;1m[1;3m
CREATE TABLE trips (
	tpep_pickup_datetime TIMESTAMP, 
	tpep_dropoff_datetime TIMESTAMP, 
	trip_distance FLOAT, 
	fare_amount FLOAT, 
	pickup_zip INT, 
	dropoff_zip INT
) USING DELTA

/*
3 rows from trips table:
tpep_pickup_datetime	tpep_dropoff_datetime	trip_distance	fare_amount	pickup_zip	dropoff_zip
2016-02-14 16:52:13+00:00	2016-02-14 17:16:04+00:00	4.94	19.0	10282	10171
2016-02-04 18:44:19+00:00	2016-02-04 18:46:00+00:00	0.28	3.5	10110	10110
2016-02-17 17:13:57+00:00	2016-02-17 17:17:55+00:00	0.7	5.0	10103	10023
*/[0m
Thought:[32;1m[1;3mThe 'trips' table has the columns 'trip_distance' and 'tpep_pickup_datetime' and 'tpep_dropoff_datetime' which I can use to calculate 

'The longest trip distance is 30.6 miles and it took approximately 43 minutes and 31 seconds.'