In [0]:
from datasets import load_dataset , Dataset, concatenate_datasets 
import numpy as np
import pandas as pd
import random

Create Pandas dataset from Huggingface Dataset at "b-mc2/sql-create-context" at https://huggingface.co/datasets/b-mc2/sql-create-context
The creator of the dataset has done some excellent preprocessing after combining data from the open tex to sql datasets WikiSQL and Spider

In [0]:
ds = load_dataset("b-mc2/sql-create-context")
dspd = pd.DataFrame(ds)
display(dspd)

Normalize the json objects in each record to a tabular form with separate columns

In [0]:
import json 
dsdf_ = pd.json_normalize(dspd['train'])
display(dsdf_)

In [0]:
# Define a function to split and process each value in the 'context' column
def split_clean_ctas(statement):
    statements = statement.split(";")
    statements = [s.replace("CREATE TABLE", "").replace("VARCHAR", "STRING").strip() for s in statements if s.strip()]
    return statements

Apply this to the context column of the dataframe

In [0]:
dsdf_['context'] = dsdf_['context'].apply(lambda x: split_clean_ctas(x)).tolist()


In [0]:
display(dsdf_)

In [0]:
dsdf_.shape

The goal of the next few code cells is to create a 'context' for the model that consists of more than just the relevant table for the given question. This is to ensure that the model learns to select the schema of the most relevant table if provided few relevant ones.

In [0]:
unique_items = dsdf_['context'].explode().unique().tolist()


In [0]:
dsdf_['count'] = dsdf_['context'].apply(lambda x: len(x))


Here, for 30% of records where the query is based off of a single table, we add another different table schema to the context. For another 20% we add another 2 table schema

In [0]:
# Calculate the number of entries for each condition
total_count1_entries = int(0.3 * dsdf_[dsdf_['count'] == 1].shape[0])
total_count2_entries = int(0.2 * dsdf_[dsdf_['count'] == 1].shape[0])

# Get the indices of entries with count==1
count1_indices = dsdf_[dsdf_['count'] == 1].index

# Append random items to the 'tables' column based on the conditions
for i in count1_indices[:total_count1_entries]:
    existing_tables = dsdf_.loc[i, 'context']
    random_item = random.choice([item for item in unique_items if item not in existing_tables])
    dsdf_.at[i, 'context'].append(random_item)

for i in count1_indices[total_count1_entries:total_count1_entries+total_count2_entries]:
    existing_tables = dsdf_.loc[i, 'context']
    random_items = random.sample([item for item in unique_items if item not in existing_tables], 2)
    dsdf_.at[i, 'context'].extend(random_items)

# Drop the 'count' column
dsdf_.drop('count', axis=1, inplace=True)


In [0]:
display(dsdf_)

In [0]:
dsdf_.shape

The next step is to shuffle the table schema in the context, such that the model does not learn to 'cheat' by picking the very first table schema in the list of schemas

In [0]:
#To shuffle the table names
dsdf_['context'] = dsdf_['context'].apply(lambda x: random.sample(x, len(x)))
display(dsdf_)

In the next few cells, we convert the context list to a string and combine the question and context in a specific format, such that it can be treated as a sequence to sequence pair

In [0]:
dsdf_['context'] = dsdf_['context'].apply(lambda x: ', '.join(map(str, x)))
display(dsdf_)

In [0]:
dsdf_['question_formatted'] = dsdf_['question'] + ', schema: ' + dsdf_['context']

In [0]:
dsdf_ = dsdf_.rename(columns={'answer': 'query'})
dsdf_final = dsdf_[['question_formatted', 'query']]
spdf_final = spark.createDataFrame(dsdf_final)
display(spdf_final)

Write this out to a table so that it can be easily retrieved for finetuning our model later

In [0]:
spdf_final.write.saveAsTable('hf_sql_dataset')

In [0]:
%sql
SELECT * FROM
hive_metastore.default.hf_sql_dataset