In [1]:
import numpy as np
import pandas as pd
import json
import re

In [2]:
def spider_json_to_pandas(file_path):
    # Read the JSON file
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    
    # Convert each JSON object in the array
    converted_data = []
    for item in data:
        converted_item = {
            'db_id': item['db_id'],
            'query': item['query'],
            'question': item['question'],
            'from': int(bool(item['sql']['from'])),
            'select': int(bool(item['sql']['select'])),
            'where': int(bool(item['sql']['where'])),
            'groupBy': int(bool(item['sql']['groupBy'])),
            'having': int(bool(item['sql']['having'])),
            'orderBy': int(bool(item['sql']['orderBy'])),
            'limit': int(bool(item['sql']['limit'])),
            'intersect': int(bool(item['sql']['intersect'])),
            'union': int(bool(item['sql']['union'])),
            'except': int(bool(item['sql']['except']))
        }
        converted_data.append(converted_item)
    
    columns = ['db_id',
            'query',
            'question',
            'from',
            'select',
            'where',
            'groupBy',
            'having',
            'orderBy',
            'limit',
            'intersect',
            'union',
            'except',
            ]
    
    dataset = pd.DataFrame(converted_data, columns=columns)

    return dataset

In [3]:
def func_preprocess(text):
    text = text.lower().strip()
    
    text = re.sub(r"\(", ' ( ', text)
    text = re.sub(r'\)', ' ) ', text)
    text = re.sub(r'=', ' = ', text)
    text = re.sub(r'  ', ' ', text)

    return text


def all_func_columns(df, wikidataset=False):

    if wikidataset:
        funcs = ['db_id', 'from', 'select', 'where', 'groupBy',
       'having', 'orderBy', 'limit', 'intersect', 'union', 'except',
       'distinct', 'count', 'join' ,'max', 'avg', 'min', 'sum', 'and', 'or']
    else:
        funcs = ["distinct", "count", "join", "max", "avg", "min", "sum", "and", "or"]

    for func in funcs:
        df[func] = df['query'].astype(str).str.contains(func, case=False).astype(int)

    return df

## Train

In [4]:
file_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/train_others.json'
spider_train_other = spider_json_to_pandas(file_path)

spider_train_other.head()

Unnamed: 0,db_id,query,question,from,select,where,groupBy,having,orderBy,limit,intersect,union,except
0,geo,SELECT city_name FROM city WHERE population =...,what is the biggest city in wyoming,1,1,1,0,0,0,0,0,0,0
1,geo,SELECT city_name FROM city WHERE population =...,what wyoming city has the largest population,1,1,1,0,0,0,0,0,0,0
2,geo,SELECT city_name FROM city WHERE population =...,what is the largest city in wyoming,1,1,1,0,0,0,0,0,0,0
3,geo,SELECT city_name FROM city WHERE population =...,where is the most populated area of wyoming,1,1,1,0,0,0,0,0,0,0
4,geo,SELECT city_name FROM city WHERE population =...,which city in wyoming has the largest population,1,1,1,0,0,0,0,0,0,0


In [5]:
spider_train_other['query'] = spider_train_other['query'].map(func_preprocess)
spider_train_other = all_func_columns(spider_train_other)
spider_train_other.head()

Unnamed: 0,db_id,query,question,from,select,where,groupBy,having,orderBy,limit,...,except,distinct,count,join,max,avg,min,sum,and,or
0,geo,select city_name from city where population =...,what is the biggest city in wyoming,1,1,1,0,0,0,0,...,0,0,0,0,1,0,1,0,1,0
1,geo,select city_name from city where population =...,what wyoming city has the largest population,1,1,1,0,0,0,0,...,0,0,0,0,1,0,1,0,1,0
2,geo,select city_name from city where population =...,what is the largest city in wyoming,1,1,1,0,0,0,0,...,0,0,0,0,1,0,1,0,1,0
3,geo,select city_name from city where population =...,where is the most populated area of wyoming,1,1,1,0,0,0,0,...,0,0,0,0,1,0,1,0,1,0
4,geo,select city_name from city where population =...,which city in wyoming has the largest population,1,1,1,0,0,0,0,...,0,0,0,0,1,0,1,0,1,0


In [67]:
file_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/train_spider.json'
train_spider = spider_json_to_pandas(file_path)
train_spider['query'] = train_spider['query'].map(func_preprocess)
train_spider = all_func_columns(train_spider)
train_spider.head()

Unnamed: 0,db_id,query,question,from,select,where,groupBy,having,orderBy,limit,...,except,distinct,count,join,max,avg,min,sum,and,or
0,department_management,select count ( * ) from head where age > 56,How many heads of the departments are older th...,1,1,1,0,0,0,0,...,0,0,1,0,0,0,0,0,0,0
1,department_management,"select name , born_state , age from head order...","List the name, born state and age of the heads...",1,1,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,1
2,department_management,"select creation , name , budget_in_billions fr...","List the creation year, name and budget of eac...",1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,department_management,"select max ( budget_in_billions ) , min ( budg...",What are the maximum and minimum budget of the...,1,1,0,0,0,0,0,...,0,0,0,0,1,0,1,0,0,0
4,department_management,select avg ( num_employees ) from department w...,What is the average number of employees of the...,1,1,1,0,0,0,0,...,0,0,0,0,0,1,0,0,1,0


