In [2]:
import json

with open('../configs/openai_api_key.json') as f:
    config = json.load(f)

In [3]:
import os
import openai

openai.api_key = config['key']
openai.Model.list()

<OpenAIObject list at 0x7f86bee3fa90> JSON: {
  "data": [
    {
      "created": 1649358449,
      "id": "babbage",
      "object": "model",
      "owned_by": "openai",
      "parent": null,
      "permission": [
        {
          "allow_create_engine": false,
          "allow_fine_tuning": false,
          "allow_logprobs": true,
          "allow_sampling": true,
          "allow_search_indices": false,
          "allow_view": true,
          "created": 1669085501,
          "group": null,
          "id": "modelperm-49FUp5v084tBB49tC4z8LPH5",
          "is_blocking": false,
          "object": "model_permission",
          "organization": "*"
        }
      ],
      "root": "babbage"
    },
    {
      "created": 1649359874,
      "id": "davinci",
      "object": "model",
      "owned_by": "openai",
      "parent": null,
      "permission": [
        {
          "allow_create_engine": false,
          "allow_fine_tuning": false,
          "allow_logprobs": true,
          "allow_sa

In [89]:
model_id = 'text-davinci-003'

prompt = '''
"""SQL tables (and columns):
* Customers(customer_id, signup_date)
* Streaming(customer_id, video_id, watch_date, watch_minutes)

A well-written SQL query that lists customers who signed up during March 2020 and watched more than 50 hours of video in their first 30 days:"""'''

prompt1 = '''"""Table customers, columns = [CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId]
Create a MySQL query for all customers in Texas named Jane
"""'''

prompt3 = '''
"""SQL tables (and columns):
* Customers(CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId)

Create a MySQL query for all customers in Texas named Jane:"""'''


prompts = [prompt, prompt1, prompt1, prompt3]
response = openai.Completion.create(
    model=model_id,
    prompt=prompts,
    max_tokens=256)

In [90]:
response

<OpenAIObject text_completion id=cmpl-6ymmCPTuDtXRrcWoVkLgO3v5OOYl0 at 0x7f81f0d8c810> JSON: {
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "text": "\n\n\nSELECT *\nFROM Customers\nWHERE signup_date >= '2020-03-01'\n  AND signup_date <= '2020-03-31'\n  AND (SELECT sum(watch_minutes)\n       FROM Streaming\n       WHERE customer_id = Customers.customer_id\n         AND watch_date >= Customers.signup_date\n         AND watch_date <= date_add(Customers.signup_date, interval 30 day)) > 50*60"
    },
    {
      "finish_reason": "stop",
      "index": 1,
      "logprobs": null,
      "text": "\n\nSELECT *\nFROM customers\nWHERE City = 'Texas' and FirstName = 'Jane'"
    },
    {
      "finish_reason": "stop",
      "index": 2,
      "logprobs": null,
      "text": "SELECT * FROM customers\nWHERE State = \"TX\"\nAND FirstName = \"Jane\""
    },
    {
      "finish_reason": "stop",
      "index": 3,
      "logprobs": null,
      "text": "

There is definitely some non-determinism.

In [91]:
with open('../../../Data/CoSQL/sql_state_tracking/cosql_dev.json') as f:
    cosql = json.load(f)
len(cosql)

293

In [109]:
import pandas as pd

def process_cosql(dataset, scope: str) -> pd.DataFrame:
    scope_valid = ['all', 'final', 'interaction']
    if scope not in scope_valid:
        raise ValueError(f'scope must be one of {scope_valid}')

    processed_input = []
    processed_target = []
    processed_db = []
    
    for dialog in dataset:
        if scope in ['all', 'final']:
            processed_input.append(dialog['final']['utterance'])
            processed_target.append(dialog['final']['query'])
        if scope in ['all', 'interaction']:
            for turn in dialog['interaction']:
                processed_input.append(turn['utterance'])
                processed_target.append(turn['query'])
        processed_db.append(dialog['database_id'])
                
    processed_dataset = pd.DataFrame({
        'input': processed_input,
        'target': processed_target,
        'db': processed_db
    })

    return processed_dataset

In [163]:
processed_cosql = process_cosql(cosql, scope='final')
processed_cosql = processed_cosql.groupby('input').first().reset_index() # select first target for duplicate inputs
processed_cosql

Unnamed: 0,input,target,db
0,Among the cars with more than lowest horsepowe...,"SELECT T2.MakeId , T2.Make FROM CARS_DATA AS ...",car_1
1,Find all airlines that have fewer than 200 fli...,SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLI...,flight_2
2,Find all airlines that have flights from airpo...,SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLI...,flight_2
3,Find number of pets owned by students who are ...,SELECT count(*) FROM student AS T1 JOIN has_pe...,pets_1
4,Find the abbreviation and country of the airli...,"SELECT T1.Abbreviation , T1.Country FROM AIRL...",flight_2
...,...,...,...
241,find the minimum and maximum number of product...,"SELECT min(Number_products) , max(Number_prod...",employee_hire_evaluation
242,find the names of loser and winner who played ...,"SELECT winner_name , loser_name FROM matches ...",wta_1
243,find the package option of the tv channel that...,SELECT package_option FROM TV_Channel WHERE id...,tvshow
244,which countries' tv channels are not playing a...,SELECT country FROM TV_Channel EXCEPT SELECT T...,tvshow


In [164]:
import re 

def get_tables(query):
    query = query.lower()
    query_toks = query.split()
    
    tables = [query_toks[i+1] for i, tok in enumerate(query_toks) if tok == "from"]
    
    if 'join' in query_toks:
        tables += [query_toks[i+1] for i, tok in enumerate(query_toks) if tok == "join"]
    
    tables = [re.sub('[()]', '', table) for table in tables]
    return list(set(tables))

def get_cols(db, table):
    cols = []
    with open(f'../../../Data/CoSQL/database/{db}/schema.sql') as f:
        lines = f.readlines()

    schema_starting_ind = [i for i, line in enumerate(lines) if f'CREATE TABLE "{table}"' in line]

    if schema_starting_ind:
        schema_starting_ind = schema_starting_ind[0]

    for i in range(schema_starting_ind + 1, len(lines)):
        if ';' in lines[i] or 'primary key' in lines[i]: break
        cols.append(lines[i].split()[0][1:-1])
        
    return cols

In [165]:
processed_cosql['tables'] = processed_cosql['target'].map(get_tables)
processed_cosql

Unnamed: 0,input,target,db,tables
0,Among the cars with more than lowest horsepowe...,"SELECT T2.MakeId , T2.Make FROM CARS_DATA AS ...",car_1,"[car_names, cars_data]"
1,Find all airlines that have fewer than 200 fli...,SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLI...,flight_2,"[airlines, flights]"
2,Find all airlines that have flights from airpo...,SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLI...,flight_2,"[airlines, flights]"
3,Find number of pets owned by students who are ...,SELECT count(*) FROM student AS T1 JOIN has_pe...,pets_1,"[student, has_pet]"
4,Find the abbreviation and country of the airli...,"SELECT T1.Abbreviation , T1.Country FROM AIRL...",flight_2,"[airlines, flights]"
...,...,...,...,...
241,find the minimum and maximum number of product...,"SELECT min(Number_products) , max(Number_prod...",employee_hire_evaluation,[shop]
242,find the names of loser and winner who played ...,"SELECT winner_name , loser_name FROM matches ...",wta_1,[matches]
243,find the package option of the tv channel that...,SELECT package_option FROM TV_Channel WHERE id...,tvshow,"[cartoon, tv_channel]"
244,which countries' tv channels are not playing a...,SELECT country FROM TV_Channel EXCEPT SELECT T...,tvshow,"[cartoon, tv_channel]"


In [175]:
processed_cosql.iloc[0]['target']

'SELECT T2.MakeId ,  T2.Make FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id  =  T2.MakeId WHERE T1.Horsepower  >  (SELECT min(Horsepower) FROM CARS_DATA) AND T1.Cylinders  <=  3;'

In [171]:
def get_tables_cols(db_table_row):
    db = db_table_row['db']
    tables = db_table_row['tables']
    
    tables_cols = {table: get_cols(db, table) for table in tables}
    print(tables_cols)
        

In [172]:
processed_cosql[['db', 'tables']].iloc[:5].apply(get_tables_cols, axis=1)

FileNotFoundError: [Errno 2] No such file or directory: '../../../Data/CoSQL/database/car_1/schema.sql'

In [181]:
with open('../../../Data/CoSQL/tables.json') as f:
    tables = json.load(f)
len(tables)

178