In [None]:
from google.colab import drive
drive.mount('/content/drive')

The following cells will install some of the libraries being used for this project

In [None]:
! pip install transformers

In [None]:
!pip install simpletransformers

In [None]:
!pip install datasets

We will now load the WikiSQL dataset, which is provided out of the box by the HuggingFace library

In [None]:
from datasets import load_dataset
dataset = load_dataset("wikisql")
print(dataset['train'][0])

Using custom data configuration default
Reusing dataset wiki_sql (/root/.cache/huggingface/datasets/wiki_sql/default/0.1.0/2e98053891fd8f9b2c4348bba609ce40cc0a4d7f621191cebcd7cb558b5f8a70)


{'phase': 1, 'question': 'Tell me what the notes are for South Australia ', 'sql': {'agg': 0, 'conds': {'column_index': [3], 'condition': ['SOUTH AUSTRALIA'], 'operator_index': [0]}, 'human_readable': 'SELECT Notes FROM table WHERE Current slogan = SOUTH AUSTRALIA', 'sel': 5}, 'table': {'caption': '', 'header': ['State/territory', 'Text/background colour', 'Format', 'Current slogan', 'Current series', 'Notes'], 'id': '1-1000181-1', 'name': 'table_1000181_1', 'page_id': '', 'page_title': '', 'rows': [['Australian Capital Territory', 'blue/white', 'Yaa·nna', 'ACT · CELEBRATION OF A CENTURY 2013', 'YIL·00A', 'Slogan screenprinted on plate'], ['New South Wales', 'black/yellow', 'aa·nn·aa', 'NEW SOUTH WALES', 'BX·99·HI', 'No slogan on current series'], ['New South Wales', 'black/white', 'aaa·nna', 'NSW', 'CPX·12A', 'Optional white slimline series'], ['Northern Territory', 'ochre/white', 'Ca·nn·aa', 'NT · OUTBACK AUSTRALIA', 'CB·06·ZZ', 'New series began in June 2011'], ['Queensland', 'maroo

Some data preprocessing, which is done to make the data suitable for training. We extract the ground truth questions and queries from the train and validation splits and store them in a list, which is then converted to a pandas dataframe.

In [None]:
import logging

import pandas as pd
from simpletransformers.seq2seq import Seq2SeqModel

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

train = []
for d in dataset["train"]:
  temp = []
  temp.append(d['question'])
  temp.append(d["sql"]["human_readable"])
  train.append(temp)
train_df = pd.DataFrame(train, columns=["input_text", "target_text"])

validation = []
for d in dataset["validation"]:
  temp = []
  temp.append(d['question'])
  temp.append(d["sql"]["human_readable"])
  validation.append(temp)
eval_df = pd.DataFrame(validation, columns=["input_text", "target_text"])

We will now define the model training arguments and the model itself.

In [None]:
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 50,
    "train_batch_size": 64,
    "num_train_epochs": 20,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_generated_text": True,
    "evaluate_during_training_verbose": True,
    "use_multiprocessing": False,
    "max_length": 25,
    "manual_seed": 4,
    "weight_decay":0.01,
    "learning_rate":8e-5,
    "warmup_steps":1000,
    "output_dir":"/content/drive/MyDrive/cs685_models/BART1"
}

model = Seq2SeqModel(
    encoder_decoder_type='bart',
    encoder_decoder_name='facebook/bart-base'
    args=model_args,
    use_cuda=True,
)

We can now start training the model!

In [None]:
model.train_model(train_df)

We now run the evaluation on the validation dataset. The validation loss is used for tuning the hyperparametrs.

In [None]:
results = model.eval_model(eval_df)