<a href="https://colab.research.google.com/github/YichengShen/cis5220-project/blob/main/cis5220_final_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text-to-SQL

Team: Query Marksman

## Section 1: Setup

### Install & imports

In [1]:
!pip install nltk

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import shutil

import re
import io
import json
import numpy as np
import os
from typing import List, Dict, Tuple, Any, Union

Mount Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Load data into Colab notebook

Before you run the code below, make sure you download the Spider dataset from [here](https://yale-lily.github.io/spider). Then, you upload the zip file of the dataset to your Drive.

Copy data from Drive into the current runtime

In [4]:
# Create data folder if not exist
!mkdir -p data

# Change this path to where you store spider.zip in your Drive
dataset_zip_path_in_drive = "/content/drive/MyDrive/CIS5220_final_project/spider.zip"
dataset_zip_path_in_runtime = "/content/data/spider.zip"

shutil.copy(dataset_zip_path_in_drive, dataset_zip_path_in_runtime)

'/content/data/spider.zip'

Unzip

In [5]:
!unzip -q -o /content/data/spider.zip -d /content/data/

## Section 2: Data Preparation & Cleaning (Milestone 2)

### Helper Functions

In [6]:
def process(sql_data: List[Dict], 
            table_data: List[Dict]) -> Tuple[List[Dict], Dict[str, Dict]]:
    output_tab = {}
    for i in range(len(table_data)):
      table = table_data[i]
      temp = {}
      temp['col_map'] = table['column_names']

      db_name = table['db_id']
      output_tab[db_name] = temp


    output_sql = []
    for i in range(len(sql_data)):
      sql = sql_data[i]
      temp = {}

      # add query metadata
      temp['question'] = sql['question']
      temp['question_tok'] = sql['question_toks']
      temp['query'] = sql['query']
      temp['query_tok'] = sql['query_toks']
      temp['table_id'] = sql['db_id']
      sql_temp = {}

      # process agg/sel
      sql_temp['agg'] = []
      sql_temp['sel'] = []
      gt_sel = sql['sql']['select'][1]
      for tup in gt_sel:
        sql_temp['agg'].append(tup[0])
        sql_temp['sel'].append(tup[1][1][1])
      
      # process where conditions and conjuctions
      sql_temp['cond'] = []
      gt_cond = sql['sql']['where']
      if len(gt_cond) > 0:
        conds = [gt_cond[x] for x in range(len(gt_cond)) if x % 2 == 0]
        for cond in conds:
          curr_cond = []
          curr_cond.append(cond[2][1][1])
          curr_cond.append(cond[1])
          if cond[4] is not None:
            curr_cond.append([cond[3], cond[4]])
          else:
            curr_cond.append(cond[3])
          sql_temp['cond'].append(curr_cond)

      sql_temp['conj'] = [gt_cond[x] for x in range(len(gt_cond)) if x % 2 == 1]

      # process group by / having
      sql_temp['group'] = [x[1] for x in sql['sql']['groupBy']]
      having_cond = []
      if len(sql['sql']['having']) > 0:
        gt_having = sql['sql']['having'][0] # currently only do first having condition
        having_cond.append(gt_having[2][1][0]) # aggregator
        having_cond.append(gt_having[2][1][1]) # column
        having_cond.append(gt_having[1]) # operator
        if gt_having[4] is not None:
          having_cond.append([gt_having[3], gt_having[4]])
        else:
          having_cond.append(gt_having[3])
      sql_temp['group'].append(having_cond)

      # process order by / limit
      order_aggs = []
      order_cols = []
      order_par = -1
      gt_order = sql['sql']['orderBy']
      if len(gt_order) > 0:
        order_aggs = [x[1][0] for x in gt_order[1]]
        order_cols = [x[1][1] for x in gt_order[1]]
        order_par = 1 if gt_order[0] == 'asc' else 0
      sql_temp['order'] = [order_aggs, order_cols, order_par]

      # process intersect/except/union
      sql_temp['special'] = 0
      if sql['sql']['intersect'] is not None:
        sql_temp['special'] = 1
      elif sql['sql']['except'] is not None:
        sql_temp['special'] = 2
      elif sql['sql']['union'] is not None:
        sql_temp['special'] = 3

      temp['sql'] = sql_temp
      output_sql.append(temp)
    return output_sql, output_tab

In [7]:
def load_data_new(sql_paths: Union[str, List[str]], 
                  table_paths: Union[str, List[str]], 
                  use_small: bool = False) -> Tuple[List[Dict], Dict[str, Dict]]:
    if not isinstance(sql_paths, list):
        sql_paths = (sql_paths, )
    if not isinstance(table_paths, list):
        table_paths = (table_paths, )
    sql_data = []
    table_data = {}
    for i, SQL_PATH in enumerate(sql_paths):
        if use_small and i >= 2:
            break
        print(f"Loading data from {SQL_PATH}")
        with open(SQL_PATH) as inf:
            data = json.load(inf)
            sql_data += data
                
    for i, TABLE_PATH in enumerate(table_paths):
        if use_small and i >= 2:
            break
        print(f"Loading data from {TABLE_PATH}")
        with open(TABLE_PATH) as inf:
            table_data= json.load(inf)
    # print sql_data[0]
    sql_data, table_data = process(sql_data, table_data)
    return sql_data, table_data

### Load Clean Data

In [10]:
sql_data_train, table_data = load_data_new(["/content/data/spider/train_spider.json"], ["/content/data/spider/tables.json"], use_small=False)

Loading data from /content/data/spider/train_spider.json
Loading data from /content/data/spider/tables.json


In [11]:
sql_data_dev, table_data = load_data_new(["/content/data/spider/dev.json"], ["/content/data/spider/tables.json"], use_small=False)

Loading data from /content/data/spider/dev.json
Loading data from /content/data/spider/tables.json


## Section 3: EDA

### SQL Data

In [26]:
print(f"Number of training data: {len(sql_data_train)}")
print(f"Number of eval data: {len(sql_data_dev)}")

Number of training data: 7000
Number of eval data: 1034


One example looks like:

In [21]:
sql_data_train[0]

{'question': 'How many heads of the departments are older than 56 ?',
 'question_tok': ['How',
  'many',
  'heads',
  'of',
  'the',
  'departments',
  'are',
  'older',
  'than',
  '56',
  '?'],
 'query': 'SELECT count(*) FROM head WHERE age  >  56',
 'query_tok': ['SELECT',
  'count',
  '(',
  '*',
  ')',
  'FROM',
  'head',
  'WHERE',
  'age',
  '>',
  '56'],
 'table_id': 'department_management',
 'sql': {'agg': [3],
  'sel': [0],
  'cond': [[10, 3, 56.0]],
  'conj': [],
  'group': [[]],
  'order': [[], [], -1],
  'special': 0}}

### Database Schema Data

In [30]:
len(table_data)

166

In [29]:
table_data['yelp']

{'col_map': [[-1, '*'],
  [0, 'bid'],
  [0, 'business id'],
  [0, 'name'],
  [0, 'full address'],
  [0, 'city'],
  [0, 'latitude'],
  [0, 'longitude'],
  [0, 'review count'],
  [0, 'is open'],
  [0, 'rating'],
  [0, 'state'],
  [1, 'id'],
  [1, 'business id'],
  [1, 'category name'],
  [2, 'uid'],
  [2, 'user id'],
  [2, 'name'],
  [3, 'cid'],
  [3, 'business id'],
  [3, 'count'],
  [3, 'day'],
  [4, 'id'],
  [4, 'business id'],
  [4, 'neighbourhood name'],
  [5, 'rid'],
  [5, 'business id'],
  [5, 'user id'],
  [5, 'rating'],
  [5, 'text'],
  [5, 'year'],
  [5, 'month'],
  [6, 'tip id'],
  [6, 'business id'],
  [6, 'text'],
  [6, 'user id'],
  [6, 'likes'],
  [6, 'year'],
  [6, 'month']]}