In [2]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase

load_dotenv()

True

**Set the environment variable and load the LLM**

In [3]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPEN_AI_API_KEY")


llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# llm = ChatOpenAI(model="gpt-4o")

**Load and test the sqlite db**

In [5]:
sqldb_directory = here("data/travel.sqlite")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM aircrafts_data LIMIT 10;")

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'car_rentals', 'flights', 'hotels', 'seats', 'ticket_flights', 'tickets', 'trip_recommendations']


"[('773', 'Boeing 777-300', 11100), ('763', 'Boeing 767-300', 7900), ('SU9', 'Sukhoi Superjet-100', 3000), ('320', 'Airbus A320-200', 5700), ('321', 'Airbus A321-200', 5600), ('319', 'Airbus A319-100', 6700), ('733', 'Boeing 737-300', 4200), ('CN1', 'Cessna 208 Caravan', 1200), ('CR2', 'Bombardier CRJ-200', 2700)]"

**Create the SQL agent and run a test query**

In [6]:
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many rows are there in the aircrafts_data table?"})
response

'SELECT COUNT(*) AS total_rows FROM aircrafts_data;'

In [7]:
db.run(response)

'[(9,)]'