In [1]:
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, PreTrainedTokenizerFast, convert_slow_tokenizer
from datasets import load_dataset
from utils import preprocess_function, parse_sql_to_canonical, tokenize, encode_rare_chars
import torch
import json
from lib.dbengine import DBEngine
from lib.query import Query
import os

In [2]:
# Path to checkpoint folder
checkpoint_path = "../models/4_heads_2e-4_lr_constant_512MappingTokenizer_128_bs_64_dff_32_kv_128d_1"

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)

In [3]:
tokenizer = PreTrainedTokenizerFast(tokenizer_object=convert_slow_tokenizer.convert_slow_tokenizer(T5Tokenizer("tokenizers/sp_512_bpe_encoded.model", legacy=False, load_from_cache_file=False)))
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [4]:
path = '../datasets/wikisql'
dataset = load_dataset(path+'/data')

In [5]:
index = 42  # TODO find smart way to choose random samples
sample = dataset["test"].select(range(index, index+1))
print(sample)

Dataset({
    features: ['phase', 'question', 'table', 'sql'],
    num_rows: 1
})


In [6]:
# TODO
# alternatively overwrite question here if people want to try custom questions on the same table
# display table
# sample.question[0] = "Custom question about the same table"

In [7]:
preprocessed_sample = sample.map(preprocess_function, batched=True) # concatenates questions with headers using custom [SEP] token
print(preprocessed_sample["input_text"])
print(preprocessed_sample["label_text"])

['What is the premium associated with tariff code g9?[SEP]Scheme[SEP]Tariff code[SEP]BTs retail price (regulated)[SEP]Approx premium[SEP]Prefixes']
['SELECT Approx premium FROM table WHERE Tariff code = g9']


In [8]:
# Encode rare tokens because the tokenizer doesn't know them
mapping_file_path = 'mapping.json'
reverse_mapping_file_path = 'reverse_mapping.json'

with open(mapping_file_path, 'r', encoding='utf-8') as mapping_file:
    mapping = json.load(mapping_file)

with open(reverse_mapping_file_path, 'r', encoding='utf-8') as reverse_mapping_file:
    reverse_mapping = json.load(reverse_mapping_file)

encoded_preprocessed_sample = preprocessed_sample.map(lambda sample: encode_rare_chars(sample, mapping), batched=True)

In [9]:
tokenized_sample = encoded_preprocessed_sample.map(lambda sample: tokenize(sample, tokenizer), batched=True)
print(tokenized_sample["input_ids"])

[[64, 57, 21, 61, 73, 345, 338, 104, 332, 157, 131, 338, 50, 346, 170, 6, 24, 235, 359, 69, 219, 127, 375, 372, 3, 351, 344, 16, 345, 333, 3, 353, 24, 235, 359, 69, 219, 3, 389, 353, 340, 204, 335, 334, 102, 61, 339, 55, 333, 66, 73, 355, 105, 50, 346, 379, 3, 366, 111, 308, 395, 61, 73, 345, 338, 104, 3, 370, 73, 359, 338, 395, 44, 2, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511]]


In [10]:
def inference(input_ids) -> str:
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        outputs = model.generate(input_ids=torch.tensor(input_ids), num_beams=10, max_length=128)
    output = tokenizer.decode(token_ids=outputs[0][1:], skip_special_tokens=True) # for some reason the beginning of sentence token doesn't get removed properly so we cut it off manually
    return output

In [14]:
def canonical_to_human_readable(canonical_form, table_header):
    """
    Convert canonical SQL form to a human-readable SQL string.

    :param canonical_form: The canonical form containing "sel", "agg", and "conds".
    :param agg_mapping: Dictionary mapping aggregation names to IDs.
    :param cond_mapping: Dictionary mapping condition operators to IDs.
    :return: Human-readable SQL query string.
    """
    agg_mapping = {
        "": 0,
        "MAX": 1,
        "MIN": 2,
        "COUNT": 3,
        "SUM": 4,
        "AVG": 5
    }

    cond_mapping = {'=': 0, '>': 1, '<': 2}
    # Reverse the mappings for easier lookup
    rev_agg_mapping = {v: k for k, v in agg_mapping.items()}
    rev_cond_mapping = {v: k for k, v in cond_mapping.items()}

    # Extract the selected column and aggregation type
    selected_column = table_header[canonical_form["sel"]]
    aggregation = rev_agg_mapping.get(canonical_form["agg"], "")

    # Formulate the SELECT clause
    if aggregation:
        select_clause = f"SELECT {aggregation}({selected_column})"
    else:
        select_clause = f"SELECT {selected_column}"

    # Process conditions
    conditions = []
    for col, op_id, value in canonical_form["conds"]:
        operator = rev_cond_mapping.get(op_id, "=")
        conditions.append(f"{table_header[col]} {operator} {value}")

    # Formulate the WHERE clause if conditions exist
    where_clause = ""
    if conditions:
        where_clause = " WHERE " + " AND ".join(conditions)

    # Combine SELECT and WHERE clauses
    human_readable_query = select_clause + where_clause

    return human_readable_query

In [15]:
def post_processing(query, table_header, reverse_mapping):
    cleaned_canonical = parse_sql_to_canonical(query, table_header, reverse_mapping)
    cleaned_query = canonical_to_human_readable(cleaned_canonical, table_header)
    return cleaned_canonical, cleaned_query

In [17]:
output = inference(tokenized_sample["input_ids"])
pred_canonical, pred_human_readable = post_processing(output, sample["table"][0]["header"], reverse_mapping)
correct_canonical, correct_human_readable = post_processing(sample["sql"][0]["human_readable"], sample["table"][0]["header"], reverse_mapping) # I know this line is stupid but I don't have a better way to get the proper canonical form for the solutions. The one in the data is in a weird format not supported by the db_engine
print(f"Model output: {pred_human_readable}")
print(f"Correct query: {sample["sql"][0]["human_readable"]}")

Model output: SELECT COUNT(Approx premium) WHERE Tariff code = G9
Correct query: SELECT Approx premium FROM table WHERE Tariff code = g9


In [18]:
print(f"Model output: {pred_canonical}")
print(f"Correct query: {correct_canonical}")

Model output: {'sel': 3, 'agg': 3, 'conds': {(1, 0, 'G9')}}
Correct query: {'sel': 3, 'agg': 0, 'conds': {(1, 0, 'g9')}}


In [19]:
db_path = os.path.abspath(path+'/tables/test/test.db')
if not os.path.exists(db_path):
    raise FileNotFoundError(f"Database file not found: {db_path}")
db_engine = DBEngine(db_path)

table_id = sample["table"][0]["id"]
pred_query = Query.from_dict(pred_canonical)
correct_query = Query.from_dict(correct_canonical)
try:
    pred_result = db_engine.execute_query(table_id, pred_query)
except Exception as e:
    pred_result = f"Execution error: {e}"

try:
    gold_result = db_engine.execute_query(table_id, correct_query)
except Exception as e:
    gold_result = f"Execution error: {e}"

print(f"Query result: {pred_result}")
print(f"Correct result: {gold_result}")

Query result: [1]
Correct result: ['4p/min']
