In [75]:
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, PreTrainedTokenizerFast, convert_slow_tokenizer
from datasets import load_dataset
from utils import preprocess_function, parse_sql_to_canonical, tokenize
import torch
import json

In [76]:
# Path to checkpoint folder
checkpoint_path = "../models/2_heads_2e-4_lr_constant_512MappingTokenizer_128_bs_32_dff_1"

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)

In [77]:
tokenizer = PreTrainedTokenizerFast(tokenizer_object=convert_slow_tokenizer.convert_slow_tokenizer(T5Tokenizer("tokenizers/sp_512_bpe_encoded.model", legacy=False, load_from_cache_file=False)))
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [78]:
path = '../datasets/wikisql'
dataset = load_dataset(path+'/data')

In [79]:
index = 42  # TODO find smart way to choose random samples
sample = dataset["test"].select(range(index, index+1))
print(sample)

Dataset({
    features: ['phase', 'question', 'table', 'sql'],
    num_rows: 1
})


In [80]:
# TODO
# alternatively overwrite question here if people want to try custom questions on the same table
# display table
# sample.question[0] = "Custom question about the same table"

In [81]:
preprocessed_sample = sample.map(preprocess_function, batched=True) # concatenates questions with headers using custom [SEP] token
print(preprocessed_sample["input_text"])
print(preprocessed_sample["label_text"])

['What is the premium associated with tariff code g9?[SEP]Scheme[SEP]Tariff code[SEP]BTs retail price (regulated)[SEP]Approx premium[SEP]Prefixes']
['SELECT Approx premium FROM table WHERE Tariff code = g9']


In [82]:
# Encode rare tokens because the tokenizer doesn't know them
mapping_file_path = 'mapping.json'
reverse_mapping_file_path = 'reverse_mapping.json'

with open(mapping_file_path, 'r', encoding='utf-8') as mapping_file:
    mapping = json.load(mapping_file)

with open(reverse_mapping_file_path, 'r', encoding='utf-8') as reverse_mapping_file:
    reverse_mapping = json.load(reverse_mapping_file)

encoded_preprocessed_sample = preprocessed_sample.map(lambda sample: encode_rare_chars(sample, mapping), batched=True)

In [83]:
tokenized_sample = encoded_preprocessed_sample.map(lambda sample: tokenize(sample, tokenizer), batched=True)
print(tokenized_sample["input_ids"])

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

[[64, 57, 21, 61, 73, 345, 338, 104, 332, 157, 131, 338, 50, 346, 170, 6, 24, 235, 359, 69, 219, 127, 375, 372, 3, 351, 344, 16, 345, 333, 3, 353, 24, 235, 359, 69, 219, 3, 389, 353, 340, 204, 335, 334, 102, 61, 339, 55, 333, 66, 73, 355, 105, 50, 346, 379, 3, 366, 111, 308, 395, 61, 73, 345, 338, 104, 3, 370, 73, 359, 338, 395, 44, 2, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511, 511]]


In [84]:
def inference(input_ids) -> str:
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        outputs = model.generate(input_ids=torch.tensor(input_ids), num_beams=10, max_length=128)
    output = tokenizer.decode(token_ids=outputs[0][1:], skip_special_tokens=True) # for some reason the beginning of sentence token doesn't get removed properly so we cut it off manually
    return output

In [85]:
def post_processing(query, table_header, reverse_mapping):
    cleaned_canonical = parse_sql_to_canonical(query, table_header, reverse_mapping)
    cleaned_query = None
    # cleaned_query = make_canonical_human_readable  TODO (we need this function; remove line above after wards)
    return cleaned_canonical, cleaned_query

In [86]:
canonical, human_readable = post_processing(inference(tokenized_sample["input_ids"]), sample["table"][0]["header"], reverse_mapping)
print(f"Model output: {human_readable}")
print(f"Correct query: {sample["sql"][0]["human_readable"]}")

# TODO run canonical result and solution on database to see if it works
print(f"Query result: ")
print(f"Correct query: ")

Model output: None
Correct query: SELECT Approx premium FROM table WHERE Tariff code = g9
Query result: 
Correct query: 
