<a href="https://colab.research.google.com/github/sharik31/SQL-Generator/blob/main/preprocessing_spider_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##  Convert train_spider.json to CSV with db_id, input, and output

In [2]:
import json
import csv

with open('/content/train_spider.json') as f:
    train_data = json.load(f)

with open('spider_train_with_dbid.csv', 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=['db_id', 'input', 'output'])
    writer.writeheader()

    for item in train_data:
        db_id = item['db_id']
        question = item['question']
        sql = item['query']
        input_text = f"translate to SQL: {question}"
        writer.writerow({'db_id': db_id, 'input': input_text, 'output': sql})

print("Saved spider_train_with_dbid.csv with columns: db_id, input, output")


Saved spider_train_with_dbid.csv with columns: db_id, input, output


In [3]:
import pandas as pd
df=pd.read_csv('/content/spider_train_with_dbid.csv')

In [4]:
df.head()

Unnamed: 0,db_id,input,output
0,department_management,translate to SQL: How many heads of the depart...,SELECT count(*) FROM head WHERE age > 56
1,department_management,"translate to SQL: List the name, born state an...","SELECT name , born_state , age FROM head ORD..."
2,department_management,"translate to SQL: List the creation year, name...","SELECT creation , name , budget_in_billions ..."
3,department_management,translate to SQL: What are the maximum and min...,"SELECT max(budget_in_billions) , min(budget_i..."
4,department_management,translate to SQL: What is the average number o...,SELECT avg(num_employees) FROM department WHER...


## Add Corresponding Table Schema from table.json as Separate Column

In [5]:
import json
import pandas as pd

df = pd.read_csv('spider_train_with_dbid.csv')

with open('tables.json') as f:
    schemas = json.load(f)

schema_map = {schema['db_id']: schema for schema in schemas}

def serialize_schema(schema):
    table_names = schema.get('table_names_original', schema.get('table_names', []))
    columns = schema.get('column_names_original', schema.get('column_names', []))
    tables = {i: [] for i in range(len(table_names))}
    for col in columns:
        idx, col_name = col
        if idx != -1 and idx < len(table_names):
            tables[idx].append(col_name)
    serialized = []
    for idx, table_name in enumerate(table_names):
        cols = ", ".join(tables[idx])
        serialized.append(f"Table {table_name} ({cols})")
    return " ".join(serialized)

def get_schema_text(row):
    db_id = row['db_id']
    if db_id in schema_map:
        return serialize_schema(schema_map[db_id])
    else:
        return ""

df['schema'] = df.apply(get_schema_text, axis=1)

df.to_csv('spider_train_with_separate_schema_column.csv', index=False)

print("CSV with separate 'schema' column saved as 'spider_train_with_separate_schema_column.csv'")


CSV with separate 'schema' column saved as 'spider_train_with_separate_schema_column.csv'


In [6]:
df.head()

Unnamed: 0,db_id,input,output,schema
0,department_management,translate to SQL: How many heads of the depart...,SELECT count(*) FROM head WHERE age > 56,"Table department (Department_ID, Name, Creatio..."
1,department_management,"translate to SQL: List the name, born state an...","SELECT name , born_state , age FROM head ORD...","Table department (Department_ID, Name, Creatio..."
2,department_management,"translate to SQL: List the creation year, name...","SELECT creation , name , budget_in_billions ...","Table department (Department_ID, Name, Creatio..."
3,department_management,translate to SQL: What are the maximum and min...,"SELECT max(budget_in_billions) , min(budget_i...","Table department (Department_ID, Name, Creatio..."
4,department_management,translate to SQL: What is the average number o...,SELECT avg(num_employees) FROM department WHER...,"Table department (Department_ID, Name, Creatio..."


In [7]:
df.sample(n=10)

Unnamed: 0,db_id,input,output,schema
1766,gymnast,translate to SQL: List the names of gymnasts i...,SELECT T2.Name FROM gymnast AS T1 JOIN people ...,"Table gymnast (Gymnast_ID, Floor_Exercise_Poin..."
4711,department_store,translate to SQL: What are the product id and ...,"SELECT product_id , product_type_code FROM pr...","Table Addresses (address_id, address_details) ..."
3250,college_1,translate to SQL: What is the total number of ...,SELECT count(*) FROM professor WHERE prof_high...,"Table CLASS (CLASS_CODE, CRS_CODE, CLASS_SECTI..."
5003,soccer_2,translate to SQL: What are the names and hours...,"SELECT T1.pName , T1.HS FROM player AS T1 JOI...","Table College (cName, state, enr) Table Player..."
439,allergy_1,translate to SQL: How many allergies are there?,SELECT count(DISTINCT allergy) FROM Allergy_type,"Table Allergy_Type (Allergy, AllergyType) Tabl..."
2374,csu_1,translate to SQL: What is the number of facult...,SELECT faculty FROM faculty AS T1 JOIN campuse...,"Table Campuses (Id, Campus, Location, County, ..."
4222,cre_Doc_Tracking_DB,"translate to SQL: Show the location code, the ...","SELECT location_code , date_in_location_from ...","Table Ref_Document_Types (Document_Type_Code, ..."
2003,gas_company,translate to SQL: Show all main industry for a...,SELECT DISTINCT main_industry FROM company,"Table company (Company_ID, Rank, Company, Head..."
2308,perpetrator,translate to SQL: What are the names of people...,SELECT Name FROM People ORDER BY Height ASC,"Table perpetrator (Perpetrator_ID, People_ID, ..."
2797,election,translate to SQL: Which parties did not have a...,SELECT Party FROM party WHERE Party_ID NOT IN ...,"Table county (County_Id, County_name, Populati..."


