In [1]:
_ = !pip install pandas
_ = !pip install langchain

In [2]:
import os
import sqlite3
import pandas as pd
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain

### Create and populate local database

In [3]:
# source: https://gist.githubusercontent.com/seankross
data_path = 'data/mtcars.csv'
!curl 'https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv' >> data/mtcars.csv

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1700  100  1700    0     0  17708      0 --:--:-- --:--:-- --:--:-- 17708


In [4]:
df = pd.read_csv(data_path, header=0)
df.head(15)

Unnamed: 0,model,mpg,cyl,disp,hp,drat,wt,qsec,vs,am,gear,carb
0,Mazda RX4,21.0,6,160.0,110,3.9,2.62,16.46,0,1,4,4
1,Mazda RX4 Wag,21.0,6,160.0,110,3.9,2.875,17.02,0,1,4,4
2,Datsun 710,22.8,4,108.0,93,3.85,2.32,18.61,1,1,4,1
3,Hornet 4 Drive,21.4,6,258.0,110,3.08,3.215,19.44,1,0,3,1
4,Hornet Sportabout,18.7,8,360.0,175,3.15,3.44,17.02,0,0,3,2
5,Valiant,18.1,6,225.0,105,2.76,3.46,20.22,1,0,3,1
6,Duster 360,14.3,8,360.0,245,3.21,3.57,15.84,0,0,3,4
7,Merc 240D,24.4,4,146.7,62,3.69,3.19,20.0,1,0,4,2
8,Merc 230,22.8,4,140.8,95,3.92,3.15,22.9,1,0,4,2
9,Merc 280,19.2,6,167.6,123,3.92,3.44,18.3,1,0,4,4


In [5]:
df_sub = df[['model', 'mpg', 'cyl']]
df_sub.head(10)

Unnamed: 0,model,mpg,cyl
0,Mazda RX4,21.0,6
1,Mazda RX4 Wag,21.0,6
2,Datsun 710,22.8,4
3,Hornet 4 Drive,21.4,6
4,Hornet Sportabout,18.7,8
5,Valiant,18.1,6
6,Duster 360,14.3,8
7,Merc 240D,24.4,4
8,Merc 230,22.8,4
9,Merc 280,19.2,6


In [6]:
conn = sqlite3.connect('cars.sqlite')
cur = conn.cursor()

table_name = 'cars'
query = f"""CREATE TABLE IF NOT EXISTS {table_name} (model TEXT, mpg REAL, cyl INTEGER)"""
cur.execute(query)
conn.commit()

df_sub.to_sql(table_name, conn, if_exists='replace', index=False)
conn.commit()

# somehow the header ends up as a database record. Remove them
cur.execute(f"""DELETE FROM {table_name} WHERE model='model';""")
conn.commit()
conn.close()

In [7]:
!ls | grep cars.sqlite

cars.sqlite


In [8]:
def read_query(sql):
    conn = sqlite3.connect('cars.sqlite')
    cur = conn.cursor()
    cur.execute(sql)
    rows = cur.fetchall()
    for row in rows:
        print(row)
    conn.close()

In [9]:
read_query('SELECT * FROM cars LIMIT 10;')

('Mazda RX4', '21', '6')
('Mazda RX4 Wag', '21', '6')
('Datsun 710', '22.8', '4')
('Hornet 4 Drive', '21.4', '6')
('Hornet Sportabout', '18.7', '8')
('Valiant', '18.1', '6')
('Duster 360', '14.3', '8')
('Merc 240D', '24.4', '4')
('Merc 230', '22.8', '4')
('Merc 280', '19.2', '6')


### Connect to OpenAI via LangChain

In [10]:
db = SQLDatabase.from_uri('sqlite:///cars.sqlite')

In [11]:
os.environ['OPENAI_API_KEY'] = 'your API key here'
llm = OpenAI(temperature=0)
# to obtain an API key, create an account at https://openai.com/ 
# NOTE: if you create a paid account, charges will apply. Refer to pricing page on the website for details.

In [12]:
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)

In [13]:
db_chain.run('Which model has the highest mpg?')



[1m> Entering new SQLDatabaseChain chain...[0m
Which model has the highest mpg? 
SQLQuery:[32;1m[1;3m SELECT model, mpg FROM cars ORDER BY mpg DESC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Toyota Corolla', '33.9')][0m
Answer:[32;1m[1;3m The model with the highest mpg is Toyota Corolla with 33.9 mpg.[0m
[1m> Finished chain.[0m


' The model with the highest mpg is Toyota Corolla with 33.9 mpg.'

In [14]:
read_query('SELECT model, mpg FROM cars ORDER BY mpg DESC LIMIT 1;')

('Toyota Corolla', '33.9')


In [15]:
db_chain.run('What is the maximum mpg by number of cyl?')



[1m> Entering new SQLDatabaseChain chain...[0m
What is the maximum mpg by number of cyl? 
SQLQuery:[32;1m[1;3m SELECT cyl, MAX(mpg) FROM cars GROUP BY cyl;[0m
SQLResult: [33;1m[1;3m[('4', '33.9'), ('6', '21.4'), ('8', '19.2')][0m
Answer:[32;1m[1;3m The maximum mpg by number of cyl is 33.9 for 4 cyl, 21.4 for 6 cyl, and 19.2 for 8 cyl.[0m
[1m> Finished chain.[0m


' The maximum mpg by number of cyl is 33.9 for 4 cyl, 21.4 for 6 cyl, and 19.2 for 8 cyl.'

In [16]:
read_query('SELECT cyl, MAX(mpg) FROM cars GROUP BY cyl;')

('4', '33.9')
('6', '21.4')
('8', '19.2')


In [17]:
db_chain.run('What is the most efficient car that has 6 cylinders?')



[1m> Entering new SQLDatabaseChain chain...[0m
What is the most efficient car that has 6 cylinders? 
SQLQuery:[32;1m[1;3m SELECT model, mpg FROM cars WHERE cyl = 6 ORDER BY mpg DESC LIMIT 5;[0m
SQLResult: [33;1m[1;3m[('Hornet 4 Drive', '21.4'), ('Hornet 4 Drive', '21.4'), ('Hornet 4 Drive', '21.4'), ('Hornet 4 Drive', '21.4'), ('Hornet 4 Drive', '21.4')][0m
Answer:[32;1m[1;3m The most efficient car with 6 cylinders is the Hornet 4 Drive with an mpg of 21.4.[0m
[1m> Finished chain.[0m


' The most efficient car with 6 cylinders is the Hornet 4 Drive with an mpg of 21.4.'

In [20]:
read_query('SELECT DISTINCT model, mpg FROM cars WHERE cyl = 6 ORDER BY mpg DESC LIMIT 50;')

('Hornet 4 Drive', '21.4')
('Mazda RX4', '21')
('Mazda RX4 Wag', '21')
('Ferrari Dino', '19.7')
('Merc 280', '19.2')
('Valiant', '18.1')
('Merc 280C', '17.8')


### Disclaimer:
Note: OpenAI provides a free API key for initial testing. Once you move to a paid subscription, calling the API in the way demonstrated in this example will incur monetary charges. Refer to OpenAI's pricing information for details.

Be aware that information, such as files and databases to train OpenAI's LLM can become public if applied in the way this demo demonstrates. Refer to OpenAI's usage policy for details.

This demo is for educational purposes only and for demonstrating machine learning methods. The author makes no claims that the outcomes shown here or any outcomes that could be produced by this method are accurate or reliable.