In [24]:
import pandas as pd
import os
from datasets import Dataset
from tqdm import tqdm
import time

Load in the test/validation dataset

In [4]:
# Save df if does not exist
csv_file_path = '/Users/ajaykallepalli/data/SQL_df_validation.csv'
if not os.path.exists(csv_file_path):
    splits = {'test': 'spider/train-00000-of-00001.parquet', 'validation': 'spider/validation-00000-of-00001.parquet'}
    df_test = pd.read_parquet("hf://datasets/xlangai/spider/" + splits["validation"])
    df_test.to_csv(csv_file_path, index=False)
    print("Downloaded data")
else:
    print("Already downloaded")

Downloaded data


In [5]:
df_test.head()

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]"
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]"
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin..."
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ..."
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","[SELECT, avg, (, age, ), ,, min, (, age, ), ,,...","[select, avg, (, age, ), ,, min, (, age, ), ,,...","[What, is, the, average, ,, minimum, ,, and, m..."


In [6]:
#Generate test_gold.sql to test the responses against
queries = df_test['query']

with open('/Users/ajaykallepalli/data/test_gold.sql', 'w') as file:
    for query in queries:
        file.write(query + '\n')

In [8]:
### Creating the input text for the llm
unique_dbs_test = df_test['db_id'].unique()
unique_dbs_test

array(['concert_singer', 'pets_1', 'car_1', 'flight_2',
       'employee_hire_evaluation', 'cre_Doc_Template_Mgt', 'course_teach',
       'museum_visit', 'wta_1', 'battle_death',
       'student_transcripts_tracking', 'tvshow', 'poker_player',
       'voter_1', 'world_1', 'orchestra', 'network_1', 'dog_kennels',
       'singer', 'real_estate_properties'], dtype=object)

In [11]:
# # creating a db with db ID and db schema as columns
df_schemas = pd.DataFrame(columns=['db_id', 'db_schema'])

In [7]:
schema_location = '/Users/ajaykallepalli/data/spider/database/'
schema_name = 'schema.sql'

In [12]:

for uniq_db in unique_dbs_test:
    schema_file = os.path.join(schema_location, uniq_db, schema_name)
    print(schema_file)
    if os.path.exists(schema_file):
        with open(schema_file, 'r') as f:
            schema_content = f.read()
        df_schemas = pd.concat([df_schemas, pd.DataFrame({'db_id': [uniq_db], 'db_schema': [schema_content]})], ignore_index=True)
    else:
        print(f"Schema file not found for {uniq_db}")

/Users/ajaykallepalli/data/spider/database/concert_singer/schema.sql
/Users/ajaykallepalli/data/spider/database/pets_1/schema.sql
/Users/ajaykallepalli/data/spider/database/car_1/schema.sql
Schema file not found for car_1
/Users/ajaykallepalli/data/spider/database/flight_2/schema.sql
Schema file not found for flight_2
/Users/ajaykallepalli/data/spider/database/employee_hire_evaluation/schema.sql
/Users/ajaykallepalli/data/spider/database/cre_Doc_Template_Mgt/schema.sql
/Users/ajaykallepalli/data/spider/database/course_teach/schema.sql
/Users/ajaykallepalli/data/spider/database/museum_visit/schema.sql
/Users/ajaykallepalli/data/spider/database/wta_1/schema.sql
Schema file not found for wta_1
/Users/ajaykallepalli/data/spider/database/battle_death/schema.sql
/Users/ajaykallepalli/data/spider/database/student_transcripts_tracking/schema.sql
/Users/ajaykallepalli/data/spider/database/tvshow/schema.sql
/Users/ajaykallepalli/data/spider/database/poker_player/schema.sql
/Users/ajaykallepalli/

In [13]:
# Save the df_schemas DataFrame to a CSV file
csv_schemas_file_path = '/Users/ajaykallepalli/data/df_schemas.csv'

if not os.path.exists(csv_schemas_file_path):
    df_schemas.to_csv(csv_schemas_file_path, index=False)
    print(f"Schemas CSV file saved to {csv_schemas_file_path}")
else:
    print(f"Schemas CSV file already exists at {csv_schemas_file_path}")

# Print the shape of df_schemas
print("Shape of df_schemas:", df_schemas.shape)

Schemas CSV file already exists at /Users/ajaykallepalli/data/df_schemas.csv
Shape of df_schemas: (15, 2)


In [14]:
## Joining the df_schema to the test data.
df_test = df_test.merge(df_schemas, on='db_id', how='left')

# Check for any missing schemas
missing_schemas = df_test[df_test['db_schema'].isnull()]
if not missing_schemas.empty:
    print(f"Warning: {len(missing_schemas)} rows have missing schemas.")
    print("Unique db_ids with missing schemas:", missing_schemas['db_id'].unique())

# Verify the merge
print("Shape of df_test after merge:", df_test.shape)
print("Columns in df_test:", df_test.columns)

# Check for null values in the merged dataframe
null_counts = df_test.isnull().sum()
print("Null value counts:\n", null_counts)



Unique db_ids with missing schemas: ['car_1' 'flight_2' 'wta_1' 'voter_1' 'world_1']
Shape of df_test after merge: (1034, 7)
Columns in df_test: Index(['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value',
       'question_toks', 'db_schema'],
      dtype='object')
Null value counts:
 db_id                    0
query                    0
question                 0
query_toks               0
query_toks_no_value      0
question_toks            0
db_schema              369
dtype: int64


In [16]:
# Remove all NA values from df_test
df_schema_test = df_test.dropna()
# Group by 'db_id' and keep only the first 20 rows for each schema
df_schema_test = df_schema_test.groupby('db_id').apply(lambda x: x.head(20)).reset_index(drop=True)

# Print the shape of the new dataframe after keeping only 20 of each schema
print("Shape of df_schema_test after keeping 20 of each schema:", df_schema_test.shape)

# Print the number of unique schemas
print("Number of unique schemas:", df_schema_test['db_id'].nunique())

# Define the path for the CSV file
csv_schema_file_path = '/Users/ajaykallepalli/data/df_schema_test.csv'

# Save the new dataframe to the same folder
if not os.path.exists(csv_schema_file_path):
    df_schema_test.to_csv(csv_schema_file_path, index=False)
    print(f"File saved to {csv_schema_file_path}")
else:
    print(f"File already exists at {csv_schema_file_path}")

# Print the shape of the new dataframe
print("Shape of df_schema_test after removing NA values:", df_schema_test.shape)

Shape of df_schema_test after keeping 20 of each schema: (278, 7)
Number of unique schemas: 15
File saved to /Users/ajaykallepalli/data/df_schema_test.csv
Shape of df_schema_test after removing NA values: (278, 7)


  df_schema_test = df_schema_test.groupby('db_id').apply(lambda x: x.head(20)).reset_index(drop=True)


In [20]:
df_schema_test['instruction'] = 'You are an agent designed to interact with a SQL database. \n Given an input question, create a syntactically correct SQL query to provide to the user based on the below schema:\n'

In [21]:
df_schema_test.head()

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks,db_schema,instruction
0,battle_death,SELECT count(*) FROM ship WHERE disposition_of...,How many ships ended up being 'Captured'?,"[SELECT, count, (, *, ), FROM, ship, WHERE, di...","[select, count, (, *, ), from, ship, where, di...","[How, many, ships, ended, up, being, 'Captured...","PRAGMA foreign_keys = ON;\nCREATE TABLE ""battl...",You are an agent designed to interact with a S...
1,battle_death,"SELECT name , tonnage FROM ship ORDER BY name...",List the name and tonnage ordered by in descen...,"[SELECT, name, ,, tonnage, FROM, ship, ORDER, ...","[select, name, ,, tonnage, from, ship, order, ...","[List, the, name, and, tonnage, ordered, by, i...","PRAGMA foreign_keys = ON;\nCREATE TABLE ""battl...",You are an agent designed to interact with a S...
2,battle_death,"SELECT name , date FROM battle","List the name, date and result of each battle.","[SELECT, name, ,, date, FROM, battle]","[select, name, ,, date, from, battle]","[List, the, name, ,, date, and, result, of, ea...","PRAGMA foreign_keys = ON;\nCREATE TABLE ""battl...",You are an agent designed to interact with a S...
3,battle_death,"SELECT max(killed) , min(killed) FROM death",What is maximum and minimum death toll caused ...,"[SELECT, max, (, killed, ), ,, min, (, killed,...","[select, max, (, killed, ), ,, min, (, killed,...","[What, is, maximum, and, minimum, death, toll,...","PRAGMA foreign_keys = ON;\nCREATE TABLE ""battl...",You are an agent designed to interact with a S...
4,battle_death,SELECT avg(injured) FROM death,What is the average number of injuries caused ...,"[SELECT, avg, (, injured, ), FROM, death]","[select, avg, (, injured, ), from, death]","[What, is, the, average, number, of, injuries,...","PRAGMA foreign_keys = ON;\nCREATE TABLE ""battl...",You are an agent designed to interact with a S...


Creating a List of input questions for the llm using the format of system instruction, schema, and Answer:

In [None]:
SQL_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}
{}

### Input:
{}

### Response: """

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    schema = examples["db_schema"]
    inputs = examples["question"]
    outputs = examples["query"]
    texts = []
    for instruction, schema, input, output in zip(instructions, schema, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = SQL_prompt.format(instruction, schema[0:15000], input, output)
        texts.append(text)
    return { "text" : texts, }
pass
dataset = Dataset.from_pandas(df_schema_test)
dataset = dataset.map(formatting_prompts_func, batched = True,)