<a href="https://colab.research.google.com/github/YichengShen/cis5220-project/blob/main/t5_base_with_db_reg_1e-2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
%%capture
! pip install datasets
! pip install transformers
! pip install pytorch-lightning==1.5.10

Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import shutil
import subprocess
import json
import nltk
import gc

In [None]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

Helper function for garbage collection

In [None]:
def garbage_collect():
    torch.cuda.empty_cache()
    gc.collect()

### Load Python scripts

In [None]:
scripts_path_in_drive = "/content/drive/Shareddrives/CIS 522/scripts"
scripts_path_in_runtime = "/content/scripts"

# Overrides previous scripts folder
if os.path.exists(scripts_path_in_runtime):
    shutil.rmtree(scripts_path_in_runtime)
shutil.copytree(scripts_path_in_drive, scripts_path_in_runtime)

'/content/scripts'

## Load Data

We use the Spider dataset. Hugging face hosts this dataset, but it does not contain the database schema information. We downloaded the original Spider dataset from Yale to supplement the Hugging face version.

### Load dataset into Colab runtime

In [None]:
# Create data folder if not exist
!mkdir -p data

# Change this path to where you store spider.zip in your Drive
dataset_zip_path_in_drive = "/content/drive/Shareddrives/CIS 522/spider.zip"
dataset_zip_path_in_runtime = "/content/data/spider.zip"

shutil.copy(dataset_zip_path_in_drive, dataset_zip_path_in_runtime)

'/content/data/spider.zip'

Unzip

In [None]:
!unzip -q -o /content/data/spider.zip -d /content/data/

### Load dataset from Huggingface

In [None]:
from datasets import load_dataset

import pandas as pd
from sklearn.model_selection import train_test_split

In [None]:
dataset = load_dataset("spider")
dataset

Downloading builder script:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.69k [00:00<?, ?B/s]

Downloading and preparing dataset spider/spider to /root/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa...


Downloading data:   0%|          | 0.00/99.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1034 [00:00<?, ? examples/s]

Dataset spider downloaded and prepared to /root/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 7000
    })
    validation: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 1034
    })
})

