In [6]:
import json

with open('../configs/openai_api_key.json') as f:
    config = json.load(f)

In [7]:
import os
import openai

openai.api_key = config['key']
openai.Model.list()

<OpenAIObject list at 0x1117a5180> JSON: {
  "data": [
    {
      "created": 1649358449,
      "id": "babbage",
      "object": "model",
      "owned_by": "openai",
      "parent": null,
      "permission": [
        {
          "allow_create_engine": false,
          "allow_fine_tuning": false,
          "allow_logprobs": true,
          "allow_sampling": true,
          "allow_search_indices": false,
          "allow_view": true,
          "created": 1669085501,
          "group": null,
          "id": "modelperm-49FUp5v084tBB49tC4z8LPH5",
          "is_blocking": false,
          "object": "model_permission",
          "organization": "*"
        }
      ],
      "root": "babbage"
    },
    {
      "created": 1649359874,
      "id": "davinci",
      "object": "model",
      "owned_by": "openai",
      "parent": null,
      "permission": [
        {
          "allow_create_engine": false,
          "allow_fine_tuning": false,
          "allow_logprobs": true,
          "allow_sampl

In [8]:
model_id = 'text-davinci-003'

prompt = '''
"""SQL tables (and columns):
* Customers(customer_id, signup_date)
* Streaming(customer_id, video_id, watch_date, watch_minutes)

A well-written SQL query that lists customers who signed up during March 2020 and watched more than 50 hours of video in their first 30 days:"""'''

prompt1 = '''"""Table customers, columns = [CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId]
Create a MySQL query for all customers in Texas named Jane
"""'''

prompt3 = '''
"""SQL tables (and columns):
* Customers(CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId)

Create a MySQL query for all customers in Texas named Jane:"""'''


prompts = [prompt, prompt1, prompt1, prompt3]
response = openai.Completion.create(
    model=model_id,
    prompt=prompts,
    max_tokens=256)

In [9]:
response

<OpenAIObject text_completion id=cmpl-71fy9KScx4ze7rkThkMjBOdnGeuPb at 0x130d93c20> JSON: {
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "text": "\n\nSELECT c.customer_id\nFROM Customers c\nJOIN Streaming s\nON c.customer_id = s.customer_id\nWHERE c.signup_date BETWEEN '2020-03-01' and '2020-03-31'\nAND s.watch_date BETWEEN c.signup_date AND DATE_ADD(c.signup_date, INTERVAL 30 DAY)\nGROUP BY c.customer_id\nHAVING SUM(s.watch_minutes) > 5000;"
    },
    {
      "finish_reason": "stop",
      "index": 1,
      "logprobs": null,
      "text": "\n\nSELECT * \nFROM Customers \nWHERE FirstName = 'Jane' AND State = 'Texas';"
    },
    {
      "finish_reason": "stop",
      "index": 2,
      "logprobs": null,
      "text": "\n\nSELECT *\nFROM customers\nWHERE FirstName = 'Jane' AND State = 'Texas';"
    },
    {
      "finish_reason": "stop",
      "index": 3,
      "logprobs": null,
      "text": "\n\nSELECT * FROM Customers\nWHERE Firs

Seems like this model is quite deterministic across the 3 prompts.

Let's continue by processing the dev set for inference on Codex.

In [16]:
import os
path_to_cosql = os.environ['COSQL_PATH']
with open(f'{path_to_cosql}/sql_state_tracking/cosql_dev.json') as f:
    dev_cosql = json.load(f)
len(dev_cosql)

293

In [28]:
import pandas as pd
from process_sql import get_schema

def process_cosql(dataset, scope: str) -> pd.DataFrame:
    scope_valid = ['all', 'final', 'interaction']
    if scope not in scope_valid:
        raise ValueError(f'scope must be one of {scope_valid}')

    processed_input = []
    processed_target = []
    processed_db = []
    schema = []
    
    for dialog in dataset:
        if scope in ['all', 'final']:
            processed_input.append(dialog['final']['utterance'])
            processed_target.append(dialog['final']['query'])
        if scope in ['all', 'interaction']:
            for turn in dialog['interaction']:
                processed_input.append(turn['utterance'])
                processed_target.append(turn['query'])
        processed_db.append(dialog['database_id'])
        schema.append(get_schema(f"{path_to_cosql}/database/{dialog['database_id']}/{dialog['database_id']}.sqlite"))

                
    processed_dataset = pd.DataFrame({
        'input': processed_input,
        'target': processed_target,
        'db': processed_db,
        'schema': schema
    })

    return processed_dataset

In [29]:
processed_cosql = process_cosql(dev_cosql, scope='final')
#processed_cosql = processed_cosql.groupby('input').first().reset_index() # select first target for duplicate inputs
processed_cosql

Unnamed: 0,input,target,db,schema
0,How many car models are produced by each maker...,"SELECT Count(*) , T2.FullName , T2.id FROM M...",car_1,"{'continents': ['contid', 'continent'], 'count..."
1,List singer names and number of concerts for e...,"SELECT T2.name , count(*) FROM singer_in_conc...",concert_singer,"{'stadium': ['stadium_id', 'location', 'name',..."
2,Show ids for all templates that are used by mo...,SELECT template_id FROM Documents GROUP BY tem...,cre_Doc_Template_Mgt,"{'ref_template_types': ['template_type_code', ..."
3,Find the first name of the students who perman...,SELECT T1.first_name FROM Students AS T1 JOIN ...,student_transcripts_tracking,"{'addresses': ['address_id', 'line_1', 'line_2..."
4,Show names for all stadiums except for stadium...,SELECT name FROM stadium EXCEPT SELECT T2.name...,concert_singer,"{'stadium': ['stadium_id', 'location', 'name',..."
...,...,...,...,...
288,"Show names, results and bulgarian commanders o...","SELECT name , RESULT , bulgarian_commander F...",battle_death,"{'battle': ['id', 'name', 'date', 'bulgarian_c..."
289,Show the ids of high schoolers who have friend...,SELECT student_id FROM Friend INTERSECT SELECT...,network_1,"{'highschooler': ['id', 'name', 'grade'], 'fri..."
290,How many documents are using the template with...,SELECT count(*) FROM Documents AS T1 JOIN Temp...,cre_Doc_Template_Mgt,"{'ref_template_types': ['template_type_code', ..."
291,"For the cars with 4 cylinders, which model has...",SELECT T1.Model FROM CAR_NAMES AS T1 JOIN CARS...,car_1,"{'continents': ['contid', 'continent'], 'count..."


In [50]:
def construct_prompt(row):
    prompt = ''
    for table in row['schema']:
        prompt += f"Table {table}, columns = [{', '.join(row['schema'][table])}]\n"
    prompt += '\n'
    
    prompt += f"Create a MySQL query to answer the following question: {row['input']}."
    return '"""' + prompt + '"""'

In [53]:
processed_cosql['prompt'] = processed_cosql.apply(construct_prompt, axis=1)
processed_cosql

Unnamed: 0,input,target,db,schema,prompt
0,How many car models are produced by each maker...,"SELECT Count(*) , T2.FullName , T2.id FROM M...",car_1,"{'continents': ['contid', 'continent'], 'count...","""""""Table continents, columns = [contid, contin..."
1,List singer names and number of concerts for e...,"SELECT T2.name , count(*) FROM singer_in_conc...",concert_singer,"{'stadium': ['stadium_id', 'location', 'name',...","""""""Table stadium, columns = [stadium_id, locat..."
2,Show ids for all templates that are used by mo...,SELECT template_id FROM Documents GROUP BY tem...,cre_Doc_Template_Mgt,"{'ref_template_types': ['template_type_code', ...","""""""Table ref_template_types, columns = [templa..."
3,Find the first name of the students who perman...,SELECT T1.first_name FROM Students AS T1 JOIN ...,student_transcripts_tracking,"{'addresses': ['address_id', 'line_1', 'line_2...","""""""Table addresses, columns = [address_id, lin..."
4,Show names for all stadiums except for stadium...,SELECT name FROM stadium EXCEPT SELECT T2.name...,concert_singer,"{'stadium': ['stadium_id', 'location', 'name',...","""""""Table stadium, columns = [stadium_id, locat..."
...,...,...,...,...,...
288,"Show names, results and bulgarian commanders o...","SELECT name , RESULT , bulgarian_commander F...",battle_death,"{'battle': ['id', 'name', 'date', 'bulgarian_c...","""""""Table battle, columns = [id, name, date, bu..."
289,Show the ids of high schoolers who have friend...,SELECT student_id FROM Friend INTERSECT SELECT...,network_1,"{'highschooler': ['id', 'name', 'grade'], 'fri...","""""""Table highschooler, columns = [id, name, gr..."
290,How many documents are using the template with...,SELECT count(*) FROM Documents AS T1 JOIN Temp...,cre_Doc_Template_Mgt,"{'ref_template_types': ['template_type_code', ...","""""""Table ref_template_types, columns = [templa..."
291,"For the cars with 4 cylinders, which model has...",SELECT T1.Model FROM CAR_NAMES AS T1 JOIN CARS...,car_1,"{'continents': ['contid', 'continent'], 'count...","""""""Table continents, columns = [contid, contin..."


In [62]:
def get_sql_codex(prompt):
    response = openai.Completion.create(model=model_id, prompt=prompt, max_tokens=256)
    return response.choices[0].text

In [69]:
from tqdm import tqdm
preds = []
for prompt in tqdm(processed_cosql['prompt']):
    preds.append(get_sql_codex(prompt))
processed_cosql['preds'] = preds
processed_cosql

Unnamed: 0,input,target,db,schema,prompt,preds
0,How many car models are produced by each maker...,"SELECT Count(*) , T2.FullName , T2.id FROM M...",car_1,"{'continents': ['contid', 'continent'], 'count...","""""""Table continents, columns = [contid, contin...","\n\nSELECT car_makers.fullname, COUNT(model_li..."
1,List singer names and number of concerts for e...,"SELECT T2.name , count(*) FROM singer_in_conc...",concert_singer,"{'stadium': ['stadium_id', 'location', 'name',...","""""""Table stadium, columns = [stadium_id, locat...","\n\nSELECT s.name AS 'Singer Name', COUNT(sic...."
2,Show ids for all templates that are used by mo...,SELECT template_id FROM Documents GROUP BY tem...,cre_Doc_Template_Mgt,"{'ref_template_types': ['template_type_code', ...","""""""Table ref_template_types, columns = [templa...",\n\nSELECT template_id\nFROM templates\nWHERE ...
3,Find the first name of the students who perman...,SELECT T1.first_name FROM Students AS T1 JOIN ...,student_transcripts_tracking,"{'addresses': ['address_id', 'line_1', 'line_2...","""""""Table addresses, columns = [address_id, lin...",\nSELECT first_name\nFROM students\nINNER JOIN...
4,Show names for all stadiums except for stadium...,SELECT name FROM stadium EXCEPT SELECT T2.name...,concert_singer,"{'stadium': ['stadium_id', 'location', 'name',...","""""""Table stadium, columns = [stadium_id, locat...",\n\nSELECT name\nFROM stadium\nWHERE stadium_i...
...,...,...,...,...,...,...
288,"Show names, results and bulgarian commanders o...","SELECT name , RESULT , bulgarian_commander F...",battle_death,"{'battle': ['id', 'name', 'date', 'bulgarian_c...","""""""Table battle, columns = [id, name, date, bu...","\n\nSELECT b.name, b.result, b.bulgarian_comma..."
289,Show the ids of high schoolers who have friend...,SELECT student_id FROM Friend INTERSECT SELECT...,network_1,"{'highschooler': ['id', 'name', 'grade'], 'fri...","""""""Table highschooler, columns = [id, name, gr...",\n\nSELECT h.id \nFROM highschooler h\nJOIN fr...
290,How many documents are using the template with...,SELECT count(*) FROM Documents AS T1 JOIN Temp...,cre_Doc_Template_Mgt,"{'ref_template_types': ['template_type_code', ...","""""""Table ref_template_types, columns = [templa...",\n\nSELECT COUNT(DISTINCT d.document_id)\nFROM...
291,"For the cars with 4 cylinders, which model has...",SELECT T1.Model FROM CAR_NAMES AS T1 JOIN CARS...,car_1,"{'continents': ['contid', 'continent'], 'count...","""""""Table continents, columns = [contid, contin...",\n\nSELECT model\nFROM cars_data \nJOIN model_...


In [73]:
processed_cosql['preds'] = processed_cosql['preds'].map(lambda x: x.replace('\n', '')) # remove newlines

In [76]:
codex_path = os.environ['TEXT_DAVINCI_CHECKPOINT_PATH']
with open(f"{codex_path}/predict_final.txt","w") as f:
  for data in processed_cosql['preds']:
    f.write(data + "\n")

**Evaluate**

In [77]:
!python evaluation.py --gold="{path_to_cosql}/sql_state_tracking/dev_gold_final.txt" --pred="{codex_path}/predict_final.txt" --db="{path_to_cosql}/database" --table="{path_to_cosql}/tables.json" --etype="all"

medium pred: SELECT car_makers.fullname, COUNT(model_list.model) FROM car_makers JOIN model_list ON car_makers.id = model_list.maker GROUP BY car_makers.fullname;
medium gold: SELECT Count(*) ,  T2.FullName ,  T2.id FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker  =  T2.Id GROUP BY T2.id;

eval_err_num:1
medium pred: SELECT s.name AS 'Singer Name', COUNT(sic.concert_id) AS 'Number of Concerts' FROM singer sJOIN singer_in_concert sic ON s.singer_id = sic.singer_idGROUP BY s.name;
medium gold: SELECT T2.name ,  count(*) FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id  =  T2.singer_id GROUP BY T2.singer_id

eval_err_num:2
easy pred: SELECT template_idFROM templatesWHERE template_id IN (SELECT template_id                      FROM documents                      GROUP BY template_id                      HAVING COUNT(document_id) > 1)
easy gold: SELECT template_id FROM Documents GROUP BY template_id HAVING count(*)  >  1

eval_err_num:3
extra pred: SELECT first_nameFRO

In [78]:
import nltk
from process_sql import Schema, get_sql
from evaluation import build_valid_col_units, rebuild_sql_val, rebuild_sql_col, build_foreign_key_map_from_json, Evaluator, eval_exec_match
kmaps = build_foreign_key_map_from_json(f"{path_to_cosql}/tables.json")
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/amanchopra/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [79]:
def evaluate(gold, predict, db_dir, kmaps):
    with open(gold) as f:
        glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]

    with open(predict) as f:
        plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
    # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
    # glist = [("SELECT max(SHARE) ,  min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]

    all = 0
    execute = 0
    exact = 0
    for p, g in zip(plist, glist):
        p_str = p[0]
        g_str, db = g
        db_name = db
        db = os.path.join(db_dir, db, db + ".sqlite")
        schema = Schema(get_schema(db))
        g_sql = get_sql(schema, g_str)
        all += 1.0

        try:
            p_sql = get_sql(schema, p_str)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
            "except": None,
            "from": {
                "conds": [],
                "table_units": []
            },
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [
                False,
                []
            ],
            "union": None,
            "where": []
            }

        # rebuild sql for value evaluation
        kmap = kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
        if exec_score:
            execute += 1.0

        evaluator = Evaluator()
        exact_score = evaluator.eval_exact_match(p_sql, g_sql)
        if exact_score:
            exact += 1.0

    return {"execute":execute/all, "exact": exact/all}

In [80]:
evaluate(f"{path_to_cosql}/sql_state_tracking/dev_gold_final.txt", f"{codex_path}/predict_final.txt", f"{path_to_cosql}/database", kmaps)

{'execute': 0.05460750853242321, 'exact': 0.03754266211604096}