# Table Question generation: WikiSQL dataset
In this notebook, we will see how to fine-tune and evaluate a question generation model on WikiSQL dataset.

## Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [1]:
model_name_or_path="t5-small"
modality="table"
dataset_name="wikisql"
max_len=200
target_max_len=40
output_dir="../../models/qg/wikisql_nb"
learning_rate=0.0001
num_train_epochs=2
per_device_train_batch_size=8
per_device_eval_batch_size=32
evaluation_strategy='epoch'

In [2]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='epoch',
    learning_rate=learning_rate,
    prediction_loss_only=True,
    remove_unused_columns=False,
    )
training_args.predict_with_generate=True
training_args.remove_unused_columns = False
training_args.prediction_loss_only = False

---
## WikiSQL data
Here we load one instance of WikiSQL and visualize it. <font color='red'>This part of the code is not needed to train the model </font>

In [3]:
from datasets import load_dataset
from tabulate import tabulate

def print_wikisql_instance(train_instance):
    table = train_instance['table']
    print('Table:\n',tabulate(table['rows'], headers=table['header'], tablefmt='grid'))

    print('Question = ',train_instance['question'])
    print('SQL = ', train_instance['sql']['human_readable'])

train_instance = load_dataset('wikisql', split='train[1001:1002]')
print_wikisql_instance(train_instance[0])

Table:
 +-------------------+-------+---------------+----------------+------------------+---------------------+
| Player            |   No. | Nationality   | Position       | Years for Jazz   | School/Club Team    |
| Fred Saunders     |    12 | United States | Forward        | 1977-78          | Syracuse            |
+-------------------+-------+---------------+----------------+------------------+---------------------+
| Danny Schayes     |    24 | United States | Forward-Center | 1981-83          | Syracuse            |
+-------------------+-------+---------------+----------------+------------------+---------------------+
| Carey Scurry      |    22 | United States | Forward        | 1985-88          | Long Island         |
+-------------------+-------+---------------+----------------+------------------+---------------------+
| Robert Smith      |     5 | United States | Guard          | 1979-80          | UNLV                |
+-------------------+-------+---------------+-----------

The SQL gets converted to a string format which goes as input to generator to generate question

In [4]:
from primeqa.qg.processors.table_qg.sql_processor import SqlProcessor

processed_data = SqlProcessor.preprocess_data(train_instance)
print('Question = ', processed_data['label'][0])
print('\nInput to generator = ', processed_data['input'][0])

Question =  Which position does John Starks play?

Input to generator =  select <<sep>> Position <<sep>> Player <<cond>> equal <<cond>> John Starks <<answer>> shooting guard <<header>> Player <<hsep>> No. <<hsep>> Nationality <<hsep>> Position <<hsep>> Years for Jazz <<hsep>> School/Club Team


---
## Loading the Model

Here we load the model based on the model_name and modality parameter set above. For WikiSQL we keep modality='table'. Other option is modality='passage'

In [5]:
from primeqa.qg.models.qg_model import QGModel

qg_model = QGModel(model_name_or_path, modality=modality)

# Loading Data

Here we load the data

In [6]:
from primeqa.qg.processors.data_loader import QGDataLoader

qgdl = QGDataLoader(
    tokenizer=qg_model.tokenizer,
    modality=modality,
    dataset_name=dataset_name,
    input_max_len=max_len,
    target_max_len=target_max_len
    )

train_dataset = qgdl.create(dataset_split="train[:100]")
valid_dataset = qgdl.create(dataset_split="validation[:50]")

# Train using QGTrainer
Here we create a QG trainer with the training arguments defined above and use it to train on Wikisql training data (or any custom data following the same format)

In [7]:
from primeqa.qg.trainers.qg_trainer import QGTrainer
from primeqa.qg.metrics.generation_metrics import rouge_metrics
from primeqa.qg.utils.data_collator import T2TDataCollator
import os

compute_metrics = rouge_metrics(qg_model.tokenizer)

trainer = QGTrainer(
    model=qg_model.model,
    tokenizer = qg_model.tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=T2TDataCollator(),
    compute_metrics=compute_metrics
    )

train_results = trainer.train()
trainer.save_model()
print(train_results.metrics)

***** Running training *****
  Num examples = 100
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 72
  Gradient Accumulation steps = 1
  Total optimization steps = 4


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,4.202608,18.2795,10.3615,17.0244,17.0949
2,No log,4.054491,18.0121,9.7805,16.5388,16.7104


***** Running Evaluation *****
  Num examples = 50
  Batch size = 288
***** Running Evaluation *****
  Num examples = 50
  Batch size = 288


Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 48.8531, 'train_samples_per_second': 4.094, 'train_steps_per_second': 0.082, 'total_flos': 10573578240000.0, 'train_loss': 4.659993648529053, 'epoch': 2.0}


## Evaluation

Here we evaluate the trained model on validation set

In [8]:
metrics = trainer.evaluate()
print(metrics)

***** Running Evaluation *****
  Num examples = 50
  Batch size = 288


{'eval_loss': 4.05449104309082, 'eval_rouge1': 18.0121, 'eval_rouge2': 9.7805, 'eval_rougeL': 16.5388, 'eval_rougeLsum': 16.7104, 'eval_runtime': 1.1546, 'eval_samples_per_second': 43.304, 'eval_steps_per_second': 0.866, 'epoch': 2.0}