In [None]:
df_val = dataset['validation'].to_pandas()
df_val

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]"
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]"
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin..."
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ..."
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","[SELECT, avg, (, age, ), ,, min, (, age, ), ,,...","[select, avg, (, age, ), ,, min, (, age, ), ,,...","[What, is, the, average, ,, minimum, ,, and, m..."
...,...,...,...,...,...,...
1029,singer,SELECT Citizenship FROM singer WHERE Birth_Yea...,What are the citizenships that are shared by s...,"[SELECT, Citizenship, FROM, singer, WHERE, Bir...","[select, citizenship, from, singer, where, bir...","[What, are, the, citizenships, that, are, shar..."
1030,real_estate_properties,SELECT count(*) FROM Other_Available_Features,How many available features are there in total?,"[SELECT, count, (, *, ), FROM, Other_Available...","[select, count, (, *, ), from, other_available...","[How, many, available, features, are, there, i..."
1031,real_estate_properties,SELECT T2.feature_type_name FROM Other_Availab...,What is the feature type name of feature AirCon?,"[SELECT, T2.feature_type_name, FROM, Other_Ava...","[select, t2, ., feature_type_name, from, other...","[What, is, the, feature, type, name, of, featu..."
1032,real_estate_properties,SELECT T2.property_type_description FROM Prope...,Show the property type descriptions of propert...,"[SELECT, T2.property_type_description, FROM, P...","[select, t2, ., property_type_description, fro...","[Show, the, property, type, descriptions, of, ..."


In [None]:
df_train = dataset['train'].to_pandas()
df_train

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...","[How, many, heads, of, the, departments, are, ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","List the name, born state and age of the heads...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","[List, the, name, ,, born, state, and, age, of..."
2,department_management,"SELECT creation , name , budget_in_billions ...","List the creation year, name and budget of eac...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","[List, the, creation, year, ,, name, and, budg..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...",What are the maximum and minimum budget of the...,"[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...","[What, are, the, maximum, and, minimum, budget..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,What is the average number of employees of the...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...","[What, is, the, average, number, of, employees..."
...,...,...,...,...,...,...
6995,culture_company,SELECT T1.company_name FROM culture_company AS...,What are all the company names that have a boo...,"[SELECT, T1.company_name, FROM, culture_compan...","[select, t1, ., company_name, from, culture_co...","[What, are, all, the, company, names, that, ha..."
6996,culture_company,"SELECT T1.title , T3.book_title FROM movie AS...",Show the movie titles and book titles for all ...,"[SELECT, T1.title, ,, T3.book_title, FROM, mov...","[select, t1, ., title, ,, t3, ., book_title, f...","[Show, the, movie, titles, and, book, titles, ..."
6997,culture_company,"SELECT T1.title , T3.book_title FROM movie AS...",What are the titles of movies and books corres...,"[SELECT, T1.title, ,, T3.book_title, FROM, mov...","[select, t1, ., title, ,, t3, ., book_title, f...","[What, are, the, titles, of, movies, and, book..."
6998,culture_company,SELECT T2.company_name FROM movie AS T1 JOIN c...,Show all company names with a movie directed i...,"[SELECT, T2.company_name, FROM, movie, AS, T1,...","[select, t2, ., company_name, from, movie, as,...","[Show, all, company, names, with, a, movie, di..."


## Preprocess

### Load and preprocess table data

In [None]:
table_paths = "/content/data/spider/tables.json"

if not isinstance(table_paths, list):
        table_paths = (table_paths, )

for i, TABLE_PATH in enumerate(table_paths):
    print(f"Loading data from {TABLE_PATH}")
    with open(TABLE_PATH) as inf:
        table_data= json.load(inf)

Loading data from /content/data/spider/tables.json


In [None]:
def format_dict(input_dict):
    formatted_value = []

    for i in range(len(input_dict['table_names'])):
        table_name = input_dict['table_names'][i]
        columns = [col[1].replace(" ", "_") for col in input_dict['column_names'] if col[0] == i]
        formatted_columns = ', '.join(columns)
        formatted_value.append(f"{table_name} : {formatted_columns}")

    formatted_value_str = " | ".join(formatted_value)
    return {input_dict['db_id']: formatted_value_str}

formatted_table_data = [format_dict(d) for d in table_data]
merged_formatted_table_data = {k: v for d in formatted_table_data for k, v in d.items()}

In [None]:
merged_formatted_table_data['perpetrator']

'perpetrator : perpetrator_id, people_id, date, year, location, country, killed, injured | people : people_id, name, height, weight, home_town'

### Concat and format data for training

Training features

In [None]:
df_train['db_schema'] = df_train['db_id'].apply(lambda x: merged_formatted_table_data[x])
df_train['source_text'] = df_train[['question', 'db_id', 'db_schema']].agg(' | '.join, axis=1)
df_train['source_text'][0]

'How many heads of the departments are older than 56 ? | department_management | department : department_id, name, creation, ranking, budget_in_billions, num_employees | head : head_id, name, born_state, age | management : department_id, head_id, temporary_acting'

Training labels

In [None]:
df_train['target_text'] = df_train[['db_id', 'query']].agg(' | '.join, axis=1)
df_train['target_text'][0]

'department_management | SELECT count(*) FROM head WHERE age  >  56'

Val features

In [None]:
df_val['db_schema'] = df_val['db_id'].apply(lambda x: merged_formatted_table_data[x])
df_val['source_text'] = df_val[['question', 'db_id', 'db_schema']].agg(' | '.join, axis=1)
df_val['source_text'][0]

'How many singers do we have? | concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, song_name, song_release_year, age, is_male | concert : concert_id, concert_name, theme, stadium_id, year | singer in concert : concert_id, singer_id'

Val labels

In [None]:
df_val['target_text'] = df_val[['db_id', 'query']].agg(' | '.join, axis=1)
df_val['target_text'][0]

'concert_singer | SELECT count(*) FROM singer'

### Finalize df used in training

In [None]:
df_f_train = df_train[['source_text','target_text']]
df_f_train

Unnamed: 0,source_text,target_text
0,How many heads of the departments are older th...,department_management | SELECT count(*) FROM h...
1,"List the name, born state and age of the heads...","department_management | SELECT name , born_st..."
2,"List the creation year, name and budget of eac...","department_management | SELECT creation , nam..."
3,What are the maximum and minimum budget of the...,department_management | SELECT max(budget_in_b...
4,What is the average number of employees of the...,department_management | SELECT avg(num_employe...
...,...,...
6995,What are all the company names that have a boo...,culture_company | SELECT T1.company_name FROM ...
6996,Show the movie titles and book titles for all ...,"culture_company | SELECT T1.title , T3.book_t..."
6997,What are the titles of movies and books corres...,"culture_company | SELECT T1.title , T3.book_t..."
6998,Show all company names with a movie directed i...,culture_company | SELECT T2.company_name FROM ...


In [None]:
df_f_val = df_val[['source_text','target_text']]
df_f_val

Unnamed: 0,source_text,target_text
0,How many singers do we have? | concert_singer ...,concert_singer | SELECT count(*) FROM singer
1,What is the total number of singers? | concert...,concert_singer | SELECT count(*) FROM singer
2,"Show name, country, age for all singers ordere...","concert_singer | SELECT name , country , age..."
3,"What are the names, countries, and ages for ev...","concert_singer | SELECT name , country , age..."
4,"What is the average, minimum, and maximum age ...","concert_singer | SELECT avg(age) , min(age) ,..."
...,...,...
1029,What are the citizenships that are shared by s...,singer | SELECT Citizenship FROM singer WHERE ...
1030,How many available features are there in total...,real_estate_properties | SELECT count(*) FROM ...
1031,What is the feature type name of feature AirCo...,real_estate_properties | SELECT T2.feature_typ...
1032,Show the property type descriptions of propert...,real_estate_properties | SELECT T2.property_ty...


## Model

We customize the simple T5 model. Credits to https://github.com/Shivanandroy/simpleT5.

In [None]:
import torch
import numpy as np
import pandas as pd
from transformers import (
    T5ForConditionalGeneration,
    MT5ForConditionalGeneration,
    ByT5Tokenizer,
    PreTrainedTokenizer,
    T5TokenizerFast as T5Tokenizer,
    MT5TokenizerFast as MT5Tokenizer,
)
from transformers import AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelWithLMHead, AutoTokenizer
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from typing import List, Union

torch.cuda.empty_cache()
pl.seed_everything(42)

INFO:pytorch_lightning.utilities.seed:Global seed set to 42


42

### Dataset & dataloader

In [None]:
class PyTorchDataModule(Dataset):
    """  PyTorch Dataset class  """

    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: PreTrainedTokenizer,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
    ):
        """
        initiates a PyTorch Dataset Module for input data
        Args:
            data (pd.DataFrame): input pandas dataframe. Dataframe must have 2 column --> "source_text" and "target_text"
            tokenizer (PreTrainedTokenizer): a PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, or ByT5Tokenizer)
            source_max_token_len (int, optional): max token length of source text. Defaults to 512.
            target_max_token_len (int, optional): max token length of target text. Defaults to 512.
        """
        self.tokenizer = tokenizer
        self.data = data
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len

    def __len__(self):
        """ returns length of data """
        return len(self.data)

    def __getitem__(self, index: int):
        """ returns dictionary of input tensors to feed into T5/MT5 model"""

        data_row = self.data.iloc[index]
        source_text = data_row["source_text"]

        source_text_encoding = self.tokenizer(
            source_text,
            max_length=self.source_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        target_text_encoding = self.tokenizer(
            data_row["target_text"],
            max_length=self.target_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        labels = target_text_encoding["input_ids"]
        labels[
            labels == 0
        ] = -100  # to make sure we have correct labels for T5 text generation

        return dict(
            source_text_input_ids=source_text_encoding["input_ids"].flatten(),
            source_text_attention_mask=source_text_encoding["attention_mask"].flatten(),
            labels=labels.flatten(),
            labels_attention_mask=target_text_encoding["attention_mask"].flatten(),
        )


class LightningDataModule(pl.LightningDataModule):
    """ PyTorch Lightning data class """

    def __init__(
        self,
        train_df: pd.DataFrame,
        test_df: pd.DataFrame,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = 4,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
        num_workers: int = 2,
    ):
        """
        initiates a PyTorch Lightning Data Module
        Args:
            train_df (pd.DataFrame): training dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
            test_df (pd.DataFrame): validation dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
            tokenizer (PreTrainedTokenizer): PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, or ByT5Tokenizer)
            batch_size (int, optional): batch size. Defaults to 4.
            source_max_token_len (int, optional): max token length of source text. Defaults to 512.
            target_max_token_len (int, optional): max token length of target text. Defaults to 512.
        """
        super().__init__()

        self.train_df = train_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_dataset = PyTorchDataModule(
            self.train_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len,
        )
        self.test_dataset = PyTorchDataModule(
            self.test_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len,
        )

    def train_dataloader(self):
        """ training dataloader """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        """ test dataloader """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        """ validation dataloader """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

### Lightning

In [None]:
class LightningModel(pl.LightningModule):
    """ PyTorch Lightning Model class"""

    def __init__(
        self,
        tokenizer,
        model,
        reg_lambda = 0,
        outputdir: str = "outputs",
        save_only_last_epoch: bool = False,
    ):
        """
        initiates a PyTorch Lightning Model
        Args:
            tokenizer : T5/MT5/ByT5 tokenizer
            model : T5/MT5/ByT5 model
            reg_lambda (float) : strength of L2 regularization 
            outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
            save_only_last_epoch (bool, optional): If True, save just the last epoch else models are saved for every epoch
        """
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.outputdir = outputdir
        self.reg_lambda = reg_lambda
        self.average_training_loss = None
        self.average_validation_loss = None
        self.save_only_last_epoch = save_only_last_epoch

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        """ forward step """
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )

        return output.loss, output.logits

    def _step(self, batch, apply_regularization=False):
        input_ids = batch["source_text_input_ids"]
        attention_mask = batch["source_text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )

        if apply_regularization:
            # L2 regularization
            l2_reg = torch.tensor(0., requires_grad=True).to(self.device)
            for param in self.model.parameters():
                l2_reg = l2_reg + torch.norm(param, p=2)
            total_loss = loss + self.reg_lambda * l2_reg
        else:
            total_loss = loss

        return total_loss

    def training_step(self, batch, batch_size):
        """ training step """
        total_loss = self._step(batch, apply_regularization=True)
        self.log(
            "train_loss", total_loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
        )
        return total_loss

    def validation_step(self, batch, batch_size):
        """ validation step """
        loss = self._step(batch)
        self.log(
            "val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
        )
        return loss

    def test_step(self, batch, batch_size):
        """ test step """
        loss = self._step(batch)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        """ configure optimizers """
        return AdamW(self.parameters(), lr=0.0001)

    def training_epoch_end(self, training_step_outputs):
        """ save tokenizer and model on epoch end """
        self.average_training_loss = np.round(
            torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
            4,
        )
        path = f"{self.outputdir}/customt5-epoch-{self.current_epoch}-train-loss-{str(self.average_training_loss)}-val-loss-{str(self.average_validation_loss)}"
        if self.save_only_last_epoch:
            if self.current_epoch == self.trainer.max_epochs - 1:
                self.tokenizer.save_pretrained(path)
                self.model.save_pretrained(path)
        else:
            self.tokenizer.save_pretrained(path)
            self.model.save_pretrained(path)

    def validation_epoch_end(self, validation_step_outputs):
        _loss = [x.cpu() for x in validation_step_outputs]
        self.average_validation_loss = np.round(
            torch.mean(torch.stack(_loss)).item(),
            4,
        )

### Custom T5

In [None]:
class CustomT5:
    """ Custom T5 class """

    def __init__(self) -> None:
        """ initiates Custom T5 class """
        pass

    def from_pretrained(self, model_type="t5", 
                        model_name="t5-base", 
                        checkpoint_path=None) -> None:
        """
        loads T5/MT5 Model model for training/finetuning
        Args:
            model_type (str, optional): "t5" or "mt5" . Defaults to "t5".
            model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
        """
        if model_type == "t5":
            self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
            if checkpoint_path:
                self.model = T5ForConditionalGeneration.from_pretrained(
                    checkpoint_path, return_dict=True
                )
                print(f"Continue training from previous checkpoint: {checkpoint_path}")
            else:
                self.model = T5ForConditionalGeneration.from_pretrained(
                    f"{model_name}", return_dict=True
                )
        elif model_type == "mt5":
            self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_name}")
            if checkpoint_path:
                self.model = T5ForConditionalGeneration.from_pretrained(
                    checkpoint_path, return_dict=True
                )
            else:
                self.model = MT5ForConditionalGeneration.from_pretrained(
                    f"{model_name}", return_dict=True
                )
        elif model_type == "byt5":
            self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
            if checkpoint_path:
                self.model = T5ForConditionalGeneration.from_pretrained(
                    checkpoint_path, return_dict=True
                )
            else:
                self.model = T5ForConditionalGeneration.from_pretrained(
                    f"{model_name}", return_dict=True
                )

    def train(
        self,
        train_df: pd.DataFrame,
        eval_df: pd.DataFrame,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
        batch_size: int = 8,
        max_epochs: int = 5,
        l2_reg_lambda = 0,
        use_gpu: bool = True,
        outputdir: str = "outputs",
        early_stopping_patience_epochs: int = 0,  # 0 to disable early stopping feature
        precision=32,
        logger="default",
        dataloader_num_workers: int = 2,
        save_only_last_epoch: bool = False,
    ):
        """
        trains T5/MT5 model on custom dataset
        Args:
            train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "source_text" and "target_text"
            eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "source_text" and "target_text"
            source_max_token_len (int, optional): max token length of source text. Defaults to 512.
            target_max_token_len (int, optional): max token length of target text. Defaults to 512.
            batch_size (int, optional): batch size. Defaults to 8.
            max_epochs (int, optional): max number of epochs. Defaults to 5.
            use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
            outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
            early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
            precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
            logger (pytorch_lightning.loggers) : any logger supported by PyTorch Lightning. Defaults to "default". If "default", pytorch lightning default logger is used.
            dataloader_num_workers (int, optional): number of workers in train/test/val dataloader
            save_only_last_epoch (bool, optional): If True, saves only the last epoch else models are saved at every epoch
        """
        self.data_module = LightningDataModule(
            train_df,
            eval_df,
            self.tokenizer,
            batch_size=batch_size,
            source_max_token_len=source_max_token_len,
            target_max_token_len=target_max_token_len,
            num_workers=dataloader_num_workers,
        )

        self.T5Model = LightningModel(
            tokenizer=self.tokenizer,
            model=self.model,
            reg_lambda=l2_reg_lambda,
            outputdir=outputdir,
            save_only_last_epoch=save_only_last_epoch,
        )

        # add callbacks
        callbacks = [TQDMProgressBar(refresh_rate=5)]

        if early_stopping_patience_epochs > 0:
            early_stop_callback = EarlyStopping(
                monitor="val_loss",
                min_delta=0.00,
                patience=early_stopping_patience_epochs,
                verbose=True,
                mode="min",
            )
            callbacks.append(early_stop_callback)

        # add gpu support
        gpus = 1 if use_gpu else 0

        # add logger
        loggers = True if logger == "default" else logger

        # prepare trainer
        trainer = pl.Trainer(
            logger=loggers,
            callbacks=callbacks,
            max_epochs=max_epochs,
            gpus=gpus,
            precision=precision,
            log_every_n_steps=1,
        )

        # fit trainer
        trainer.fit(self.T5Model, self.data_module)

    def load_model(
        self, model_type: str = "t5", model_dir: str = "outputs", use_gpu: bool = False
    ):
        """
        loads a checkpoint for inferencing/prediction
        Args:
            model_type (str, optional): "t5" or "mt5". Defaults to "t5".
            model_dir (str, optional): path to model directory. Defaults to "outputs".
            use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
        """
        if model_type == "t5":
            self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
            self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
        elif model_type == "mt5":
            self.model = MT5ForConditionalGeneration.from_pretrained(f"{model_dir}")
            self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
        elif model_type == "byt5":
            self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
            self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")

        if use_gpu:
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                raise "exception ---> no gpu found. set use_gpu=False, to use CPU"
        else:
            self.device = torch.device("cpu")

        self.model = self.model.to(self.device)

    def predict(
        self,
        source_text: Union[str, List[str]],
        source_max_token_len: int = 512,
        max_length: int = 512,
        num_return_sequences: int = 1,
        num_beams: int = 2,
        top_k: int = 50,
        top_p: float = 0.95,
        do_sample: bool = True,
        repetition_penalty: float = 2.5,
        length_penalty: float = 1.0,
        early_stopping: bool = True,
        skip_special_tokens: bool = True,
        clean_up_tokenization_spaces: bool = True,
    ):
        """
        generates prediction for T5/MT5 model
        Args:
            source_text (str): any text for generating predictions
            max_length (int, optional): max token length of prediction. Defaults to 512.
            num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
            num_beams (int, optional): number of beams. Defaults to 2.
            top_k (int, optional): Defaults to 50.
            top_p (float, optional): Defaults to 0.95.
            do_sample (bool, optional): Defaults to True.
            repetition_penalty (float, optional): Defaults to 2.5.
            length_penalty (float, optional): Defaults to 1.0.
            early_stopping (bool, optional): Defaults to True.
            skip_special_tokens (bool, optional): Defaults to True.
            clean_up_tokenization_spaces (bool, optional): Defaults to True.
        Returns:
            list[str]: returns predictions
        """
        if isinstance(source_text, str):
            input_ids = self.tokenizer.encode(
                source_text, return_tensors="pt", add_special_tokens=True
            )
            input_ids = input_ids.to(self.device)
            generated_ids = self.model.generate(
                input_ids=input_ids,
                num_beams=num_beams,
                max_length=max_length,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                early_stopping=early_stopping,
                top_p=top_p,
                top_k=top_k,
                num_return_sequences=num_return_sequences,
                do_sample=do_sample,
            )
            preds = [
                self.tokenizer.decode(
                    g,
                    skip_special_tokens=skip_special_tokens,
                    clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                ) for g in generated_ids
            ]
            return preds
        # Predict for a list of strings
        elif isinstance(source_text, list):
            source_text_encoding = self.tokenizer(
                source_text,
                max_length=source_max_token_len,
                padding="max_length",
                truncation=True,
                return_attention_mask=True,
                add_special_tokens=True,
                return_tensors="pt",
            )
            input_ids = source_text_encoding["input_ids"].to(self.device)
            attention_mask = source_text_encoding["attention_mask"].to(self.device)
            generated_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                num_beams=num_beams,
                max_length=max_length,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                early_stopping=early_stopping,
                top_p=top_p,
                top_k=top_k,
                num_return_sequences=num_return_sequences,
                do_sample=do_sample,
            )
            preds = self.tokenizer.batch_decode(
                generated_ids,
                skip_special_tokens=skip_special_tokens,
                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            )
            return preds

    def predict_in_batches(
        self,
        source_text: List[str],
        batch_size: int = 8,
        *args,
        **kwargs,
    ):
        """
        Generates predictions for a list of source texts using the predict method in batches.

        Args:
            source_text (List[str]): A list of source texts for generating predictions.
            batch_size (int, optional): The size of each batch. Defaults to 8.
            *args: Additional positional arguments for the predict method.
            **kwargs: Additional keyword arguments for the predict method.

        Returns:
            List[str]: A list of predictions corresponding to the input source texts.
        """
        preds = []
        for i in range(0, len(source_text), batch_size):
            batch = source_text[i:i+batch_size]
            batch_preds = self.predict(batch, *args, **kwargs)
            preds.extend(batch_preds)
        return preds

## Train

**IMPORTANT: Create a new folder before training. Be ware of overwriting previous checkpoints!**

In [None]:
garbage_collect()

In [None]:
OLD_CHECKPOINTS_PATH = "/content/drive/Shareddrives/CIS 522/model_checkpoints/plain_simple_t5_v0/simplet5-epoch-2-train-loss-0.2109-val-loss-0.6322"

NEW_CHECKPOINTS_SAVING_DIR = "/content/drive/Shareddrives/CIS 522/model_checkpoints/Marc_reg_t5_experiments_3"


model = CustomT5()
model.from_pretrained(model_type="t5", model_name="t5-base", checkpoint_path=None)
model.train(train_df=df_f_train,
            eval_df=df_f_val, 
            source_max_token_len=256, 
            target_max_token_len=64, 
            batch_size=16, 
            max_epochs=4, 
            l2_reg_lambda = 1e-2,
            precision=32,
            use_gpu=True,
            outputdir=NEW_CHECKPOINTS_SAVING_DIR)

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.distributed:GPU available: True, used: True
INFO:pytorch_lightning.utilities.distributed:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Try one prediction

In [None]:
question = "List the name, born state and age of the heads of departments ordered by age. | department_management " + merged_formatted_table_data['department_management']
model.device = "cuda" if torch.cuda.is_available() else "cpu"

# Don't forget to change the path of model checkpoint here
model.load_model("t5","/content/drive/Shareddrives/CIS 522/model_checkpoints/Marc_reg_t5_experiments_3/customt5-epoch-3-train-loss-165.5583-val-loss-0.5594", use_gpu=True)
model.predict(question)

['department_management | SELECT name, born_state, age FROM head ORDER BY age']

## Code for Evaluation

In [None]:
def evaluate(preds_file, labels_file, evaluation_type="all", 
             database_dir="./data/spider/database", 
             table_file="./data/spider/tables.json",
             verbose="False"):
    """
    Runs the evaluation script for the Spider dataset using the provided labels and predictions files.
    It prints the evaluation results to the console and returns the subprocess result object.

    Args:
        preds_file (str): Path to the predictions file. In this file, each line is `a ground-truth SQL \t db_id`.
        labels_file (str): Path to the labels (gold) file. In this file, each line is a predicted SQL.
        evaluation_type (str): Evaluation type, can be 'all', 'exec', or 'match'.
        database_dir (str): Path to the directory containing the Spider dataset's database files.
        table_file (str): Path to the tables.json file from the Spider dataset.
        verbose (str): Flag to trun on or off printing details.

    Returns:
        result (subprocess.CompletedProcess): A CompletedProcess instance representing the evaluation subprocess.
                                              It contains attributes like 'stdout' and 'stderr' to access the output
                                              and error messages respectively.
    """

    cmd = [
        "python3", "scripts/evaluation.py",
        "--gold", labels_file,
        "--pred", preds_file,
        "--etype", evaluation_type,
        "--db", database_dir,
        "--table", table_file,
        "--verbose", verbose
    ]

    result = subprocess.run(cmd, capture_output=True, text=True)

    print(result.stdout)

    return result

## Evaluation

In [None]:
garbage_collect()

In [None]:
preds_filename ="preds.txt"
labels_filename="labels.txt"

model = CustomT5()
model.load_model("t5","/content/drive/Shareddrives/CIS 522/model_checkpoints/Marc_reg_t5_experiments_3/customt5-epoch-3-train-loss-165.5583-val-loss-0.5594", use_gpu=True)
print("Evaluating on:", model.device)

preds_no_format = model.predict_in_batches(list(df_f_val['source_text']), 
                                           batch_size=16)

with open(preds_filename, 'w') as output_file:
    preds = []
    for pred in preds_no_format:
        if ' | ' in pred:
            pred_formatted = pred.split(' | ')[1]
        else: 
            pred_formatted = pred
        preds.append(pred_formatted)
        output_file.write(pred_formatted + '\n')

with open(labels_filename, 'w') as output_file:
    labels = []
    for label in df_f_val['target_text']:
        label_split = label.split(' | ')
        label_formatted = label_split[1] + '\t' + label_split[0]
        labels.append(label_formatted)
        output_file.write(label_formatted + '\n')

Evaluating on: cuda


In [None]:
# evaluation_type="all" or "exec" might explode RAM, be careful
evaluation = evaluate(preds_file="preds.txt", 
                      labels_file="labels.txt", 
                      evaluation_type="all", 
                      database_dir="./data/spider/database", 
                      table_file="./data/spider/tables.json",
                      verbose="False")

                     easy                 medium               hard                 extra                all                 
count                248                  446                  174                  166                  1034                
execution            0.544                0.321                0.247                0.054                0.319               

exact match          0.556                0.314                0.236                0.048                0.316               

---------------------PARTIAL MATCHING ACCURACY----------------------
select               0.903                0.903                0.965                0.818                0.905               
select(no AGG)       0.933                0.924                0.965                0.818                0.925               
where                0.900                0.855                0.700                0.500                0.811               
where(no OP)         0.900                0.855