In [None]:
import json
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, PreTrainedTokenizerFast, convert_slow_tokenizer
from utils import preprocess_function, tokenize, encode_rare_chars, create_metrics_computer
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch
import wandb

In [None]:
checkpoint_path = "../models/4_heads_2e-4_lr_constant_512MappingTokenizer_128_bs_64_dff_32_kv_128d_1"

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)

In [None]:
mapping_file_path = 'mapping.json'
reverse_mapping_file_path = 'reverse_mapping.json'

with open(mapping_file_path, 'r', encoding='utf-8') as mapping_file:
    mapping = json.load(mapping_file)

with open(reverse_mapping_file_path, 'r', encoding='utf-8') as reverse_mapping_file:
    reverse_mapping = json.load(reverse_mapping_file)

In [None]:
tokenizer = PreTrainedTokenizerFast(tokenizer_object=convert_slow_tokenizer.convert_slow_tokenizer(T5Tokenizer("tokenizers/sp_512_bpe_encoded.model", legacy=False, load_from_cache_file=False)))
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

In [None]:
path = '../datasets/wikisql'
dataset = load_dataset(path+'/data')

In [None]:
preprocessed_dataset = dataset.map(preprocess_function, batched=True, batch_size=2048)
preprocessed_dataset = preprocessed_dataset.map(lambda batch: encode_rare_chars(batch, mapping), batched=True, batch_size=2048)
tokenized_dataset = preprocessed_dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True, batch_size=2048)

In [None]:
train_data = tokenized_dataset["train"]
val_data = tokenized_dataset["validation"]
train_data

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./" + checkpoint_path + "/eval",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    eval_strategy="epoch",
    num_train_epochs=50,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=512,
    predict_with_generate=True,
    generation_max_length=64,
    generation_num_beams=1,
    optim="lion_32bit"
)

compute_metrics = create_metrics_computer(val_data, tokenizer, path+'/tables/validation/dev.db')

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    compute_metrics=compute_metrics
)

In [None]:
trainer.evaluate()

In [None]:
def canonical_to_human_readable(canonical_form):
    """
    Convert canonical SQL form to a human-readable SQL string.

    :param canonical_form: The canonical form containing "sel", "agg", and "conds".
    :param agg_mapping: Dictionary mapping aggregation names to IDs.
    :param cond_mapping: Dictionary mapping condition operators to IDs.
    :return: Human-readable SQL query string.
    """
    agg_mapping = {
        "": 0,
        "MAX": 1,
        "MIN": 2,
        "COUNT": 3,
        "SUM": 4,
        "AVG": 5
    }

    cond_mapping = {'=': 0, '>': 1, '<': 2}
    # Reverse the mappings for easier lookup
    rev_agg_mapping = {v: k for k, v in agg_mapping.items()}
    rev_cond_mapping = {v: k for k, v in cond_mapping.items()}

    # Extract the selected column and aggregation type
    selected_column = canonical_form["sel"]
    aggregation = rev_agg_mapping.get(canonical_form["agg"], "")

    # Formulate the SELECT clause
    if aggregation:
        select_clause = f"SELECT {aggregation}({selected_column})"
    else:
        select_clause = f"SELECT {selected_column}"

    # Process conditions
    conditions = []
    for col, op_id, value in canonical_form["conds"]:
        operator = rev_cond_mapping.get(op_id, "=")
        conditions.append(f"{col} {operator} {value}")

    # Formulate the WHERE clause if conditions exist
    where_clause = ""
    if conditions:
        where_clause = " WHERE " + " AND ".join(conditions)

    # Combine SELECT and WHERE clauses
    human_readable_query = select_clause + where_clause

    return human_readable_query

In [None]:
# manually validate model
input_ids = val_data["input_ids"]
labels = val_data["labels"]
tables = val_data["table"]

# Run the model to generate predictions
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation
    predictions = model.generate(input_ids=torch.tensor(input_ids).to(torch.device("cuda")), num_beams=5, max_length=128)

print(predictions, labels)

In [None]:
def post_processing(query, table_header, reverse_mapping):
    cleaned_canonical = parse_sql_to_canonical(query, table_header, reverse_mapping)
    cleaned_query = canonical_to_human_readable(cleaned_canonical)
    return cleaned_query

In [None]:
# Decode predictions and labels
input_text = [post_processing(tokenizer.decode(inputs, skip_special_tokens=True), table["header"], reverse_mapping) for inputs, table in zip(input_ids, tables)]
predictions_text = [post_processing(tokenizer.decode(pred, skip_special_tokens=True), table["header"], reverse_mapping) for pred, table in zip(predictions, tables)]
labels_text = [tokenizer.decode(label, skip_special_tokens=True) for label in labels]
print(input_text)
print(predictions_text)
print(labels_text)

In [None]:
wandb.init(project="ablation-studies2", name="wrong predictions")
# Initialize the wandb.Table
table = wandb.Table(columns=["Input", "Prediction", "Correct Output"])

# Add rows to the table
for inp, pred, correct in zip(input_text, predictions_text, labels_text):
    match = pred == correct
    if match: continue
    print(f"Adding row: {inp}, {pred}, {correct}")  # Debugging
    table.add_data(inp, pred, correct)

# Log the table
wandb.log({"Predictions Table": table})
