The following files are needed for this notebook:


1.   db/train.db
2.   db/test.db
3.   db/bad_indices.csv

Please change the paths in the cells where ever these files are referenced to your drive path. 



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers

In [None]:
!pip install simpletransformers

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset
dataset = load_dataset("wikisql")
print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}, Test: {len(dataset['test'])}")

In [None]:
# Test the queries on train.db to check db connection
import sqlite3
conn = sqlite3.connect('/content/drive/MyDrive/db/train.db') # change this to your drive path
cur = conn.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
print(cur.fetchall())

In [3]:
# Function for evaluation
import sqlite3
import re
import sys

def is_number(s):
    try:
        if float(s) or int(s):
          return True
    except ValueError:
        return False

def evaluate(path_to_db: str, test_data:list, predictions:list, bad_indices:list):
  '''
  path_to_db - The path in the filesystem which has the db file. Example: /content/drive/MyDrive/db/test.db
  test_data - The test data list. This is the list contained in dataset['test']. This should contain all of the original keys.
  predictions - The list containing predictions by the model. The indices shoould match with the test data
  bad_indices - List of indices in the test data which are to be skipped
  '''
  try:
    conn = sqlite3.connect(path_to_db)
  except Error as e:
    print(f'Error connecting to db: {e}')
  cur = conn.cursor()
  count = 0
  bad_index = []
  for k, sample in enumerate(test_data):
    if k in bad_indices:
      continue
    ground_truth = sample['sql']['human_readable']
    prediction = predictions[k]
    parsed_gt = parse_query(ground_truth, column_index=sample['sql']['conds']['column_index'], 
                            column_names=sample['table']['header'],
                            conditions=sample['sql']['conds']['condition'],
                            agg=sample['sql']['agg'],
                            table_id=sample['table']['id'])
    parsed_prediction = parse_query(prediction, column_index=sample['sql']['conds']['column_index'], 
                            column_names=sample['table']['header'],
                            conditions=sample['sql']['conds']['condition'],
                            agg=sample['sql']['agg'],
                            table_id=sample['table']['id'])
    try:
      cur.execute(parsed_gt)
      gt_rows = cur.fetchall()
    except:
      continue
    try:
      cur.execute(parsed_prediction)
      predicted_rows = cur.fetchall()
      if len(gt_rows) == len(predicted_rows):
        count += 1
    except:
      e = sys.exc_info()[0]
      print(f"Execution failed for index {k}, GT: {ground_truth}, Predicted: {prediction}")
  eval_score = count / (len(test_data)-len(bad_indices))
  return eval_score

def parse_query(query:str, column_index:list, column_names:list, conditions:list, agg:int, table_id:str):
  '''
  query - the query to be parsed
  column_index - data['sql']['conds']['column_index']
  column_names - data['table']['header']
  conditions - data['sql']['conds']['condition']
  agg - data['sql']['agg']
  table_id - data['table']['id']
  '''
  agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
  query = query.replace('"', '')
  table_name = 'table_'+table_id.replace('-','_')
  column_to_index = {}
  for i, column in enumerate(column_names):
    column_to_index[column] = str(i)
  sorted_column_to_index = {}
  for i in sorted(column_to_index, key=len, reverse=True):
      sorted_column_to_index[i] = column_to_index[i]
  column_to_index = sorted_column_to_index
  # print(column_to_index)
  seen = []
  for condition in conditions:
    # print(f'Condition: {condition}, Query: {query}')
    if condition in seen:
      continue
    condition = condition.replace('"', '')
    if not is_number(condition):
      query = query.replace(condition, '"'+condition+'"')
      # query = re.sub(r"\b%s\b" % condition , '"'+condition+'"', query)
    seen.append(condition)
  for column, column_index in column_to_index.items():
    query = query.replace(column, 'col'+column_index)
    # query = re.sub(r"\b%s\b" % column, 'col'+column_index, query)
  query = query.replace('table', table_name)
  agg_op = agg_ops[agg]
  if agg != 0:
    idx = query.find(agg_op)
    start = idx + len(agg_op) + 1
    end = start
    while end < len(query) and query[end] != ' ':
      end += 1
    agg_columns = query[start:end]
    query = query.replace(query[idx:end], agg_op+'('+agg_columns+')')
  return query


In [None]:
import csv
import numpy as np

# change this to your drive path
with open('/content/drive/MyDrive/db/bad_indices.csv', newline='') as f: 
    reader = csv.reader(f)
    header = next(reader)
    bad_indices = [int(row[1])for row in reader if row]
print(len(bad_indices))

In [5]:
test_data = dataset['test']
path_to_db = '/content/drive/MyDrive/db/test.db' # change this to your drive path

In [6]:
from simpletransformers.seq2seq import Seq2SeqModel

# use this for BART models:
model = Seq2SeqModel(
    encoder_decoder_type='bart',
    encoder_decoder_name="/content/drive/MyDrive/cs685_models/BART-INTER-FINE-TUNED/final_outputs_bart_int_50", #add your trained model path here
    use_cuda=True,
)

# use this for other models:
# model = Seq2SeqModel(
#     encoder_type='distilbert',
#     encoder_name="/content/drive/MyDrive/cs685_models/DISTILBERT-BERT/final_outputs_distilbert/encoder",
#     decoder_name="/content/drive/MyDrive/cs685_models/DISTILBERT-BERT/final_outputs_distilbert/decoder",
#     use_cuda=True,
# )

In [7]:
# Get the predictions from the loaded model

queries = []
for i, sample in enumerate(dataset['test']):
  queries.append(sample['sql']['human_readable'])
predictions = []
predictions = model.predict(queries)
assert len(predictions) == len(queries)

In [None]:
# A sanity check for the predictions

import random
index = random.randint(0,15000)
print(f"GT: {queries[index]}, Prediction: {predictions[index]}")

In [None]:
# Get the execution score

eval_score = evaluate(path_to_db, test_data, predictions, bad_indices)

In [None]:
print(f"Score: {eval_score}")

In [None]:
# Check if the ground truth and predicted queries match

def compare_queries(queries:str, predictions:str):
  if len(queries) != len(predictions):
    raise ValueError("Lengths don't match")
  score = 0.0
  for i in range(len(queries)):
    if queries[i] == predictions[i]:
        score += 1
  score = score / len(queries)
  return score
score = compare_queries(queries, predictions)
print(f"Score: {score}")