In [128]:
!pip install sqlparse



In [129]:
!pip install sql_metadata



In [130]:
import sqlparse

def extract_table_and_columns(sql):
    parsed = sqlparse.parse(sql)[0]
    tables = set()
    columns = set()

    print(parsed.tokens)

    for token in parsed.tokens:
        print(token.ttype)
        if token.ttype is sqlparse.tokens.Name or token.ttype in (sqlparse.tokens.Wildcard,):
            columns.add(str(token))
        if isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                if isinstance(identifier, sqlparse.sql.Identifier):
                    value = str(identifier)
                    if '.' in value:
                        table, column = value.split('.')
                        tables.add(table)
                        columns.add(column)
                    else:
                        tables.add(value)

    return tables, columns

In [131]:
sql1 = "SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15"
sql2 = "SELECT name FROM head WHERE born_state != 'California'"

tables1, columns1 = extract_table_and_columns(sql1)
#tables2, columns2 = extract_table_and_columns(sql2)

print("First Query - Tables:", tables1, "Columns:", columns1)
#print("Second Query - Tables:", tables2, "Columns:", columns2)

[<DML 'SELECT' at 0x7E5E729EF040>, <Whitespace ' ' at 0x7E5E729EEE60>, <Function 'avg(nu...' at 0x7E5E72983BC0>, <Whitespace ' ' at 0x7E5E729EC160>, <Keyword 'FROM' at 0x7E5E729EF400>, <Whitespace ' ' at 0x7E5E729EE020>, <Identifier 'depart...' at 0x7E5E72980B30>, <Whitespace ' ' at 0x7E5E729ED240>, <Where 'WHERE ...' at 0x7E5E72982E30>]
Token.Keyword.DML
Token.Text.Whitespace
None
Token.Text.Whitespace
Token.Keyword
Token.Text.Whitespace
None
Token.Text.Whitespace
None
First Query - Tables: set() Columns: set()


In [132]:
from sql_metadata import Parser

Parser("SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15").columns
Parser("SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15").tables

['department']

In [133]:
query = "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'"

print(Parser(query).columns)
print(Parser(query).tables)

['department.creation', 'department.department_id', 'management.department_id', 'management.head_id', 'head.head_id', 'head.born_state']
['department', 'management', 'head']


In [134]:
query = "SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1"

print(Parser(query).columns)
print(Parser(query).tables)

['creation']
['department']


In [135]:
query = "SELECT T1.name ,  T1.num_employees FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id WHERE T2.temporary_acting  =  'Yes'"

print(Parser(query).columns)
print(Parser(query).tables)

['department.name', 'department.num_employees', 'department.department_id', 'management.department_id', 'management.temporary_acting']
['department', 'management']


In [136]:
query = "SELECT count(*) FROM department WHERE department_id NOT IN (SELECT department_id FROM management)"

print(Parser(query).columns)
print(Parser(query).tables)

['department_id']
['department', 'management']


In [137]:
query = "SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T1.name  =  'Treasury' INTERSECT SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T1.name  =  'Homeland Security'"

print(Parser(query).columns)
print(Parser(query).tables)

question = "List the states where both the secretary of 'Treasury' department and the secretary of 'Homeland Security' were born."

['head.born_state', 'department.department_id', 'management.department_id', 'management.head_id', 'head.head_id', 'department.name']
['department', 'management', 'head']


In [138]:
query = "SELECT name FROM enzyme WHERE name LIKE \"%ALA%\""

print(Parser(query).columns)
print(Parser(query).tables)


['name', '%ALA%']
['enzyme']


In [139]:
query = "SELECT T1.department_id ,  T1.name ,  count(*) FROM management AS T2 JOIN department AS T1 ON T1.department_id  =  T2.department_id GROUP BY T1.department_id HAVING count(*)  >  1"

print(Parser(query).columns)
print(Parser(query).tables)

['department.department_id', 'department.name', 'management.department_id']
['management', 'department']