In [55]:
final_train = pd.concat([train_spider, spider_train_other], ignore_index=True)
print(final_train.shape)
final_train = final_train.drop_duplicates()
print(final_train.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

final_train.to_csv(f"{save_path}train_spider.csv")

(8659, 21)
(8651, 21)


## Test

In [51]:
file_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/test.json'
test_spider = spider_json_to_pandas(file_path)
test_spider['query'] = test_spider['query'].map(func_preprocess)
test_spider = all_func_columns(test_spider)

test_spider.head()

Unnamed: 0,db_id,query,question,from,select,where,groupBy,having,orderBy,limit,...,union,except,distinct,count,max,avg,min,sum,and,or
0,soccer_3,select count ( * ) from club,How many clubs are there?,1,1,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
1,soccer_3,select count ( * ) from club,Count the number of clubs.,1,1,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
2,soccer_3,select name from club order by name asc,List the name of clubs in ascending alphabetic...,1,1,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,1
3,soccer_3,select name from club order by name asc,"What are the names of clubs, ordered alphabeti...",1,1,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,1
4,soccer_3,"select manager , captain from club",What are the managers and captains of clubs?,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [56]:
print(test_spider.shape)
test_spider = test_spider.drop_duplicates()
print(test_spider.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

test_spider.to_csv(f"{save_path}test_spider.csv")

(2147, 21)
(2147, 21)


## Validation

In [53]:
file_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/dev.json'
val_spider = spider_json_to_pandas(file_path)
val_spider['query'] = val_spider['query'].map(func_preprocess)
val_spider = all_func_columns(val_spider)
val_spider.head()

Unnamed: 0,db_id,query,question,from,select,where,groupBy,having,orderBy,limit,...,union,except,distinct,count,max,avg,min,sum,and,or
0,concert_singer,select count ( * ) from singer,How many singers do we have?,1,1,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
1,concert_singer,select count ( * ) from singer,What is the total number of singers?,1,1,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
2,concert_singer,"select name , country , age from singer order ...","Show name, country, age for all singers ordere...",1,1,0,0,0,1,0,...,0,0,0,1,0,0,0,0,0,1
3,concert_singer,"select name , country , age from singer order ...","What are the names, countries, and ages for ev...",1,1,0,0,0,1,0,...,0,0,0,1,0,0,0,0,0,1
4,concert_singer,"select avg ( age ) , min ( age ) , max ( age )...","What is the average, minimum, and maximum age ...",1,1,1,0,0,0,0,...,0,0,0,1,1,1,1,0,0,0


In [57]:
print(val_spider.shape)
val_spider = val_spider.drop_duplicates()
print(val_spider.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

val_spider.to_csv(f"{save_path}validation_spider.csv")

(1034, 21)
(1034, 21)


## WikiSQL dataset

In [33]:
file_path = "/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/"
train_wiki_data = pd.read_csv(f"{file_path}train_wikisql.csv", index_col=0)
train_wiki_data = train_wiki_data.reset_index()
train_wiki_data.head()

Unnamed: 0,question,sql
0,Tell me what the notes are for South Australia,SELECT Notes FROM table WHERE Current slogan =...
1,What is the current series where the new serie...,SELECT Current series FROM table WHERE Notes =...
2,What is the format for South Australia?,SELECT Format FROM table WHERE State/territory...
3,Name the background colour for the Australian ...,SELECT Text/background colour FROM table WHERE...
4,how many times is the fuel propulsion is cng?,SELECT COUNT Fleet Series (Quantity) FROM tabl...


In [34]:
train_wiki_data = train_wiki_data.rename(columns={'sql':'query'}, errors="raise")
train_wiki_data['query'] = train_wiki_data['query'].map(func_preprocess)
train_wiki_data = all_func_columns(train_wiki_data, wikidataset=True)
train_wiki_data.head()

Unnamed: 0,question,query,db_id,from,select,where,groupBy,having,orderBy,limit,...,union,except,distinct,count,max,avg,min,sum,and,or
0,Tell me what the notes are for South Australia,select notes from table where current slogan =...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,What is the current series where the new serie...,select current series from table where notes =...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,What is the format for South Australia?,select format from table where state/territory...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,Name the background colour for the Australian ...,select text/background colour from table where...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
4,how many times is the fuel propulsion is cng?,select count fleet series ( quantity ) from ta...,0,1,1,1,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0


In [39]:
file_path = "/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/"
validation_wiki_data = pd.read_csv(f"{file_path}validation_wikisql.csv", index_col=0)
validation_wiki_data = validation_wiki_data.reset_index()
validation_wiki_data.head()

Unnamed: 0,question,sql
0,What position does the player who played for b...,SELECT Position FROM table WHERE School/Club T...
1,How many schools did player number 3 play at?,SELECT COUNT School/Club Team FROM table WHERE...
2,What school did player number 21 play for?,SELECT School/Club Team FROM table WHERE No. = 21
3,Who is the player that wears number 42?,SELECT Player FROM table WHERE No. = 42
4,What player played guard for toronto in 1996-97?,SELECT Player FROM table WHERE Position = Guar...


In [40]:
validation_wiki_data = validation_wiki_data.rename(columns={'sql':'query'}, errors="raise")
validation_wiki_data['query'] = validation_wiki_data['query'].map(func_preprocess)
validation_wiki_data = all_func_columns(validation_wiki_data, wikidataset=True)
validation_wiki_data.head()

Unnamed: 0,question,query,db_id,from,select,where,groupBy,having,orderBy,limit,...,union,except,distinct,count,max,avg,min,sum,and,or
0,What position does the player who played for b...,select position from table where school/club t...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,How many schools did player number 3 play at?,select count school/club team from table where...,0,1,1,1,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
2,What school did player number 21 play for?,select school/club team from table where no. = 21,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,Who is the player that wears number 42?,select player from table where no. = 42,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,What player played guard for toronto in 1996-97?,select player from table where position = guar...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,1,1


In [41]:
file_path = "/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/"
test_wiki_data = pd.read_csv(f"{file_path}test_wikisql.csv", index_col=0)
test_wiki_data = test_wiki_data.reset_index()
test_wiki_data.head()

Unnamed: 0,question,sql
0,What is terrence ross' nationality,SELECT Nationality FROM table WHERE Player = T...
1,What clu was in toronto 1995-96,SELECT School/Club Team FROM table WHERE Years...
2,which club was in toronto 2003-06,SELECT School/Club Team FROM table WHERE Years...
3,how many schools or teams had jalen rose,SELECT COUNT School/Club Team FROM table WHERE...
4,Where was Assen held?,SELECT Round FROM table WHERE Circuit = Assen


In [42]:
test_wiki_data = test_wiki_data.rename(columns={'sql':'query'}, errors="raise")
test_wiki_data['query'] = test_wiki_data['query'].map(func_preprocess)
test_wiki_data = all_func_columns(test_wiki_data, wikidataset=True)
test_wiki_data.head()

Unnamed: 0,question,query,db_id,from,select,where,groupBy,having,orderBy,limit,...,union,except,distinct,count,max,avg,min,sum,and,or
0,What is terrence ross' nationality,select nationality from table where player = t...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,What clu was in toronto 1995-96,select school/club team from table where years...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
2,which club was in toronto 2003-06,select school/club team from table where years...,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,how many schools or teams had jalen rose,select count school/club team from table where...,0,1,1,1,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
4,Where was Assen held?,select round from table where circuit = assen,0,1,1,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [43]:
print(train_wiki_data.shape)
train_wiki_data = train_wiki_data.drop_duplicates()
print(train_wiki_data.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

train_wiki_data.to_csv(f"{save_path}train_wiki_processed.csv")

(56355, 21)
(56164, 21)


In [44]:
print(validation_wiki_data.shape)
validation_wiki_data = validation_wiki_data.drop_duplicates()
print(validation_wiki_data.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

validation_wiki_data.to_csv(f"{save_path}validation_wiki_processed.csv")

(8421, 21)
(8392, 21)


In [45]:
print(test_wiki_data.shape)
test_wiki_data = test_wiki_data.drop_duplicates()
print(test_wiki_data.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

test_wiki_data.to_csv(f"{save_path}test_wiki_processed.csv")

(15878, 21)
(15836, 21)


# Final dataset

In [58]:
train = pd.concat([train_wiki_data, final_train])
validation = pd.concat([validation_wiki_data, val_spider])
test = pd.concat([test_wiki_data, test_spider])

In [60]:
train.shape, test.shape, validation.shape

((64815, 21), (17983, 21), (9426, 21))

In [61]:
print(train.shape)
train = train.drop_duplicates()
print(train.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

train.to_csv(f"{save_path}train.csv")

(64815, 21)
(64815, 21)


In [62]:
print(validation.shape)
validation = validation.drop_duplicates()
print(validation.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

validation.to_csv(f"{save_path}validation.csv")

(9426, 21)
(9426, 21)


In [63]:
print(test.shape)
test = test.drop_duplicates()
print(test.shape)
save_path = '/Users/vedanttibrewal/Documents/USC/lectures/sem_1/DSCI-551/project/chatDB-dsci551/dataset/'

test.to_csv(f"{save_path}test.csv")

(17983, 21)
(17983, 21)
