In [1]:
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, PreTrainedTokenizerFast, convert_slow_tokenizer
from datasets import load_dataset
from utils import preprocess_function, parse_sql_to_canonical, tokenize, encode_rare_chars
import torch
import json
from lib.dbengine import DBEngine
from lib.query import Query
import os

In [2]:
# Path to checkpoint folder
checkpoint_path = "../models/4_heads_2e-4_lr_constant_512MappingTokenizer_128_bs_64_dff_32_kv_128d_1"

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)

In [3]:
tokenizer = PreTrainedTokenizerFast(tokenizer_object=convert_slow_tokenizer.convert_slow_tokenizer(T5Tokenizer("tokenizers/sp_512_bpe_encoded.model", legacy=False, load_from_cache_file=False)))
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [4]:
path = '../datasets/wikisql'
dataset = load_dataset(path+'/data')

In [5]:
index = 345  # TODO find smart way to choose random samples
sample = dataset["test"].select(range(index, index+1))
print(sample)

Dataset({
    features: ['phase', 'question', 'table', 'sql'],
    num_rows: 1
})


In [6]:
# TODO
# alternatively overwrite question here if people want to try custom questions on the same table
# display table
# sample.question[0] = "Custom question about the same table"

In [7]:
preprocessed_sample = sample.map(preprocess_function, batched=True) # concatenates questions with headers using custom [SEP] token
print(preprocessed_sample["input_text"])
print(preprocessed_sample["label_text"])

[" what's the\xa0notes\xa0where\xa0number range\xa0is 17[SEP]Number Range[SEP]Introduced[SEP]Builder[SEP]Engine[SEP]Weight (long tons)[SEP]Seats[SEP]Withdrawn[SEP]Notes"]
['SELECT Notes FROM table WHERE Number Range = 17']


In [8]:
# Encode rare tokens because the tokenizer doesn't know them
mapping_file_path = 'mapping.json'
reverse_mapping_file_path = 'reverse_mapping.json'

with open(mapping_file_path, 'r', encoding='utf-8') as mapping_file:
    mapping = json.load(mapping_file)

with open(reverse_mapping_file_path, 'r', encoding='utf-8') as reverse_mapping_file:
    reverse_mapping = json.load(reverse_mapping_file)

encoded_preprocessed_sample = preprocessed_sample.map(lambda sample: encode_rare_chars(sample, mapping), batched=True)

In [9]:
tokenized_sample = encoded_preprocessed_sample.map(lambda sample: tokenize(sample, tokenizer), batched=True)
print(tokenized_sample["input_ids"])

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

[[245, 400, 340, 21, 4, 376, 382, 364, 5, 337, 100, 44, 4, 376, 382, 364, 5, 357, 342, 229, 4, 376, 382, 364, 5, 337, 176, 130, 13, 113, 4, 376, 382, 364, 5, 22, 49, 382, 3, 371, 176, 88, 13, 113, 3, 388, 337, 117, 109, 348, 344, 60, 3, 389, 348, 102, 346, 8, 3, 343, 337, 355, 12, 333, 3, 349, 333, 300, 66, 341, 10, 355, 6, 10, 340, 379, 3, 351, 333, 7, 340, 3, 349, 154, 346, 312, 337, 3, 371, 100, 44, 2, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511]]


In [10]:
def inference(input_ids) -> str:
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        outputs = model.generate(input_ids=torch.tensor(input_ids), num_beams=10, max_length=128)
    output = tokenizer.decode(token_ids=outputs[0][1:], skip_special_tokens=True) # for some reason the beginning of sentence token doesn't get removed properly so we cut it off manually
    return output

In [11]:
def post_processing(query, table_header, reverse_mapping):
    cleaned_canonical = parse_sql_to_canonical(query, table_header, reverse_mapping)
    cleaned_query = None
    # cleaned_query = make_canonical_human_readable  TODO (we need this function; remove line above after wards)
    return cleaned_canonical, cleaned_query

In [12]:
pred_canonical, pred_human_readable = post_processing(inference(tokenized_sample["input_ids"]), sample["table"][0]["header"], reverse_mapping)
correct_canonical, correct_human_readable = post_processing(sample["sql"][0]["human_readable"], sample["table"][0]["header"], reverse_mapping) # I know this line is stupid but I don't have a better way to get the proper canonical form for the solutions. The one in the data is in a weird format not supported by the db_engine
print(f"Model output: {pred_human_readable}")
print(f"Correct query: {sample["sql"][0]["human_readable"]}")

Model output: None
Correct query: SELECT Notes FROM table WHERE Number Range = 17


In [13]:
print(f"Model output: {pred_canonical}")
print(f"Correct query: {correct_canonical}")

Model output: {'sel': 7, 'agg': 0, 'conds': {(0, 0, '17')}}
Correct query: {'sel': 7, 'agg': 0, 'conds': {(0, 0, '17')}}


In [14]:
db_path = os.path.abspath(path+'/tables/test/test.db')
if not os.path.exists(db_path):
    raise FileNotFoundError(f"Database file not found: {db_path}")
db_engine = DBEngine(db_path)

table_id = sample["table"][0]["id"]
pred_query = Query.from_dict(pred_canonical)
correct_query = Query.from_dict(correct_canonical)
try:
    pred_result = db_engine.execute_query(table_id, pred_query)
except Exception as e:
    pred_result = f"Execution error: {e}"

try:
    gold_result = db_engine.execute_query(table_id, correct_query)
except Exception as e:
    gold_result = f"Execution error: {e}"

print(f"Query result: {pred_result}")
print(f"Correct result: {gold_result}")

Query result: ['parcels car, capacity long tons (t; short tons)']
Correct result: ['parcels car, capacity long tons (t; short tons)']