In [140]:
def calculate_accuracy(predicted_schema_tags, true_query):
  parsed_query = Parser(true_query)
  query_columns = parsed_query.columns
  query_tables = parsed_query.tables

  if len(query_tables) == 1:
    table = query_tables[0]
    query_columns = [f"{table.lower()}.{col.lower()}" for col in query_columns]

  query_tables_columns = query_columns + query_tables

  query_tables_columns = [item.lower() for item in query_tables_columns if "%" not in item]
  print(f"Query: {query_tables_columns}")

  pure_predicted_schema_tags = [tag for tag in predicted_schema_tags if tag != 'O' and not tag.startswith("O.")]
  pure_predicted_schema_tags_tables = [tag.split(".")[0] for tag in pure_predicted_schema_tags]
  pure_predicted_schema_tags_tables_columns = pure_predicted_schema_tags_tables + pure_predicted_schema_tags

  pure_predicted_schema_tags_tables_columns = list(set(pure_predicted_schema_tags_tables_columns))

  pure_predicted_schema_tags_tables_columns = [item.lower() for item in pure_predicted_schema_tags_tables_columns]
  print(f"Schema Tags: {pure_predicted_schema_tags_tables_columns}")

  correctly_found_tags = 0
  for query_item in query_tables_columns:
    if query_item in pure_predicted_schema_tags_tables_columns:
      correctly_found_tags += 1


  return correctly_found_tags / max(len(pure_predicted_schema_tags_tables_columns), len(query_tables_columns))

In [141]:
def calculate_table_accuracy(predicted_schema_tags, true_query):
  parsed_query = Parser(true_query)
  query_tables = parsed_query.tables

  query_tables = [item.lower() for item in query_tables]
  print(f"Query Tables: {query_tables}")

  pure_predicted_schema_tags_tables = [tag.split(".")[0] for tag in predicted_schema_tags]
  pure_predicted_schema_tags_tables = [tag for tag in pure_predicted_schema_tags_tables if tag != 'O']

  pure_predicted_schema_tags_tables = list(set(pure_predicted_schema_tags_tables))

  pure_predicted_schema_tags_tables = [item.lower() for item in pure_predicted_schema_tags_tables]
  print(f"Schema Tags Tables: {pure_predicted_schema_tags_tables}")

  correctly_found_tags = 0
  for query_item in query_tables:
    if query_item in pure_predicted_schema_tags_tables:
      correctly_found_tags += 1

  max_length = max(len(pure_predicted_schema_tags_tables), len(query_tables))
  if max_length == 0:
    return 0

  return correctly_found_tags / max_length

In [142]:
def calculate_column_accuracy(predicted_schema_tags, true_query):
  parsed_query = Parser(true_query)
  query_columns = parsed_query.columns
  query_tables = parsed_query.tables

  if len(query_tables) == 1:
    table = query_tables[0]
    query_columns = [f"{table}.{col}" for col in query_columns]

  query_columns = [item.lower() for item in query_columns]
  print(f"Query Columns: {query_columns}")

  pure_predicted_schema_tags_columns = [tag for tag in predicted_schema_tags if tag != 'O' and "." in tag]
  pure_predicted_schema_tags_columns = list(set(pure_predicted_schema_tags_columns))

  pure_predicted_schema_tags_columns = [item.lower() for item in pure_predicted_schema_tags_columns]
  print(f"Schema Tag Columns: {pure_predicted_schema_tags_columns}")

  correctly_found_tags = 0
  for query_item in query_columns:
    if query_item in pure_predicted_schema_tags_columns:
      correctly_found_tags += 1

  max_length = max(len(pure_predicted_schema_tags_columns), len(query_columns))
  if max_length == 0:
    return 0

  return correctly_found_tags / max_length

In [143]:
query = "SELECT head_id ,  name FROM head WHERE name LIKE '%Ha%'"
predicted_tags = ['head.head_id', 'head.name', 'head', 'department']

acc = calculate_accuracy(predicted_tags, query)
print(acc)

Query: ['head.head_id', 'head.name', 'head']
Schema Tags: ['head', 'head.name', 'department', 'head.head_id']
0.75


In [144]:
from google.colab import drive
import json

drive.mount('/content/drive')

spider_predicted_tags_path_1 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_0-1000.txt'
spider_predicted_tags_path_2 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_1000-2000.txt'
spider_predicted_tags_path_3 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_2000-3000.txt'
spider_predicted_tags_path_4 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_3000-4000.txt'
spider_predicted_tags_path_5 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_4000-5000.txt'
spider_predicted_tags_path_6 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_5000-6000.txt'
spider_predicted_tags_path_7 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_6000-7000.txt'
spider_predicted_tags_path_8 = '/content/drive/MyDrive/CS559_shared/implementation/13_spider_shema_tags_7000-8000.txt'
spider_train_path = '/content/drive/MyDrive/CS559_shared/implementation/spider/train_spider.json'
spider_train_others_path = '/content/drive/MyDrive/CS559_shared/implementation/spider/train_others.json'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [145]:
spider_train = open(spider_train_path)
spider_train_list = json.load(spider_train)

spider_others = open(spider_train_others_path)
spider_others_list = json.load(spider_others)

