In [1]:
import openai
from google.cloud import secretmanager_v1

secretmanager_client = secretmanager_v1.SecretManagerServiceClient()
api_key = secretmanager_client.access_secret_version(name='projects/84043197426/secrets/openai-api-key/versions/1').payload.data

openai.api_key = api_key.decode("utf-8")
openai.Model.list()

<OpenAIObject list at 0x12e6e62c0> JSON: {
  "data": [
    {
      "created": 1649358449,
      "id": "babbage",
      "object": "model",
      "owned_by": "openai",
      "parent": null,
      "permission": [
        {
          "allow_create_engine": false,
          "allow_fine_tuning": false,
          "allow_logprobs": true,
          "allow_sampling": true,
          "allow_search_indices": false,
          "allow_view": true,
          "created": 1669085501,
          "group": null,
          "id": "modelperm-49FUp5v084tBB49tC4z8LPH5",
          "is_blocking": false,
          "object": "model_permission",
          "organization": "*"
        }
      ],
      "root": "babbage"
    },
    {
      "created": 1649359874,
      "id": "davinci",
      "object": "model",
      "owned_by": "openai",
      "parent": null,
      "permission": [
        {
          "allow_create_engine": false,
          "allow_fine_tuning": false,
          "allow_logprobs": true,
          "allow_sampl

In [26]:
model_id = 'text-davinci-003'

prompt = '''
"""SQL tables (and columns):
* Customers(customer_id, signup_date)
* Streaming(customer_id, video_id, watch_date, watch_minutes)

A well-written SQL query that lists customers who signed up during March 2020 and watched more than 50 hours of video in their first 30 days:"""'''

prompt1 = '''"""Table customers, columns = [CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId]
Create a MySQL query for all customers in Texas named Jane
"""'''

prompt3 = '''
"""SQL tables (and columns):
* Customers(CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId)

Create a MySQL query for all customers in Texas named Jane:"""'''


prompts = [prompt, prompt1, prompt1, prompt3]
response = openai.Completion.create(
    model=model_id,
    prompt=prompts,
    max_tokens=256)

In [18]:
response

<OpenAIObject text_completion id=cmpl-76QHFXCsfv8Qx9R2tWdPjLyf8dXPk at 0x12eb917c0> JSON: {
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "text": "\n\nSELECT c.customer_id\nFROM Customers c\nINNER JOIN Streaming s\n   ON c.customer_id = s.customer_id\nWHERE c.signup_date BETWEEN '2020-03-01' AND '2020-03-31'\n  AND s.watch_date <= DATE_ADD(c.signup_date, INTERVAL 30 DAY)\nGROUP BY c.customer_id\nHAVING SUM(s.watch_minutes) > 50 * 60;"
    },
    {
      "finish_reason": "stop",
      "index": 1,
      "logprobs": null,
      "text": "\nSELECT * FROM customers WHERE State = 'Texas' AND FirstName = 'Jane';"
    },
    {
      "finish_reason": "stop",
      "index": 2,
      "logprobs": null,
      "text": "\nSELECT * \nFROM customers \nWHERE FirstName = 'Jane' \nAND State = 'Texas';"
    },
    {
      "finish_reason": "stop",
      "index": 3,
      "logprobs": null,
      "text": "\n\nSELECT * FROM Customers\nWHERE State = 'Texas'\n

Seems like this model is quite deterministic across the 3 prompts.

Let's continue by processing the dev set for inference on Codex.

In [40]:
import os
from google.cloud import storage
import json

client = storage.Client()
cosql_bucket = client.bucket(os.environ['COSQL_BUCKET'])
spider_bucket = client.bucket(os.environ['SPIDER_BUCKET'])

train_cosql = json.loads(cosql_bucket.get_blob('sql_state_tracking/cosql_dev.json').download_as_text())
train_spider = json.loads(spider_bucket.get_blob('train_spider.json').download_as_text())
print(len(train_cosql))
print(len(train_spider))

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [3]:
import pandas as pd
from process_sql import get_schema

In [4]:
def process_spider(dataset):
    processed_input = []
    processed_target = []
    processed_db = []
    schema = []

    for record in dataset:
        processed_input.append(record['question'])
        processed_target.append(record['query'])
        processed_db.append(record['db_id'])
        if not os.path.isfile(f"../temp/spider_db/{record['db_id']}.sqlite"):
            spider_bucket.get_blob(f"database/{record['db_id']}/{record['db_id']}.sqlite").download_to_filename(f"../temp/spider_db/{record['db_id']}.sqlite")
        schema.append(get_schema(f"../temp/spider_db/{record['db_id']}.sqlite"))
    
    processed_dataset = pd.DataFrame({
        'input': processed_input,
        'target': processed_target,
        'db': processed_db,
        'schema': schema
    })

    return processed_dataset


In [5]:
def process_cosql(dataset, scope: str) -> pd.DataFrame:
    scope_valid = ['all', 'final', 'interaction']
    if scope not in scope_valid:
        raise ValueError(f'scope must be one of {scope_valid}')

    processed_input = []
    processed_target = []
    processed_db = []
    schema = []
    
    for dialog in dataset:
        if scope in ['all', 'final']:
            processed_input.append(dialog['final']['utterance'])
            processed_target.append(dialog['final']['query'])
        if scope in ['all', 'interaction']:
            for turn in dialog['interaction']:
                processed_input.append(turn['utterance'])
                processed_target.append(turn['query'])
        processed_db.append(dialog['database_id'])
        if not os.path.isfile(f"../temp/cosql_db/{dialog['database_id']}.sqlite"):
            cosql_bucket.get_blob(f"database/{dialog['database_id']}/{dialog['database_id']}.sqlite").download_to_filename(f"../temp/cosql_db/{dialog['database_id']}.sqlite")
        schema.append(get_schema(f"../temp/cosql_db/{dialog['database_id']}.sqlite"))

                
    processed_dataset = pd.DataFrame({
        'input': processed_input,
        'target': processed_target,
        'db': processed_db,
        'schema': schema
    })

    return processed_dataset

In [5]:
# make temp dir to store sqlite files for connections and evaluation
try:
    os.mkdir('../temp')
except FileExistsError:
    pass

try:
    os.mkdir('../temp/spider_db')
except FileExistsError:
    pass

try:
    os.mkdir('../temp/cosql_db')
except FileExistsError:
    pass

#processed = process_cosql(train_cosql, scope='final')
processed = process_spider(train_spider)
processed

Unnamed: 0,input,target,db,schema
0,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56,department_management,"{'department': ['department_id', 'name', 'crea..."
1,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD...",department_management,"{'department': ['department_id', 'name', 'crea..."
2,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ...",department_management,"{'department': ['department_id', 'name', 'crea..."
3,What are the maximum and minimum budget of the...,"SELECT max(budget_in_billions) , min(budget_i...",department_management,"{'department': ['department_id', 'name', 'crea..."
4,What is the average number of employees of the...,SELECT avg(num_employees) FROM department WHER...,department_management,"{'department': ['department_id', 'name', 'crea..."
...,...,...,...,...
6995,What are all the company names that have a boo...,SELECT T1.company_name FROM culture_company AS...,culture_company,"{'book_club': ['book_club_id', 'year', 'author..."
6996,Show the movie titles and book titles for all ...,"SELECT T1.title , T3.book_title FROM movie AS...",culture_company,"{'book_club': ['book_club_id', 'year', 'author..."
6997,What are the titles of movies and books corres...,"SELECT T1.title , T3.book_title FROM movie AS...",culture_company,"{'book_club': ['book_club_id', 'year', 'author..."
6998,Show all company names with a movie directed i...,SELECT T2.company_name FROM movie AS T1 JOIN c...,culture_company,"{'book_club': ['book_club_id', 'year', 'author..."


In [6]:
def construct_prompt(row):
    prompt = ''
    for table in row['schema']:
        prompt += f"Table {table}, columns = [{', '.join(row['schema'][table])}]\n"
    prompt += '\n'
    
    prompt += f"Create a MySQL query to answer the following question: {row['input']}"
    return '"""' + prompt + '"""'

In [7]:
processed['prompt'] = processed.apply(construct_prompt, axis=1)
processed

Unnamed: 0,input,target,db,schema,prompt
0,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56,department_management,"{'department': ['department_id', 'name', 'crea...","""""""Table department, columns = [department_id,..."
1,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD...",department_management,"{'department': ['department_id', 'name', 'crea...","""""""Table department, columns = [department_id,..."
2,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ...",department_management,"{'department': ['department_id', 'name', 'crea...","""""""Table department, columns = [department_id,..."
3,What are the maximum and minimum budget of the...,"SELECT max(budget_in_billions) , min(budget_i...",department_management,"{'department': ['department_id', 'name', 'crea...","""""""Table department, columns = [department_id,..."
4,What is the average number of employees of the...,SELECT avg(num_employees) FROM department WHER...,department_management,"{'department': ['department_id', 'name', 'crea...","""""""Table department, columns = [department_id,..."
...,...,...,...,...,...
6995,What are all the company names that have a boo...,SELECT T1.company_name FROM culture_company AS...,culture_company,"{'book_club': ['book_club_id', 'year', 'author...","""""""Table book_club, columns = [book_club_id, y..."
6996,Show the movie titles and book titles for all ...,"SELECT T1.title , T3.book_title FROM movie AS...",culture_company,"{'book_club': ['book_club_id', 'year', 'author...","""""""Table book_club, columns = [book_club_id, y..."
6997,What are the titles of movies and books corres...,"SELECT T1.title , T3.book_title FROM movie AS...",culture_company,"{'book_club': ['book_club_id', 'year', 'author...","""""""Table book_club, columns = [book_club_id, y..."
6998,Show all company names with a movie directed i...,SELECT T2.company_name FROM movie AS T1 JOIN c...,culture_company,"{'book_club': ['book_club_id', 'year', 'author...","""""""Table book_club, columns = [book_club_id, y..."


In [8]:
def get_sql_codex(prompt):
    response = openai.Completion.create(model=model_id, prompt=prompt, max_tokens=256)
    return response.choices[0].text

In [10]:
from tqdm import tqdm
model_id = 'text-davinci-003'
preds = []
for prompt in tqdm(processed['prompt']):
    preds.append(get_sql_codex(prompt))
processed['preds'] = preds
processed

 25%|██▍       | 1737/7000 [54:08<2:44:03,  1.87s/it]


APIError: The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error. (Please include the request ID f9428ffbc088262f9196e48f0b379ce7 in your email.) {
  "error": {
    "message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error. (Please include the request ID f9428ffbc088262f9196e48f0b379ce7 in your email.)",
    "type": "server_error",
    "param": null,
    "code": null
  }
}
 500 {'error': {'message': 'The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error. (Please include the request ID f9428ffbc088262f9196e48f0b379ce7 in your email.)', 'type': 'server_error', 'param': None, 'code': None}} {'Date': 'Tue, 02 May 2023 17:47:11 GMT', 'Content-Type': 'application/json', 'Content-Length': '366', 'Connection': 'keep-alive', 'access-control-allow-origin': '*', 'openai-model': 'text-davinci-003', 'openai-organization': 'user-qzdxmvf3qek0wmywdx3ntmoy', 'openai-processing-ms': '588', 'openai-version': '2020-10-01', 'strict-transport-security': 'max-age=15724800; includeSubDomains', 'x-ratelimit-limit-requests': '60', 'x-ratelimit-limit-tokens': '150000', 'x-ratelimit-remaining-requests': '58', 'x-ratelimit-remaining-tokens': '149744', 'x-ratelimit-reset-requests': '1.055s', 'x-ratelimit-reset-tokens': '102ms', 'x-request-id': 'f9428ffbc088262f9196e48f0b379ce7', 'CF-Cache-Status': 'DYNAMIC', 'Server': 'cloudflare', 'CF-RAY': '7c12167ebc8d0f6b-EWR', 'alt-svc': 'h3=":443"; ma=86400, h3-29=":443"; ma=86400'}

In [21]:
processed['preds'] = preds + (processed.shape[0] - len(preds))*[None]

In [23]:
processed['preds'] = processed['preds'].map(lambda x: x.replace('\n', ' ') if x is not None else x) # remove newlines

In [25]:
with open(f"./predictions/text_davinci_003_spider.txt","w") as f:
  for data in processed['preds']:
    if not data: data = ''
    f.write(data + "\n")

**Evaluate**

In [32]:
bucket = spider_bucket

In [42]:
!python evaluation.py --gold="spider_gold_subset.txt" --pred="predictions/text_davinci_003_spider.txt" --db="../temp/spider_db" --table="tables.json" --etype="all" --bucket="{bucket.name}" --table_uri=True --gold_uri=True, --pred_uri=True, --plug_value=True

spider_gold_subset.txt
g
./predictions/text_davinci_003_spider.txt
Traceback (most recent call last):
  File "/Users/amanchopra/Documents/School/MS/Spring 2023/Capstone/Code/S23-CUDSI-Loreal-Text2SQL/evaluation/evaluation.py", line 973, in <module>
    evaluate(args.gold, args.pred, args.db, args.etype, kmaps, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint, bucket=bucket, gold_uri=args.gold_uri, predict_uri=args.pred_uri)
  File "/Users/amanchopra/Documents/School/MS/Spring 2023/Capstone/Code/S23-CUDSI-Loreal-Text2SQL/evaluation/evaluation.py", line 537, in evaluate
    plist_lines = bucket.get_blob(predict).download_as_text().split('\n')
AttributeError: 'NoneType' object has no attribute 'download_as_text'


In [51]:
# remove temp dir with sqllite files
import shutil
shutil.rmtree('../temp')