In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoConfig, AutoTokenizer, T5Tokenizer, Trainer, TrainingArguments, PreTrainedTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, convert_slow_tokenizer
from utils import filter_function, preprocess_function, tokenize, create_metrics_computer
import torch
import wandb

In [None]:
config = AutoConfig.from_pretrained("google/t5-efficient-tiny")
# Initialize the model from scratch using the configuration
model = AutoModelForSeq2SeqLM.from_config(config)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained("google/t5-efficient-tiny")
tokenizer = PreTrainedTokenizerFast(tokenizer_object=convert_slow_tokenizer.convert_slow_tokenizer(T5Tokenizer("tokenizers/sp_8k_bpe_1.model", legacy=False)))
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# tokenizer = T5Tokenizer(vocab_file="tokenizers/sp_16k_bpe_1.model", legacy=False, load_from_cache_file=False)
model.resize_token_embeddings(len(tokenizer))

In [None]:
tokens = tokenizer.encode("How often did germany win gold in the 1994 olympics?[SEP]name[SEP]team[SEP]country[SEP]ikhasbd")
print(tokens)
print([tokenizer.decode(token) for token in tokens])

In [None]:
token_embedding = model.shared  # Shared token embedding layer
num_token_embedding_params = sum(p.numel() for p in token_embedding.parameters() if p.requires_grad)
print(f"Number of trainable parameters in the token embedding layer: {num_token_embedding_params}")

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of trainable parameters: {total_params}")

In [None]:
path = '../datasets/wikisql'
dataset = load_dataset(path+'/data')

In [None]:
preprocessed_dataset = dataset.map(preprocess_function, batched=True, batch_size=2048)
tokenized_dataset = preprocessed_dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True, batch_size=2048)

In [None]:
train_data = tokenized_dataset["train"]
val_data = tokenized_dataset["validation"]
val_data

In [None]:
# train_data = train_data.filter(lambda sample: filter_function(sample, tokenizer), batched=False)

In [None]:
def experiment(project, experiment_name, batch_size=32):
    seeds = [1337, 69, 42]
    compute_metrics = create_metrics_computer(val_data, tokenizer, path+'/tables/validation/dev.db')
    full_metrics = []
    for run in range(3):
        run_name = experiment_name + "_" + str(run+1)
        training_args = Seq2SeqTrainingArguments(
            output_dir="./results/"+run_name,
            run_name=run_name,
            report_to="wandb",
            save_strategy="epoch",
            save_total_limit=1,
            load_best_model_at_end=True,
            eval_strategy="epoch",
            num_train_epochs=25,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=512,
            # learning_rate=1e-4,
            # weight_decay=experiment[3],
            predict_with_generate=True,
            generation_max_length=48,
            generation_num_beams=5,
            seed=seeds[run],
            optim="lion_32bit"
        )

        # Trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=val_data.select(range(1024)), # evaluation is slow, do it on subset
            compute_metrics=compute_metrics
        )

        # Train
        wandb.init(project=project, group=experiment_name, name=run_name)
        trainer.train()
        # Evaluate on the full dataset after training
        full_metrics.append(trainer.evaluate(eval_dataset=val_data))
        wandb.finish()

In [None]:
experiment("ablation-studies", "lion_32bit_customTokenizer")

In [None]:
average_metrics = {key: sum(run[key] for run in full_metrics) / len(full_metrics) for key in full_metrics[0]}
average_metrics

In [None]:
# Function to log or update the summary table
def update_results_table(experiment_name, metrics):
    artifact_name = "experiment_results"
    try:
        artifact = wandb.Api().artifact(project_name + "/" + artifact_name + ":latest")
        artifact_table = artifact.get("results_table")
    except:
        # If no artifact exists yet, start a new table
        artifact = wandb.Artifact(artifact_name, type="results_summary")
        artifact_table = wandb.Table(columns=["Experiment"] + list(metrics.keys()))
    
    # Unpack the metrics dictionary values as a row
    artifact_table.add_data(experiment_name, *[metrics[key] for key in metrics])
    
    # Create a new artifact with the updated table    
    artifact.add(artifact_table, "results_table", overwrite=True)
    
    # Log the updated artifact
    wandb.log_artifact(artifact)

update_results_table(experiment_name, average_metrics)
wandb.finish()

In [None]:
# manually validate model
input_ids = tokenized_val_data["input_ids"]
labels = tokenized_val_data["labels"]

# Run the model to generate predictions
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation
    predictions = model.generate(input_ids=torch.tensor(input_ids).to(torch.device("cuda")))

print(predictions, labels)

In [None]:
# Decode predictions and labels
input_text = [tokenizer.decode(inputs, skip_special_tokens=True) for inputs in input_ids]
predictions_text = [tokenizer.decode(pred, skip_special_tokens=True) for pred in predictions]
labels_text = [tokenizer.decode(label, skip_special_tokens=True) for label in labels]
print(input_text)
print(predictions_text)
print(labels_text)

In [None]:
wandb.init(project="ablation-studies", name="predictions_table")
# Initialize the wandb.Table
table = wandb.Table(columns=["Input", "Prediction", "Correct Output", "Match"])

# Add rows to the table
for inp, pred, correct in zip(input_text, predictions_text, labels_text):
    match = pred == correct
    print(f"Adding row: {idx}, {pred}, {correct}, {match}")  # Debugging
    table.add_data(inp, pred, correct, match)

# Log the table
wandb.log({"Predictions Table": table})


In [None]:
checkpoint_dir = 'results/lion_32bit_bs16_3/checkpoint-3523'

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_dir)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./" + checkpoint_dir + "/eval",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    eval_strategy="epoch",
    num_train_epochs=25,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=512,
    predict_with_generate=True,
    generation_max_length=48,
    generation_num_beams=5,
    optim="lion_32bit"
)

compute_metrics = create_metrics_computer(tokenized_val_data, tokenizer, path+'/tables/validation/dev.db')

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_val_data,
    compute_metrics=compute_metrics
)

In [None]:
trainer.evaluate()