In [146]:
all_spider_questions = spider_train_list + spider_others_list

In [147]:
all_spider_questions[0]

{'db_id': 'department_management',
 'query': 'SELECT count(*) FROM head WHERE age  >  56',
 'query_toks': ['SELECT',
  'count',
  '(',
  '*',
  ')',
  'FROM',
  'head',
  'WHERE',
  'age',
  '>',
  '56'],
 'query_toks_no_value': ['select',
  'count',
  '(',
  '*',
  ')',
  'from',
  'head',
  'where',
  'age',
  '>',
  'value'],
 'question': 'How many heads of the departments are older than 56 ?',
 'question_toks': ['How',
  'many',
  'heads',
  'of',
  'the',
  'departments',
  'are',
  'older',
  'than',
  '56',
  '?'],
 'sql': {'from': {'table_units': [['table_unit', 1]], 'conds': []},
  'select': [False, [[3, [0, [0, 0, False], None]]]],
  'where': [[False, 3, [0, [0, 10, False], None], 56.0, None]],
  'groupBy': [],
  'having': [],
  'orderBy': [],
  'limit': None,
  'intersect': None,
  'union': None,
  'except': None}}

In [148]:
import re
import numpy as np

def retrieve_data_and_tag(path):
    file = open(path, 'r')
    sentences, sentence_tags = [], []
    for line in file:
        # line = re.sub(r"\".*\"","\"xxx\"",line)
        line = re.sub(r"'", "", line)
        word_tag_list = line.split()
        words = []
        tags = []
        for w in word_tag_list:
            item = w.split('/')
            if (len(item[0]) == 0):
                continue

            words.append(item[0])
            if (len(item) < 2):
                print(path, item)
            tags.append(item[1])
        sentences.append(words)
        sentence_tags.append(tags)
        if (len(sentences) != len(sentence_tags)):
            print(line)
    return sentences, sentence_tags

In [149]:
spider_questions_1, spider_predicted_tags_1 = retrieve_data_and_tag(spider_predicted_tags_path_1)
spider_questions_2, spider_predicted_tags_2 = retrieve_data_and_tag(spider_predicted_tags_path_2)
spider_questions_3, spider_predicted_tags_3 = retrieve_data_and_tag(spider_predicted_tags_path_3)
spider_questions_4, spider_predicted_tags_4 = retrieve_data_and_tag(spider_predicted_tags_path_4)
spider_questions_5, spider_predicted_tags_5 = retrieve_data_and_tag(spider_predicted_tags_path_5)
spider_questions_6, spider_predicted_tags_6 = retrieve_data_and_tag(spider_predicted_tags_path_6)
spider_questions_7, spider_predicted_tags_7 = retrieve_data_and_tag(spider_predicted_tags_path_7)
spider_questions_8, spider_predicted_tags_8 = retrieve_data_and_tag(spider_predicted_tags_path_8)

In [150]:
#all_spider_questions = spider_questions_1 + spider_questions_2 + spider_questions_3 + spider_questions_4 + spider_questions_5 + spider_questions_6 + spider_questions_7 + spider_questions_8
all_spider_predicted_tags = spider_predicted_tags_1 + spider_predicted_tags_2 + spider_predicted_tags_3 + spider_predicted_tags_4 + spider_predicted_tags_5 + spider_predicted_tags_6 + spider_predicted_tags_7 + spider_predicted_tags_8

In [151]:
print(all_spider_questions[0])
print(all_spider_predicted_tags[0])
print(len(all_spider_questions))

{'db_id': 'department_management', 'query': 'SELECT count(*) FROM head WHERE age  >  56', 'query_toks': ['SELECT', 'count', '(', '*', ')', 'FROM', 'head', 'WHERE', 'age', '>', '56'], 'query_toks_no_value': ['select', 'count', '(', '*', ')', 'from', 'head', 'where', 'age', '>', 'value'], 'question': 'How many heads of the departments are older than 56 ?', 'question_toks': ['How', 'many', 'heads', 'of', 'the', 'departments', 'are', 'older', 'than', '56', '?'], 'sql': {'from': {'table_units': [['table_unit', 1]], 'conds': []}, 'select': [False, [[3, [0, [0, 0, False], None]]]], 'where': [[False, 3, [0, [0, 10, False], None], 56.0, None]], 'groupBy': [], 'having': [], 'orderBy': [], 'limit': None, 'intersect': None, 'union': None, 'except': None}}
['O', 'O', 'department', 'O', 'O', 'department', 'O', 'O', 'O', 'department', 'O']
8659