In [8]:
print(len(df))

7000


In [9]:
print(df.isnull().sum())


db_id     0
input     0
output    0
schema    0
dtype: int64


## Preprocessing DevSet

In [10]:
with open('/content/dev.json') as f:
    train_data = json.load(f)

with open('spider_val_with_dbid.csv', 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=['db_id', 'input', 'output'])
    writer.writeheader()

    for item in train_data:
        db_id = item['db_id']
        question = item['question']
        sql = item['query']
        input_text = f"translate to SQL: {question}"
        writer.writerow({'db_id': db_id, 'input': input_text, 'output': sql})

print("Saved spider_val_with_dbid.csv with columns: db_id, input, output")

Saved spider_val_with_dbid.csv with columns: db_id, input, output


In [11]:
import pandas as pd
df=pd.read_csv('/content/spider_val_with_dbid.csv')

In [12]:
df.head()

Unnamed: 0,db_id,input,output
0,concert_singer,translate to SQL: How many singers do we have?,SELECT count(*) FROM singer
1,concert_singer,translate to SQL: What is the total number of ...,SELECT count(*) FROM singer
2,concert_singer,"translate to SQL: Show name, country, age for ...","SELECT name , country , age FROM singer ORDE..."
3,concert_singer,"translate to SQL: What are the names, countrie...","SELECT name , country , age FROM singer ORDE..."
4,concert_singer,"translate to SQL: What is the average, minimum...","SELECT avg(age) , min(age) , max(age) FROM s..."


In [13]:
print(len(df))

1034


In [14]:
df = pd.read_csv('spider_val_with_dbid.csv')

with open('tables.json') as f:
    schemas = json.load(f)

schema_map = {schema['db_id']: schema for schema in schemas}

def serialize_schema(schema):
    table_names = schema.get('table_names_original', schema.get('table_names', []))
    columns = schema.get('column_names_original', schema.get('column_names', []))
    tables = {i: [] for i in range(len(table_names))}
    for col in columns:
        idx, col_name = col
        if idx != -1 and idx < len(table_names):
            tables[idx].append(col_name)
    serialized = []
    for idx, table_name in enumerate(table_names):
        cols = ", ".join(tables[idx])
        serialized.append(f"Table {table_name} ({cols})")
    return " ".join(serialized)

def get_schema_text(row):
    db_id = row['db_id']
    if db_id in schema_map:
        return serialize_schema(schema_map[db_id])
    else:
        return ""

df['schema'] = df.apply(get_schema_text, axis=1)

df.to_csv('spider_val_with_separate_schema_column.csv', index=False)

print("CSV with separate 'schema' column saved as 'spider_val_with_separate_schema_column.csv'")

CSV with separate 'schema' column saved as 'spider_val_with_separate_schema_column.csv'


In [15]:
df.head()

Unnamed: 0,db_id,input,output,schema
0,concert_singer,translate to SQL: How many singers do we have?,SELECT count(*) FROM singer,"Table stadium (Stadium_ID, Location, Name, Cap..."
1,concert_singer,translate to SQL: What is the total number of ...,SELECT count(*) FROM singer,"Table stadium (Stadium_ID, Location, Name, Cap..."
2,concert_singer,"translate to SQL: Show name, country, age for ...","SELECT name , country , age FROM singer ORDE...","Table stadium (Stadium_ID, Location, Name, Cap..."
3,concert_singer,"translate to SQL: What are the names, countrie...","SELECT name , country , age FROM singer ORDE...","Table stadium (Stadium_ID, Location, Name, Cap..."
4,concert_singer,"translate to SQL: What is the average, minimum...","SELECT avg(age) , min(age) , max(age) FROM s...","Table stadium (Stadium_ID, Location, Name, Cap..."


In [16]:
print(len(df))

1034


In [17]:
print(df.isnull().sum())

db_id     0
input     0
output    0
schema    0
dtype: int64
