# Table Question Answering: WikiSQL dataset
In this notebook, we will see how to fine-tune and evaluate a question generation model on WikiSQL dataset.

In [1]:
do_train=True
model_name_or_path="google/tapas-base"
do_eval=True 
dataset_name="wikisql" 
data_path_root="data/wikisql/" 
output_dir="../../models/tableqa/wikisql_nb"
learning_rate=4e-4

In [2]:
import logging
from primeqa.tableqa.metrics.answer_accuracy import compute_denotation_accuracy
from primeqa.tableqa.models.tableqa_model import TableQAModel
from primeqa.tableqa.trainers.tableqa_trainer import TableQATrainer
from transformers import TapasConfig
from transformers import (
    DataCollator,
    HfArgumentParser,
    TrainingArguments,
    set_seed,default_data_collator,
)
from primeqa.tableqa.run_tableqa import TableQAArguments
from primeqa.tableqa.utils.data_collator import TapasCollator
from primeqa.tableqa.preprocessors.wikisql_preprocessor import load_data
from primeqa.tableqa.postprocessor.wikisql import WikiSQLPostprocessor
from primeqa.tableqa.metrics.answer_accuracy import compute_denotation_accuracy

## Loading the TableQA specific arguments needed for TAPAS training


In [3]:

tqa_args = TableQAArguments()
tqa_args.dataset_name=dataset_name
tqa_args.data_path_root=data_path_root
tqa_args.use_answer_as_supervision=True
config = TapasConfig(tqa_args)
tableqa_model = TableQAModel("google/tapas-base",config=config)
model = tableqa_model.model
tokenizer = tableqa_model.tokenizer

post_obj = WikiSQLPostprocessor(tokenizer,tqa_args)
train_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
)


Some weights of TapasForQuestionAnswering were not initialized from the model checkpoint at google/tapas-base and are newly initialized: ['output_bias', 'column_output_weights', 'output_weights', 'column_output_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Loading the wikisql data 
Note: The call load_data also internally converts the data to TAPAS traiming format.

In [4]:
# only a small fraction of training, dev data has been used for demonstration purpose
train_dataset,eval_dataset = load_data(tqa_args.data_path_root,tokenizer,100,50)

Preprocessing wikisql dataset


Using custom data configuration default
Using custom data configuration default


Preprocessing done
Preprocessing done


## Use the TableQATrainer with TAPAS specific collator

In [5]:
trainer = TableQATrainer(model=model,
                                args=train_args,
                                train_dataset=train_dataset if train_args.do_train else None,
                                eval_dataset=eval_dataset if train_args.do_eval else None,
                                tokenizer=tableqa_model.tokenizer,
                                data_collator=TapasCollator(),
                                post_process_function= post_obj.postprocess_prediction,
                                compute_metrics=compute_denotation_accuracy         
                                )

## check the trainer metrics for training and validation phase

In [6]:
if train_args.do_train:
    train_result = trainer.train()
    trainer.save_model()
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

***** Running training *****
  Num examples = 68
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 27


Step,Training Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ../../models/tableqa/wikisql_nb
Configuration saved in ../../models/tableqa/wikisql_nb/config.json
Model weights saved in ../../models/tableqa/wikisql_nb/pytorch_model.bin
tokenizer config file saved in ../../models/tableqa/wikisql_nb/tokenizer_config.json
Special tokens file saved in ../../models/tableqa/wikisql_nb/special_tokens_map.json


***** train metrics *****
  epoch                    =        3.0
  total_flos               =    49988GF
  train_loss               =     1.0374
  train_runtime            = 0:05:27.66
  train_samples_per_second =      0.623
  train_steps_per_second   =      0.082


In [8]:
if train_args.do_eval:
      metrics = trainer.evaluate()
      trainer.log_metrics("eval", metrics)
      trainer.save_metrics("eval", metrics)

***** Running Evaluation *****
  Num examples = 42
  Batch size = 8


***** eval metrics *****
  epoch                    =    3.0
  eval_Denotation accuracy = 0.4524