In [152]:
all_accuracies = []
for idx in range(len(all_spider_predicted_tags)):

  spider_question = all_spider_questions[idx]["question"]
  spider_question_query = all_spider_questions[idx]["query"]

  spider_question_predicted_tags = all_spider_predicted_tags[idx]

  print(f"Question: {spider_question}")
  accuracy = calculate_accuracy(spider_question_predicted_tags, spider_question_query)

  all_accuracies.append(accuracy)

avg_acc = np.mean(all_accuracies)
print(f"Average accuracy: {avg_acc}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Schema Tags: ['city']
Question: Find the payment method code used by more than 3 parties.
Query: ['parties.payment_method_code', 'parties']
Schema Tags: ['party', 'parties']
Question: What are the payment method codes that have been used by more than 3 parties?
Query: ['parties.payment_method_code', 'parties']
Schema Tags: ['method', 'parties', 'parties.payment', 'code', 'payment']
Question: Find the name of organizations whose names contain "Party".
Query: ['organizations.organization_name', 'organizations']
Schema Tags: ['organization']
Question: What are the names of organizations that contain the word "Party"?
Query: ['organizations.organization_name', 'organizations']
Schema Tags: ['organization']
Question: How many distinct payment methods are used by parties?
Query: ['parties.payment_method_code', 'parties']
Schema Tags: ['parties', 'method', 'payment']
Question: Count the number of different payment method codes u

In [153]:
all_accuracies = []
for idx in range(len(all_spider_predicted_tags)):

  spider_question = all_spider_questions[idx]["question"]
  spider_question_query = all_spider_questions[idx]["query"]

  spider_question_predicted_tags = all_spider_predicted_tags[idx]

  print(f"Question: {spider_question}")
  accuracy = calculate_table_accuracy(spider_question_predicted_tags, spider_question_query)
  print(f"Acc: {accuracy}")

  all_accuracies.append(accuracy)

avg_acc = np.mean(all_accuracies)
print(f"Average Table Accuracy: {avg_acc}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Query Tables: ['faculty', 'student']
Schema Tags Tables: ['advisor']
Acc: 0.0
Question: Show the ids of students whose advisors are professors.
Query Tables: ['faculty', 'student']
Schema Tags Tables: ['advisor', 'professor', 'student']
Acc: 0.3333333333333333
Question: Which students have professors as their advisors? Find their student ids.
Query Tables: ['faculty', 'student']
Schema Tags Tables: ['student']
Acc: 0.5
Question: Show first name and last name for all the students advised by Michael Goodrich.
Query Tables: ['faculty', 'student']
Schema Tags Tables: ['student']
Acc: 0.5
Question: Which students are advised by Michael Goodrich? Give me their first and last names.
Query Tables: ['faculty', 'student']
Schema Tags Tables: ['student']
Acc: 0.5
Question: Show the faculty id of each faculty member, along with the number of students he or she advises.
Query Tables: ['faculty', 'student']
Schema Tags Tables: ['studen

In [154]:
all_accuracies = []
for idx in range(len(all_spider_predicted_tags)):

  spider_question = all_spider_questions[idx]["question"]
  spider_question_query = all_spider_questions[idx]["query"]

  spider_question_predicted_tags = all_spider_predicted_tags[idx]

  print(f"Question: {spider_question}")
  accuracy = calculate_column_accuracy(spider_question_predicted_tags, spider_question_query)
  print(f"Acc: {accuracy}")

  all_accuracies.append(accuracy)

avg_acc = np.mean(all_accuracies)
print(f"Average Column Accuracy: {avg_acc}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Query Columns: ['faculty.fname', 'faculty.lname', 'faculty.facid', 'student.advisor', 'student.fname', 'linda', 'student.lname', 'smith']
Schema Tag Columns: []
Acc: 0.0
Question: Show the ids of students whose advisors are professors.
Query Columns: ['student.stuid', 'faculty.facid', 'student.advisor', 'faculty.rank', 'professor']
Schema Tag Columns: ['o.lname']
Acc: 0.0
Question: Which students have professors as their advisors? Find their student ids.
Query Columns: ['student.stuid', 'faculty.facid', 'student.advisor', 'faculty.rank', 'professor']
Schema Tag Columns: []
Acc: 0.0
Question: Show first name and last name for all the students advised by Michael Goodrich.
Query Columns: ['student.fname', 'student.lname', 'faculty.facid', 'student.advisor', 'faculty.fname', 'michael', 'faculty.lname', 'goodrich']
Schema Tag Columns: []
Acc: 0.0
Question: Which students are advised by Michael Goodrich? Give me their first and