In [94]:
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, PreTrainedTokenizerFast, convert_slow_tokenizer
from datasets import load_dataset
from utils import preprocess_function, parse_sql_to_canonical, tokenize
import torch
import json
from lib.dbengine import DBEngine
from lib.query import Query
import os

In [76]:
# Path to checkpoint folder
checkpoint_path = "../models/2_heads_2e-4_lr_constant_512MappingTokenizer_128_bs_32_dff_1"

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)

In [77]:
tokenizer = PreTrainedTokenizerFast(tokenizer_object=convert_slow_tokenizer.convert_slow_tokenizer(T5Tokenizer("tokenizers/sp_512_bpe_encoded.model", legacy=False, load_from_cache_file=False)))
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [78]:
path = '../datasets/wikisql'
dataset = load_dataset(path+'/data')

In [176]:
index = 1338  # TODO find smart way to choose random samples
sample = dataset["test"].select(range(index, index+1))
print(sample)

Dataset({
    features: ['phase', 'question', 'table', 'sql'],
    num_rows: 1
})


In [177]:
# TODO
# alternatively overwrite question here if people want to try custom questions on the same table
# display table
# sample.question[0] = "Custom question about the same table"

In [178]:
preprocessed_sample = sample.map(preprocess_function, batched=True) # concatenates questions with headers using custom [SEP] token
print(preprocessed_sample["input_text"])
print(preprocessed_sample["label_text"])

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

['What theme name has the original artist of Dolly Parton?[SEP]Week #[SEP]Theme[SEP]Song choice[SEP]Original artist[SEP]Order #[SEP]Result']
['SELECT Theme FROM table WHERE Original artist = Dolly Parton']


In [179]:
# Encode rare tokens because the tokenizer doesn't know them
mapping_file_path = 'mapping.json'
reverse_mapping_file_path = 'reverse_mapping.json'

with open(mapping_file_path, 'r', encoding='utf-8') as mapping_file:
    mapping = json.load(mapping_file)

with open(reverse_mapping_file_path, 'r', encoding='utf-8') as reverse_mapping_file:
    reverse_mapping = json.load(reverse_mapping_file)

encoded_preprocessed_sample = preprocessed_sample.map(lambda sample: encode_rare_chars(sample, mapping), batched=True)

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [180]:
tokenized_sample = encoded_preprocessed_sample.map(lambda sample: tokenize(sample, tokenizer), batched=True)
print(tokenized_sample["input_ids"])

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

[[64, 21, 345, 333, 103, 74, 168, 21, 332, 14, 296, 332, 174, 252, 51, 129, 94, 341, 356, 68, 174, 10, 372, 3, 349, 246, 270, 3, 353, 16, 345, 333, 3, 351, 10, 355, 69, 342, 336, 55, 333, 3, 361, 339, 296, 332, 174, 252, 3, 361, 339, 346, 8, 270, 3, 248, 2, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511]]


In [181]:
def inference(input_ids) -> str:
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        outputs = model.generate(input_ids=torch.tensor(input_ids), num_beams=10, max_length=128)
    output = tokenizer.decode(token_ids=outputs[0][1:], skip_special_tokens=True) # for some reason the beginning of sentence token doesn't get removed properly so we cut it off manually
    return output

In [182]:
def post_processing(query, table_header, reverse_mapping):
    cleaned_canonical = parse_sql_to_canonical(query, table_header, reverse_mapping)
    cleaned_query = None
    # cleaned_query = make_canonical_human_readable  TODO (we need this function; remove line above after wards)
    return cleaned_canonical, cleaned_query

In [183]:
pred_canonical, pred_human_readable = post_processing(inference(tokenized_sample["input_ids"]), sample["table"][0]["header"], reverse_mapping)
correct_canonical, correct_human_readable = post_processing(sample["sql"][0]["human_readable"], sample["table"][0]["header"], reverse_mapping) # I know this line is stupid but I don't have a better way to get the proper canonical form for the solutions. The one in the data is in a weird format not supported by the db_engine
print(f"Model output: {pred_human_readable}")
print(f"Correct query: {sample["sql"][0]["human_readable"]}")

Model output: None
Correct query: SELECT Theme FROM table WHERE Original artist = Dolly Parton


In [184]:
print(f"Model output: {pred_canonical}")
print(f"Correct query: {correct_canonical}")

Model output: {'sel': 1, 'agg': 0, 'conds': {(3, 0, 'Dolly Parton')}}
Correct query: {'sel': 1, 'agg': 0, 'conds': {(3, 0, 'Dolly Parton')}}


In [185]:
db_path = os.path.abspath(path+'/tables/test/test.db')
if not os.path.exists(db_path):
    raise FileNotFoundError(f"Database file not found: {db_path}")
db_engine = DBEngine(db_path)

table_id = sample["table"][0]["id"]
pred_query = Query.from_dict(pred_canonical)
correct_query = Query.from_dict(correct_canonical)
try:
    pred_result = db_engine.execute_query(table_id, pred_query)
except Exception as e:
    pred_result = f"Execution error: {e}"

try:
    gold_result = db_engine.execute_query(table_id, correct_query)
except Exception as e:
    gold_result = f"Execution error: {e}"

print(f"Query result: {pred_result}")
print(f"Correct result: {gold_result}")

Query result: ['dolly parton']
Correct result: ['dolly parton']
