In [1]:
from datasets import load_dataset, DatasetDict, load_from_disk, concatenate_datasets, disable_caching
disable_caching()

In [2]:
# Specify the folder where the dataset was saved
cache_dir = "./datasets"

# Load the dataset from the saved folder
dataset = load_dataset("b-mc2/sql-create-context", cache_dir=cache_dir)

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 78577
    })
})

In [4]:
train_data = dataset['train'].take(2000)
train_data[0]

{'answer': 'SELECT COUNT(*) FROM head WHERE age > 56',
 'question': 'How many heads of the departments are older than 56 ?',
 'context': 'CREATE TABLE head (age INTEGER)'}

In [5]:
train_testval_split = train_data.train_test_split(test_size=0.25, seed=42)

# Further split the test+val set into validation and test (e.g., 50% of 20% = 10% each)
test_val_split = train_testval_split["test"].train_test_split(test_size=0.5, seed=42)

# Combine splits into a DatasetDict
split_dataset = DatasetDict({
    "train": train_testval_split["train"],
    "val": test_val_split["train"],
    "test": test_val_split["test"],
})

In [6]:
split_dataset

DatasetDict({
    train: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 1500
    })
    val: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 250
    })
    test: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 250
    })
})

In [7]:
# save_path = "./datasets/sql-create-context-split"
# split_dataset.save_to_disk(save_path)

In [8]:
# split_dataset = load_from_disk("./datasets/sql-create-context-split")

In [9]:
# split_dataset

In [10]:
eval_dataset = load_dataset("json", data_files={'eval':"./datasets/sql_eval_dataset_executable_gold.json"})
eval_dataset

DatasetDict({
    eval: Dataset({
        features: ['question', 'answer', 'db_name', 'context', 'query_category'],
        num_rows: 250
    })
})

In [11]:
eval_train_test = eval_dataset["eval"].train_test_split(test_size=0.4, seed=42)
eval_train_test

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'db_name', 'context', 'query_category'],
        num_rows: 150
    })
    test: Dataset({
        features: ['question', 'answer', 'db_name', 'context', 'query_category'],
        num_rows: 100
    })
})

In [12]:
eval_train_test["train_append"] = eval_train_test["train"].select_columns(['answer', 'question', 'context'])
eval_train_test

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'db_name', 'context', 'query_category'],
        num_rows: 150
    })
    test: Dataset({
        features: ['question', 'answer', 'db_name', 'context', 'query_category'],
        num_rows: 100
    })
    train_append: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 150
    })
})

In [13]:
def prepare_dataset_for_training(dataset, prompt_file):
    with open(prompt_file, "r") as f:
        prompt = f.read()
    columns = dataset.features.keys()

    def preprocess_function(sample):
        sample["text"] = prompt.format(
            user_question=sample["question"],
            table_metadata_string=sample["context"],
            sql=(
                sample["answer"]
                if sample["answer"].endswith(";")
                else sample["answer"] + ";"
            ),
        ).strip()

        return sample

    train_dataset = dataset.map(
        preprocess_function,
        remove_columns=columns,
    )
    return train_dataset

In [14]:
eval_train = prepare_dataset_for_training(eval_train_test["train_append"], prompt_file="./prompts/prompt_v4_postgres_train.md")
eval_train

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Dataset({
    features: ['text'],
    num_rows: 150
})

In [15]:
eval_train[0]

{'text': '### Task\nGenerate a SQL query to answer [QUESTION]What is the total number of credits earned by students in each program?[/QUESTION]\n\n### Instructions\n- Use PostgreSQL Syntax\n- End the SQL query with ";"\n\n### Database Schema\nThe query will run on a database with the following schema:\nCREATE TABLE public.area (course_id BIGINT, area TEXT);\nCREATE TABLE public.comment_instructor (instructor_id BIGINT DEFAULT \'0\'::BIGINT NOT NULL, student_id BIGINT DEFAULT \'0\'::BIGINT NOT NULL, score BIGINT, comment_text TEXT);\nCREATE TABLE public.course (course_id BIGINT DEFAULT \'0\'::BIGINT NOT NULL, name TEXT, department TEXT, number TEXT, credits TEXT, advisory_requirement TEXT, enforced_requirement TEXT, description TEXT, num_semesters BIGINT, num_enrolled BIGINT, has_discussion BOOLEAN, has_lab BOOLEAN, has_projects BOOLEAN, has_exams BOOLEAN, num_reviews BIGINT, clarity_score BIGINT, easiness_score BIGINT, helpfulness_score BIGINT);\nCREATE TABLE public.course_offering (of

In [16]:
split_train = prepare_dataset_for_training(split_dataset["train"], prompt_file="./prompts/prompt_v4_train.md")
split_train

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Dataset({
    features: ['text'],
    num_rows: 1500
})

In [17]:
split_train[0]

{'text': '### Task\nGenerate a SQL query to answer [QUESTION]Which type of policy is most frequently used? Give me the policy type code.[/QUESTION]\n\n### Instructions\n- End the SQL query with ";"\n- Do not explain the Answer SQL\n\n### Database Schema\nThe query will run on a database with the following schema:\nCREATE TABLE policies (policy_type_code VARCHAR)\n\n### Answer\nGiven the database schema, here is the SQL query that answers [QUESTION]Which type of policy is most frequently used? Give me the policy type code.[/QUESTION]\n[SQL]SELECT policy_type_code FROM policies GROUP BY policy_type_code ORDER BY COUNT(*) DESC LIMIT 1;[/SQL]'}

In [18]:
split_dataset["train"] = concatenate_datasets([split_train] + [eval_train]*10).shuffle(seed=42)
split_dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 3000
    })
    val: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 250
    })
    test: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 250
    })
})

In [19]:
split_dataset["eval_train"] = eval_train_test["train"]
split_dataset["eval_test"] = eval_train_test["test"]
split_dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 3000
    })
    val: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 250
    })
    test: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 250
    })
    eval_train: Dataset({
        features: ['question', 'answer', 'db_name', 'context', 'query_category'],
        num_rows: 150
    })
    eval_test: Dataset({
        features: ['question', 'answer', 'db_name', 'context', 'query_category'],
        num_rows: 100
    })
})

In [20]:
save_path = "./datasets/train_merge_150x10_1500_diff_prompt_gold"
split_dataset.save_to_disk(save_path)

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/250 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/250 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/150 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

In [21]:
split_dataset["train"][-1]

{'text': '### Task\nGenerate a SQL query to answer [QUESTION]What is the ratio of papers that have more than 1 keyphrases to papers that have 1 keyphrase?[/QUESTION]\n\n### Instructions\n- Use PostgreSQL Syntax\n- End the SQL query with ";"\n\n### Database Schema\nThe query will run on a database with the following schema:\nCREATE TABLE public.author (authorid bigint NOT NULL, authorname text);\n\nCREATE TABLE public.cite (citingpaperid bigint NOT NULL, citedpaperid bigint NOT NULL);\n\nCREATE TABLE public.dataset (datasetid bigint NOT NULL, datasetname text);\n\nCREATE TABLE public.field (fieldid bigint);\n\nCREATE TABLE public.journal (journalid bigint NOT NULL, journalname text);\n\nCREATE TABLE public.keyphrase (keyphraseid bigint NOT NULL, keyphrasename text);\n\nCREATE TABLE public.paper (paperid bigint NOT NULL, title text, venueid bigint, year bigint, numciting bigint, numcitedby bigint, journalid bigint);\n\nCREATE TABLE public.paperdataset (paperid bigint, datasetid bigint);\