In [1]:
# extract the dataset into a csv file from the json file
import json
import csv

def extract_data(json_file, csv_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    with open(csv_file, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(data[0].keys())
        for row in data:
            writer.writerow(row.values())

In [2]:
json_file = './spider/train_spider.json'
csv_file = './spider/train_spider.csv'
extract_data(json_file, csv_file)

In [3]:
json_file = './spider/dev.json'
csv_file = './spider/dev_spider.csv'
extract_data(json_file, csv_file)

In [4]:
# link the csv file to the database in ./spider/database

def dbid_to_schema(db_id):
    with open('./spider/database/' + db_id + '/schema.sql') as f:
        lines = f.readlines()
    return lines

In [5]:
# get all db_ids in the ./spider/database folder
import os
db_ids_train = os.listdir('./spider/database')
db_ids_train = [db_id for db_id in db_ids_train if os.path.isdir('./spider/database/' + db_id)]

# remove from list all db_ids who do not have a schema.sql file, and put them in a new list
db_ids_no_schema = []
for db_id in db_ids_train:
    try:
        dbid_to_schema(db_id)
    except:
        db_ids_no_schema.append(db_id)
db_ids_train = [db_id for db_id in db_ids_train if db_id not in db_ids_no_schema]

# remove from the no_schema list all db_ids who do not have a .sql file, and put them in a new list
db_ids_no_sql = ['chinook_1', 'company_1', 'epinions_1', 'flight_4', 'icfp_1', 'small_bank_1', 'twitter_1', 'voter_1', 'world_1']
db_ids_no_schema = [db_id for db_id in db_ids_no_schema if db_id not in db_ids_no_sql]

# add outliers:
db_ids_outliers = ['college_1', 'college_2']
schema_outliers = ['TinyCollege.sql', 'TextBookExampleSchema.sql']

# remove all outliers from db_ids_no_schema
db_ids_no_schema = [db_id for db_id in db_ids_no_schema if db_id not in db_ids_outliers]

# remove the db_ids_no_sql from the train_spider.csv file
import pandas as pd
df = pd.read_csv('./spider/train_spider.csv')
df = df[~df['db_id'].isin(db_ids_no_sql)]
df = df[~df['db_id'].isin(db_ids_no_schema)]
df = df[~df['db_id'].isin(db_ids_outliers)]
df.to_csv('./spider/train_spider.csv', index=False)

# remove the db_ids_no_sql from the dev_spider.csv file
df = pd.read_csv('./spider/dev_spider.csv')
df = df[~df['db_id'].isin(db_ids_no_sql)]
df = df[~df['db_id'].isin(db_ids_no_schema)]
df = df[~df['db_id'].isin(db_ids_outliers)]
df.to_csv('./spider/dev_spider.csv', index=False)

In [6]:
import csv
# create one csv file with two columns: db_id and schema
csv_file = './spider/database_schema.csv'

with open(csv_file, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['db_id', 'schema'])
    for db_id in db_ids_no_schema:
        with open('./spider/database/' + db_id + '/' + db_id + '.sql') as f:
            schema = f.readlines()

            # remove all INSERT lines from the schema except one per table
            kept_inserts = []
            used_tables = []
            for line in schema:
                if line.startswith('INSERT') or line.startswith('insert') or line.startswith(' insert') or line.startswith(' INSERT') or line.startswith('Insert') or line.startswith(' Insert') or line.startswith('\tINSERT') or line.startswith('\tinsert') or line.startswith('\tInsert'):
                    line_tokens = line.split()
                    table_name = line_tokens[2]
                    # remove all the line after the fourth occurence of ')'
                    if line.count(')') > 4:
                        first_find = line.find(')')
                        second_find = line.find(')', first_find + 1)
                        third_find = line.find(')', second_find + 1)
                        fourth_find = line.find(')', third_find + 1)
                        line = line[:line.find(')', fourth_find + 1)] + ')'             
                    if table_name in used_tables:
                        continue
                    else:
                        used_tables.append(table_name)
                        kept_inserts.append(line)
            schema = [line for line in schema if not line.startswith('INSERT') and not line.startswith('insert') 
                      and not line.startswith(' insert') and not line.startswith(' INSERT')
                      and not line.startswith('Insert') and not line.startswith(' Insert')
                      and not line.startswith('\tINSERT') and not line.startswith('\tinsert') and not line.startswith('\tInsert')]
            schema = schema + kept_inserts
            schema = ''.join(schema)
        writer.writerow([db_id, schema])
    for db_id in db_ids_train:
        schema = dbid_to_schema(db_id)
        # remove all INSERT lines from the schema except one per table
        kept_inserts = []
        used_tables = []
        for line in schema:
            if line.startswith('INSERT') or line.startswith('insert') or line.startswith(' insert') or line.startswith(' INSERT') or line.startswith('Insert') or line.startswith(' Insert') or line.startswith('\tINSERT') or line.startswith('\tinsert') or line.startswith('\tInsert'):
                line_tokens = line.split()
                table_name = line_tokens[2]
                # remove all the line after the fourth occurence of ')'
                if line.count(')') > 4:
                    first_find = line.find(')')
                    second_find = line.find(')', first_find + 1)
                    third_find = line.find(')', second_find + 1)
                    fourth_find = line.find(')', third_find + 1)
                    line = line[:line.find(')', fourth_find + 1)] + ')'                 
                if table_name in used_tables:
                    continue
                else:
                    used_tables.append(table_name)
                    kept_inserts.append(line)
        schema = [line for line in schema if not line.startswith('INSERT') and not line.startswith('insert') 
                    and not line.startswith(' insert') and not line.startswith(' INSERT')
                    and not line.startswith('Insert') and not line.startswith(' Insert')
                    and not line.startswith('\tINSERT') and not line.startswith('\tinsert') and not line.startswith('\tInsert')]
        schema = schema + kept_inserts
        schema = ''.join(schema)
        writer.writerow([db_id, schema])
    for db_id in db_ids_outliers:
        with open('./spider/database/' + db_id + '/' + schema_outliers[db_ids_outliers.index(db_id)]) as f:
            schema = f.readlines()

            # remove all INSERT lines from the schema except one per table
            kept_inserts = []
            used_tables = []
            for line in schema:
                if line.startswith('INSERT') or line.startswith('insert') or line.startswith(' insert') or line.startswith(' INSERT') or line.startswith('Insert') or line.startswith(' Insert') or line.startswith('\tINSERT') or line.startswith('\tinsert') or line.startswith('\tInsert'):
                    line_tokens = line.split()
                    table_name = line_tokens[2]
                    # remove all the line after the fourth occurence of ')'
                    if line.count(')') > 4:
                        first_find = line.find(')')
                        second_find = line.find(')', first_find + 1)
                        third_find = line.find(')', second_find + 1)
                        fourth_find = line.find(')', third_find + 1)
                        line = line[:line.find(')', fourth_find + 1)] + ')'                
                    if table_name in used_tables:
                        continue
                    else:
                        used_tables.append(table_name)
                        kept_inserts.append(line)
            schema = [line for line in schema if not line.startswith('INSERT') and not line.startswith('insert') 
                      and not line.startswith(' insert') and not line.startswith(' INSERT')
                      and not line.startswith('Insert') and not line.startswith(' Insert')
                      and not line.startswith('\tINSERT') and not line.startswith('\tinsert') and not line.startswith('\tInsert')]
            schema = schema + kept_inserts
            schema = ''.join(schema)
        writer.writerow([db_id, schema])

In [7]:
# add insert statements to schema when it is missing

import pandas as pd
db_ids_not_insert = []
df = pd.read_csv('./spider/database_schema.csv')
db_ids = df['db_id'].tolist()
for db_id in db_ids:
    schema = df[df['db_id'] == db_id]['schema'].values[0]
    if ('INSERT' or 'insert' or 'Insert') not in schema:
        db_ids_not_insert.append(db_id)

import sqlite3

def get_values(db_id):
    conn = sqlite3.connect(f'./spider/database/{db_id}/{db_id}.sqlite')
    conn.text_factory = lambda b: b.decode(errors = 'ignore')
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
    values = []
    for table in tables:
        table_name = table[0]
        value = conn.execute(f"SELECT * FROM {table_name}").fetchall()[:2]
        values.append(value)
    conn.close()
    return tables, values

for db_id in db_ids_not_insert:
    tables, values = get_values(db_id)
    inserts = []
    for i, table in enumerate(tables):
        table_name = table[0]
        if len(values[i]) != 0:
            insert = f'INSERT INTO {table_name} VALUES {values[i]}'
            inserts.append(insert)
    inserts = '\n'.join(inserts)
    df.loc[df['db_id'] == db_id, 'schema'] = df.loc[df['db_id'] == db_id, 'schema'] + inserts
df.to_csv('./spider/database_schema.csv', index=False)

# Note: only 7 databases have no insert statements in the schema, but they don't have any data in the tables either

In [8]:
# check if there are db_ids in the dev_spider.csv that are not in the database_schema.csv
df_dev = pd.read_csv('./spider/dev_spider.csv')
df_schema = pd.read_csv('./spider/database_schema.csv')

db_ids_dev = df_dev['db_id'].unique()
db_ids_schema = df_schema['db_id'].unique()

db_ids_no_schema_dev = [db_id for db_id in db_ids_dev if db_id not in db_ids_schema]
print(db_ids_no_schema_dev)

[]


In [9]:
# replace insert with \nINSERT
df_schema = pd.read_csv('./spider/database_schema.csv')
df_test = pd.read_csv('./spider/test_database_schema.csv')

df_schema['schema'] = df_schema['schema'].str.replace('INSERT', '\nINSERT')
df_schema['schema'] = df_schema['schema'].str.replace('insert', '\nINSERT')
df_schema['schema'] = df_schema['schema'].str.replace('Insert', '\nINSERT')
df_test['schema'] = df_test['schema'].str.replace('INSERT', '\nINSERT')
df_test['schema'] = df_test['schema'].str.replace('insert', '\nINSERT')
df_test['schema'] = df_test['schema'].str.replace('Insert', '\nINSERT')

df_schema.to_csv('./spider/database_schema.csv', index=False)
df_test.to_csv('./spider/test_database_schema.csv', index=False)

In [10]:
# from database_schema.csv, create a clean_database_schema.csv file with the same schema as the database_schema.csv file,
# but that removes lines that are not part of a CREATE TABLE statement, or an INSERT INTO statement

import csv

csv_file = './spider/clean_database_schema.csv'

glossary_create = ['CREATE', 'create', 'Create']
glossary_insert = ['INSERT', 'insert', 'Insert']
glossary_table = ['TABLE ', 'table ', 'Table ']

with open('./spider/database_schema.csv', 'r') as f:
    reader = csv.reader(f)
    with open(csv_file, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['db_id', 'schema'])
        for row in reader:
            db_id = row[0]
            schema = row[1]
            # remove all characters in schema before a create or CREATE or Create
            lower_schema = schema.lower()
            create_index = lower_schema.find('create')
            if create_index == -1:
                continue
            schema = schema[create_index:]

            lines = schema.split('\n')
            new_lines = []
            inside_create = False
            for line in lines:
                if any(word in line for word in glossary_create):
                    if any(word in line for word in glossary_table):
                        new_lines.append(line)
                        inside_create = True
                elif any(word in line for word in glossary_insert):
                    new_lines.append(line)
                elif inside_create:
                    new_lines.append(line)
                if ';' in line:
                    inside_create = False
            schema = '\n'.join(new_lines)
            writer.writerow([db_id, schema])

csv_file = './spider/clean_test_database_schema.csv'

with open('./spider/test_database_schema.csv', 'r') as f:
    reader = csv.reader(f)
    with open(csv_file, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['db_id', 'schema'])
        for row in reader:
            db_id = row[0]
            schema = row[1]
            # remove all characters in schema before a create or CREATE or Create
            lower_schema = schema.lower()
            create_index = lower_schema.find('create')
            if create_index == -1:
                continue
            schema = schema[create_index:]

            lines = schema.split('\n')
            new_lines = []
            inside_create = False
            for line in lines:
                if any(word in line for word in glossary_create):
                    if any(word in line for word in glossary_table):
                        new_lines.append(line)
                        inside_create = True
                elif any(word in line for word in glossary_insert):
                    new_lines.append(line)
                elif inside_create:
                    new_lines.append(line)
                if ';' in line:
                    inside_create = False
            schema = '\n'.join(new_lines)
            writer.writerow([db_id, schema])

In [11]:
def sample_schema(sample):
    db_id = sample['db_id']
    with open('./spider/database/' + db_id + '/schema.sql') as f:
        lines = f.readlines()
    new_sample = {}
    # concatenate the schema to a single string
    lines = ' '.join(lines)
    new_sample['schema'] = lines
    new_sample['query'] = sample['query']
    new_sample['question'] = sample['question']
    new_sample['db_id'] = db_id
    return new_sample

In [12]:
# get one sample from the csv file
import pandas as pd
df = pd.read_csv("./spider/train_spider.csv")
sample = df.iloc[0].to_dict()

sample = sample_schema(sample)

In [13]:
# create new csv files for the train and dev datasets with only db_id, query and question, but without the schema
def clean_csv(csv_file, new_csv_file):
    df = pd.read_csv(csv_file)
    new_df = pd.DataFrame(columns=['db_id', 'query', 'question'])
    for i in range(len(df)):
        sample = df.iloc[i].to_dict()
        sample = sample['db_id'], sample['query'], sample['question']
        new_df.loc[i] = sample
    new_df.to_csv(new_csv_file, index=False)

clean_csv('./spider/train_spider.csv', './spider/train_spider_clean.csv')
clean_csv('./spider/dev_spider.csv', './spider/dev_spider_clean.csv')

In [14]:
# create a Huggingface dataset from the csv files
from datasets import Dataset
import pandas as pd

df_train = pd.read_csv('./spider/train_spider_clean.csv')
df_dev = pd.read_csv('./spider/dev_spider_clean.csv')
df_test = pd.read_csv('./spider/test_spider_clean.csv')

train_dataset = Dataset.from_pandas(df_train)
dev_dataset = Dataset.from_pandas(df_dev)
test_dataset = Dataset.from_pandas(df_test)

df_schema = pd.read_csv('./spider/clean_database_schema.csv')
df_schema_test = pd.read_csv('./spider/clean_test_database_schema.csv')

# replace the db_id with the schema in the ./spider/database_schema.csv file, and join the schema into a single string
def replace_db_id_with_schema(sample):
    db_id = sample['db_id']
    schema = df_schema[df_schema['db_id'] == db_id]['schema'].values[0]
    sample['schema'] = schema
    # del sample['db_id']
    return sample

def replace_db_id_with_schema_test(sample):
    db_id = sample['db_id']
    schema = df_schema_test[df_schema_test['db_id'] == db_id]['schema'].values[0]
    sample['schema'] = schema
    # del sample['db_id']
    return sample

train_dataset = train_dataset.map(replace_db_id_with_schema)
dev_dataset = dev_dataset.map(replace_db_id_with_schema)
test_dataset = test_dataset.map(replace_db_id_with_schema_test)

Map:   0%|          | 0/6016 [00:00<?, ? examples/s]

Map:   0%|          | 0/665 [00:00<?, ? examples/s]

Map:   0%|          | 0/1929 [00:00<?, ? examples/s]

In [15]:
# make a datasetdict and push to the Huggingface Hub on my account to VictorDCh/spider-clean-text-to-sql (the repo already exists)
from datasets import DatasetDict

dataset_dict = DatasetDict({'train': train_dataset, 'dev': dev_dataset, 'test': test_dataset})
# dataset_dict.save_to_disk('./spider_dataset') # already done

In [17]:
from datasets import DatasetDict

# write token for huggingface-cli
print("Need to write token for huggingface-cli!")
#!huggingface-cli login --token [insert token with write permission here here]

# dataset_dict = DatasetDict.load_from_disk('./spider_dataset')
dataset_dict.push_to_hub('VictorDCh/spider-clean-text-to-sql-5', private=True)

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to C:\Users\vdubu\.cache\huggingface\token
Login successful


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/VictorDCh/spider-clean-text-to-sql-5/commit/731fc74c2e2e586de05ce4ac8fb07e743d90bd07', commit_message='Upload dataset', commit_description='', oid='731fc74c2e2e586de05ce4ac8fb07e743d90bd07', pr_url=None, pr_revision=None, pr_num=None)

In [18]:
from datasets import load_dataset

dshf = load_dataset('VictorDCh/spider-clean-text-to-sql-3')

Downloading readme:   0%|          | 0.00/591 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/450k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/47.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/126k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6016 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/665 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1929 [00:00<?, ? examples/s]

In [19]:
# get sample
sample = dshf['train'][0]

print(sample)

{'db_id': 'department_management', 'query': 'SELECT count(*) FROM head WHERE age  >  56', 'question': 'How many heads of the departments are older than 56 ?', 'schema': 'CREATE TABLE IF NOT EXISTS "department" (\r\n"Department_ID" int,\r\n"Name" text,\r\n"Creation" text,\r\n"Ranking" int,\r\n"Budget_in_Billions" real,\r\n"Num_Employees" real,\r\nPRIMARY KEY ("Department_ID")\r\n);\r\nCREATE TABLE IF NOT EXISTS "head" (\r\n"head_ID" int,\r\n"name" text,\r\n"born_state" text,\r\n"age" real,\r\nPRIMARY KEY ("head_ID")\r\n);\r\nCREATE TABLE IF NOT EXISTS "management" (\r\n"department_ID" int,\r\n"head_ID" int,\r\n"temporary_acting" text,\r\nPRIMARY KEY ("Department_ID","head_ID"),\r\nFOREIGN KEY ("Department_ID") REFERENCES `department`("Department_ID"),\r\nFOREIGN KEY ("head_ID") REFERENCES `head`("head_ID")\r\n);\r\nINSERT INTO department VALUES(1,\'State\',\'1789\',\'1\',9.9600000000000008526,30265.999999999999999);\r\nINSERT INTO head VALUES(1,\'Tiger Woods\',\'Alabama\',66.99999999999

In [20]:
print(dshf["train"][0]["schema"])

CREATE TABLE IF NOT EXISTS "department" (
"Department_ID" int,
"Name" text,
"Creation" text,
"Ranking" int,
"Budget_in_Billions" real,
"Num_Employees" real,
PRIMARY KEY ("Department_ID")
);
CREATE TABLE IF NOT EXISTS "head" (
"head_ID" int,
"name" text,
"born_state" text,
"age" real,
PRIMARY KEY ("head_ID")
);
CREATE TABLE IF NOT EXISTS "management" (
"department_ID" int,
"head_ID" int,
"temporary_acting" text,
PRIMARY KEY ("Department_ID","head_ID"),
FOREIGN KEY ("Department_ID") REFERENCES `department`("Department_ID"),
FOREIGN KEY ("head_ID") REFERENCES `head`("head_ID")
);
INSERT INTO department VALUES(1,'State','1789','1',9.9600000000000008526,30265.999999999999999);
INSERT INTO head VALUES(1,'Tiger Woods','Alabama',66.999999999999999998);
INSERT INTO management VALUES(2,5,'Yes');
