# Load Model

In [2]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

In [3]:
def translate_to_sql(text, tokenizer, model):
    inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)

    return tokenizer.decode(output[0], skip_special_tokens=True)

# Execution Accuracy & Exact Matching

In [5]:
import os
from google.cloud import storage
client = storage.Client()
bucket = client.bucket(os.environ.get('COSQL_BUCKET'))

In [6]:
import nltk
from process_sql import get_schema, Schema, get_sql
from evaluation import build_valid_col_units, rebuild_sql_val, rebuild_sql_col, build_foreign_key_map_from_json, Evaluator, eval_exec_match
kmaps = build_foreign_key_map_from_json("tables.json", table_uri=True, bucket=bucket)
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/amanchopra/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [7]:
def evaluate(gold, predict, db_dir, kmaps, bucket=None, gold_uri=False, predict_uri=False):
    if gold_uri:
        glist = bucket.get_blob(gold).download_as_text().split('\n')
    else:
        with open(gold) as f:
            glist = f.readlines()

    glist = [l.strip().split('\t') for l in glist if len(l.strip()) > 0]

    if predict_uri:
        plist = bucket.get_blob(predict).download_as_text().split('\n')
    else:
        with open(predict) as f:
            plist = f.readlines()

    plist = [l.strip().split('\t') for l in plist if len(l.strip()) > 0]

    all = 0
    execute = 0
    exact = 0
    for p, g in zip(plist, glist):
        p_str = p[0]
        g_str, db = g
        db_name = db
        db = os.path.join(db_dir, db_name + ".sqlite")
        schema = Schema(get_schema(db))
        g_sql = get_sql(schema, g_str)
        all += 1.0

        try:
            p_sql = get_sql(schema, p_str)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
            "except": None,
            "from": {
                "conds": [],
                "table_units": []
            },
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [
                False,
                []
            ],
            "union": None,
            "where": []
            }

        # rebuild sql for value evaluation
        kmap = kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
        if exec_score:
            execute += 1.0

        evaluator = Evaluator()
        exact_score = evaluator.eval_exact_match(p_sql, g_sql)
        if exact_score:
            exact += 1.0

    return {"execute":execute/all, "exact": exact/all}

# Evaluate

In [8]:
import json
lst = json.loads(bucket.get_blob('sql_state_tracking/cosql_dev.json').download_as_text())
len(lst)

293

In [9]:
import os
CKPT = os.environ['T5_WIKISQL_COSQL_CHECKPOINT_PATH']

In [10]:
# dev_gold_final
with open(f"./predictions/dev_gold_final.txt","w") as f:
  for data in lst:
    f.write(data['final']["query"] + "\t" + data["database_id"] + "\n")

## t5_wikisql_cosql

In [24]:
try:
    os.mkdir('../temp/ckpt')
except FileExistsError:
    pass

ckpt_bucket = client.bucket(os.environ.get('MODEL_CHECKPOINT_BUCKET'))
blobs = ckpt_bucket.list_blobs(os.environ.get('T5_WIKISQL_COSQL_CHECKPOINT_PATH'))

In [28]:
bucket.list_blobs()

<google.api_core.page_iterator.HTTPIterator at 0x12ce0ad30>

In [7]:
model_tuned = T5ForConditionalGeneration.from_pretrained(CKPT)
tokenizer_tuned = AutoTokenizer.from_pretrained(CKPT)

In [17]:
from tqdm import tqdm

In [32]:
with open(f"{CKPT}/predict_final.txt","w") as f:
  for data in tqdm(lst):
    f.write(translate_to_sql('translate to SQL: ' + data['final']["utterance"], tokenizer=tokenizer_tuned, model=model_tuned) + "\n")

100%|██████████| 293/293 [08:16<00:00,  1.70s/it]


In [33]:
!python evaluation.py --gold="{PATH_TO_COSQL}/sql_state_tracking/dev_gold_final.txt" --pred="{CKPT}/predict_final.txt" --db="{PATH_TO_COSQL}/database" --table="{PATH_TO_COSQL}/tables.json" --etype="all"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
eval_err_num:1
medium pred: SELECT count(*), T1.maker_full_name FROM car_model AS T1 JOIN maker AS T2 ON T1.car_model_id = T2.car_model_id GROUP BY T1.maker_full_name
medium gold: SELECT Count(*) ,  T2.FullName ,  T2.id FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker  =  T2.Id GROUP BY T2.id;

eval_err_num:2
medium pred: SELECT T1.Name, count(*) FROM singer AS T1 JOIN concert AS T2 ON T1.Song_ID = T2.Song_ID GROUP BY T1.Name
medium gold: SELECT T2.name ,  count(*) FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id  =  T2.singer_id GROUP BY T2.singer_id

eval_err_num:3
easy pred: SELECT Template_ID FROM Document_Templates GROUP BY Template_ID HAVING count(*) > 1
easy gold: SELECT template_

In [36]:
evaluate(f"{PATH_TO_COSQL}/sql_state_tracking/dev_gold_final.txt", f"{CKPT}/predict_final.txt", f"{PATH_TO_COSQL}/database", kmaps)

{'execute': 0.040955631399317405, 'exact': 0.040955631399317405}

## t5_small_cosql

In [21]:
CKPT = 't5_small_cosql'
model_tuned = T5ForConditionalGeneration.from_pretrained(CKPT)
tokenizer_tuned = AutoTokenizer.from_pretrained(CKPT)

In [None]:
with open("t5_small_cosql/predict_final.txt","w") as f:
  i = 1
  for data in lst:
    f.write(translate_to_sql('translate to SQL: ' + data["final"]["utterance"], tokenizer=tokenizer_tuned, model=model_tuned) + "\n")
    print(i)
    i += 1

In [22]:
evaluate("dev_gold_final.txt", "t5_small_cosql/predict_final.txt", "database", kmaps)

{'execute': 0.0, 'exact': 0.0}

## t5_small_spider_cosql

In [23]:
CKPT = 't5_small_spider_cosql'
model_tuned = T5ForConditionalGeneration.from_pretrained(CKPT)
tokenizer_tuned = AutoTokenizer.from_pretrained(CKPT)

In [None]:
with open("t5_small_spider_cosql/predict_final.txt","w") as f:
  i = 1
  for data in lst:
    f.write(translate_to_sql('translate to SQL: ' + data["final"]["utterance"], tokenizer=tokenizer_tuned, model=model_tuned) + "\n")
    print(i)
    i += 1

In [24]:
evaluate("gold_final.txt", "t5_small_spider_cosql/predict_final.txt", "database", kmaps)

{'execute': 0.040955631399317405, 'exact': 0.040955631399317405}