<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction: TAPAS

* Original TAPAS paper (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.398/
* Follow-up paper on intermediate pre-training (EMMNLP Findings 2020): https://www.aclweb.org/anthology/2020.findings-emnlp.27/
* Original Github repository: https://github.com/google-research/tapas
* Blog post: https://ai.googleblog.com/2020/04/using-neural-networks-to-find-answers.html

TAPAS is an algorithm that (among other tasks) can answer questions about tabular data. It is essentially a BERT model with relative position embeddings and additional token type ids that encode tabular structure, and 2 classification heads on top: one for **cell selection** and one for (optionally) performing an **aggregation** among selected cells (such as summing or counting).

Similar to BERT, the base `TapasModel` is pre-trained using the masked language modeling (MLM) objective on a large collection of tables from Wikipedia and associated texts. In addition, the authors further pre-trained the model on an second task (table entailment) to increase the numerical reasoning capabilities of TAPAS (as explained in the follow-up paper), which further improves performance on downstream tasks. 

In this notebook, we are going to fine-tune `TapasForQuestionAnswering` on [Sequential Question Answering (SQA)](https://www.microsoft.com/en-us/research/publication/search-based-neural-structured-learning-sequential-question-answering/), a dataset built by Microsoft Research which deals with asking questions related to a table in a **conversational set-up**. We are going to do so as in the original paper, by adding a randomly initialized cell selection head on top of the pre-trained base model (note that SQA does not have questions that involve aggregation and hence no aggregation head), and then fine-tuning them altogether.

First, we install both the Transformers library as well as the dependency on [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter), which the model requires.

In [2]:
! pip install torch-scatter

Collecting torch-scatter
  Downloading torch_scatter-2.1.1.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25ldone
[?25h  Created wheel for torch-scatter: filename=torch_scatter-2.1.1-cp311-cp311-macosx_10_9_x86_64.whl size=294139 sha256=570fb2db14389c19dceb50d143c222b0f39fe0f6947372ac40807df41c7b24bd
  Stored in directory: /Users/evgenynazarenko/Library/Caches/pip/wheels/e1/45/de/7e6c2b34bf0c92ea931392eb9930fa25ac12bf455e68ae1d6e
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.1


We also install a small portion from the SQA training dataset, for demonstration purposes. This is a TSV file containing table-question pairs. Besides this, we also download the `table_csv` directory, which contains the actual tabular data.

Note that you can download the entire SQA dataset on the [official website](https://www.microsoft.com/en-us/download/details.aspx?id=54253).

In [1]:
import requests, zipfile, io
import os

def download_files(dir_name):
  if not os.path.exists(dir_name): 
    # 28 training examples from the SQA training set + table csv data
    urls = ["https://www.dropbox.com/s/2p6ez9xro357i63/sqa_train_set_28_examples.zip?dl=1",
            "https://www.dropbox.com/s/abhum8ssuow87h6/table_csv.zip?dl=1"
    ]
    for url in urls:
      r = requests.get(url)
      z = zipfile.ZipFile(io.BytesIO(r.content))
      z.extractall()

dir_name = "sqa_data"
download_files(dir_name)

## Prepare the data 

Let's look at the first few rows of the dataset:

In [None]:
## here is generated dataset for month schedule 


In [187]:
import pandas as pd

data = pd.read_excel("sqa_train_set_28_examples.xlsx")
data.head(10)


Unnamed: 0,id,annotator,position,question,table_file,answer_coordinates,answer_text
0,nt-639,0,0,where are the players from?,table_csv/203_149.csv,"['(0, 4)', '(1, 4)', '(2, 4)', '(3, 4)', '(4, ...","['Louisiana State University', 'Valley HS (Las..."
1,nt-639,0,1,which player went to louisiana state university?,table_csv/203_149.csv,"['(0, 1)']",['Ben McDonald']
2,nt-639,1,0,who are the players?,table_csv/203_149.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Ben McDonald', 'Tyler Houston', 'Roger Salke..."
3,nt-639,1,1,which ones are in the top 26 picks?,table_csv/203_149.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Ben McDonald', 'Tyler Houston', 'Roger Salke..."
4,nt-639,1,2,"and of those, who is from louisiana state univ...",table_csv/203_149.csv,"['(0, 1)']",['Ben McDonald']
5,nt-639,2,0,who are the players in the top 26?,table_csv/203_149.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Ben McDonald', 'Tyler Houston', 'Roger Salke..."
6,nt-639,2,1,"of those, which one was from louisiana state u...",table_csv/203_149.csv,"['(0, 1)']",['Ben McDonald']
7,nt-11649,0,0,what are all the names of the teams?,table_csv/204_135.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Cordoba CF', 'CD Malaga', 'Granada CF', 'UD ..."
8,nt-11649,0,1,"of these, which teams had any losses?",table_csv/204_135.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Cordoba CF', 'CD Malaga', 'Granada CF', 'UD ..."
9,nt-11649,0,2,"of these teams, which had more than 21 losses?",table_csv/204_135.csv,"['(15, 1)']",['CD Villarrobledo']


In [188]:
data_schedule=pd.read_csv("24_sqa_train_set.csv")
data_schedule


Unnamed: 0,id,annotator,position,question,table_file,answer_coordinates,answer_text
0,ss-1030,0,0,What is scheduled on Monday at 9?,table_csv/schedule.csv,"['(0, 1)']",['TCS']
1,ss-1030,0,1,What is the next lecture?,table_csv/schedule.csv,"['(1, 1)']",['SSA']
2,ss-1030,1,0,What subject is at 11 on Tuesday?,table_csv/schedule.csv,"['(1, 2)']",['Project 2-2']
3,ss-1030,1,1,What is the next lecture?,table_csv/schedule.csv,"['(2, 2)']",['SSA']
4,ss-1030,2,0,What is happening on Wednesday at 13?,table_csv/schedule.csv,"['(2, 3)']",['TCS']
5,ss-1030,2,1,What is the next lecture?,table_csv/schedule.csv,"['(3, 3)']",['MM']
6,ss-1030,3,0,What is the schedule for Thursday at 15?,table_csv/schedule.csv,"['(3, 4)']",['Logic']
7,ss-1030,3,1,What is the next lecture?,table_csv/schedule.csv,"['(4, 4)']",['SSA']
8,ss-1030,4,0,What's on Friday at 17?,table_csv/schedule.csv,"['(4, 5)']",['Project 2-2']
9,ss-1030,5,0,What is the first class on Monday?,table_csv/schedule.csv,"['(0, 1)']",['TCS']


As you can see, each row corresponds to a question related to a table. 
* The `position` column identifies whether the question is the first, second, ... in a sequence of questions related to a table. 
* The `table_file` column identifies the name of the table file, which refers to a CSV file in the `table_csv` directory.
* The `answer_coordinates` and `answer_text` columns indicate the answer to the question. The `answer_coordinates` is a list of tuples, each tuple being a (row_index, column_index) pair. The `answer_text` column is a list of strings, indicating the cell values.

However, the `answer_coordinates` and `answer_text` columns are currently not recognized as real Python lists of Python tuples and strings respectively. Let's do that first using the `.literal_eval()`function of the `ast` module:

In [189]:
import ast

def _parse_answer_coordinates(answer_coordinate_str):
  """Parses the answer_coordinates of a question.
  Args:
    answer_coordinate_str: A string representation of a Python list of tuple
      strings.
      For example: "['(1, 4)','(1, 3)', ...]"
  """

  try:
    answer_coordinates = []
    # make a list of strings
    coords = ast.literal_eval(answer_coordinate_str)
    # parse each string as a tuple
    for row_index, column_index in sorted(
        ast.literal_eval(coord) for coord in coords):
      answer_coordinates.append((row_index, column_index))
  except SyntaxError:
    raise ValueError('Unable to evaluate %s' % answer_coordinate_str)
  
  return answer_coordinates


def _parse_answer_text(answer_text):
  """Populates the answer_texts field of `answer` by parsing `answer_text`.
  Args:
    answer_text: A string representation of a Python list of strings.
      For example: "[u'test', u'hello', ...]"
    answer: an Answer object.
  """
  try:
    answer = []
    for value in ast.literal_eval(answer_text):
      answer.append(value)
  except SyntaxError:
    raise ValueError('Unable to evaluate %s' % answer_text)

  return answer

data_schedule['answer_coordinates'] = data_schedule['answer_coordinates'].apply(lambda coords_str: _parse_answer_coordinates(coords_str))
data_schedule['answer_text'] = data_schedule['answer_text'].apply(lambda txt: _parse_answer_text(txt))

data_schedule.head(10)

Unnamed: 0,id,annotator,position,question,table_file,answer_coordinates,answer_text
0,ss-1030,0,0,What is scheduled on Monday at 9?,table_csv/schedule.csv,"[(0, 1)]",[TCS]
1,ss-1030,0,1,What is the next lecture?,table_csv/schedule.csv,"[(1, 1)]",[SSA]
2,ss-1030,1,0,What subject is at 11 on Tuesday?,table_csv/schedule.csv,"[(1, 2)]",[Project 2-2]
3,ss-1030,1,1,What is the next lecture?,table_csv/schedule.csv,"[(2, 2)]",[SSA]
4,ss-1030,2,0,What is happening on Wednesday at 13?,table_csv/schedule.csv,"[(2, 3)]",[TCS]
5,ss-1030,2,1,What is the next lecture?,table_csv/schedule.csv,"[(3, 3)]",[MM]
6,ss-1030,3,0,What is the schedule for Thursday at 15?,table_csv/schedule.csv,"[(3, 4)]",[Logic]
7,ss-1030,3,1,What is the next lecture?,table_csv/schedule.csv,"[(4, 4)]",[SSA]
8,ss-1030,4,0,What's on Friday at 17?,table_csv/schedule.csv,"[(4, 5)]",[Project 2-2]
9,ss-1030,5,0,What is the first class on Monday?,table_csv/schedule.csv,"[(0, 1)]",[TCS]


Let's create a new dataframe that groups questions which are asked in a sequence related to the table. We can do this by adding a `sequence_id` column, which is a combination of the `id` and `annotator` columns:

In [190]:
def get_sequence_id(example_id, annotator):
  if "-" in str(annotator):
    raise ValueError('"-" not allowed in annotator.')
  return f"{example_id}-{annotator}"

data['sequence_id'] = data.apply(lambda x: get_sequence_id(x.id, x.annotator), axis=1)
data.head()

Unnamed: 0,id,annotator,position,question,table_file,answer_coordinates,answer_text,sequence_id
0,nt-639,0,0,where are the players from?,table_csv/203_149.csv,"['(0, 4)', '(1, 4)', '(2, 4)', '(3, 4)', '(4, ...","['Louisiana State University', 'Valley HS (Las...",nt-639-0
1,nt-639,0,1,which player went to louisiana state university?,table_csv/203_149.csv,"['(0, 1)']",['Ben McDonald'],nt-639-0
2,nt-639,1,0,who are the players?,table_csv/203_149.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Ben McDonald', 'Tyler Houston', 'Roger Salke...",nt-639-1
3,nt-639,1,1,which ones are in the top 26 picks?,table_csv/203_149.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Ben McDonald', 'Tyler Houston', 'Roger Salke...",nt-639-1
4,nt-639,1,2,"and of those, who is from louisiana state univ...",table_csv/203_149.csv,"['(0, 1)']",['Ben McDonald'],nt-639-1


In [191]:
def get_sequence_id(example_id, annotator):
  if "-" in str(annotator):
    raise ValueError('"-" not allowed in annotator.')
  return f"{example_id}-{annotator}"

data_schedule['sequence_id'] = data_schedule.apply(lambda x: get_sequence_id(x.id, x.annotator), axis=1)
data_schedule.head()

Unnamed: 0,id,annotator,position,question,table_file,answer_coordinates,answer_text,sequence_id
0,ss-1030,0,0,What is scheduled on Monday at 9?,table_csv/schedule.csv,"[(0, 1)]",[TCS],ss-1030-0
1,ss-1030,0,1,What is the next lecture?,table_csv/schedule.csv,"[(1, 1)]",[SSA],ss-1030-0
2,ss-1030,1,0,What subject is at 11 on Tuesday?,table_csv/schedule.csv,"[(1, 2)]",[Project 2-2],ss-1030-1
3,ss-1030,1,1,What is the next lecture?,table_csv/schedule.csv,"[(2, 2)]",[SSA],ss-1030-1
4,ss-1030,2,0,What is happening on Wednesday at 13?,table_csv/schedule.csv,"[(2, 3)]",[TCS],ss-1030-2


In [192]:
# let's group table-question pairs by sequence id, and remove some columns we don't need 
grouped = data.groupby(by='sequence_id').agg(lambda x: x.tolist())
grouped = grouped.drop(columns=['id', 'annotator', 'position'])
grouped['table_file'] = grouped['table_file'].apply(lambda x: x[0])
grouped.head(10)

Unnamed: 0_level_0,question,table_file,answer_coordinates,answer_text
sequence_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ns-1292-0,"[who are all the athletes?, where are they fro...",table_csv/204_521.csv,"[['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4,...","[['Tommy Green', 'Janis Dalins', 'Ugo Frigerio..."
nt-10730-0,[what was the production numbers of each revol...,table_csv/203_253.csv,"[['(0, 4)', '(1, 4)', '(2, 4)', '(3, 4)', '(4,...","[['1,900 (estimated)', '14,500 (estimated)', '..."
nt-10730-1,[what three revolver models had the least amou...,table_csv/203_253.csv,"[['(0, 0)', '(6, 0)', '(7, 0)'], ['(0, 0)']]","[['Remington-Beals Army Model Revolver', 'New ..."
nt-10730-2,"[what are all of the remington models?, how ma...",table_csv/203_253.csv,"[['(0, 0)', '(1, 0)', '(2, 0)', '(3, 0)', '(4,...","[['Remington-Beals Army Model Revolver', 'Remi..."
nt-11649-0,"[what are all the names of the teams?, of thes...",table_csv/204_135.csv,"[['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4,...","[['Cordoba CF', 'CD Malaga', 'Granada CF', 'UD..."
nt-11649-1,"[what are the losses?, what team had more than...",table_csv/204_135.csv,"[['(0, 6)', '(1, 6)', '(2, 6)', '(3, 6)', '(4,...","[['6', '6', '9', '10', '10', '12', '12', '11',..."
nt-11649-2,"[what were all the teams?, what were the loss ...",table_csv/204_135.csv,"[['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4,...","[['Cordoba CF', 'CD Malaga', 'Granada CF', 'UD..."
nt-639-0,"[where are the players from?, which player wen...",table_csv/203_149.csv,"[['(0, 4)', '(1, 4)', '(2, 4)', '(3, 4)', '(4,...","[['Louisiana State University', 'Valley HS (La..."
nt-639-1,"[who are the players?, which ones are in the t...",table_csv/203_149.csv,"[['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4,...","[['Ben McDonald', 'Tyler Houston', 'Roger Salk..."
nt-639-2,"[who are the players in the top 26?, of those,...",table_csv/203_149.csv,"[['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4,...","[['Ben McDonald', 'Tyler Houston', 'Roger Salk..."


Each row in the dataframe above now consists of a **table and one or more questions** which are asked in a **sequence**. Let's visualize the first row, i.e. a table, together with its queries:

In [197]:

# let's drop not needed columns
#data_schedule['sequence_id'] = data_schedule['sequence_id'].str.strip()
#qa_df['sequence_id'] = qa_df['sequence_id'].str.replace('\n', '')
data_schedule_grouped = data_schedule.groupby(by = 'sequence_id').agg(lambda x: x.tolist()).reset_index()

#qa_df_clean = qa_df.groupby('sequence_id').agg(lambda x: x.tolist())
data_schedule_grouped = data_schedule_grouped.drop(columns=['position','annotator','id'])
data_schedule_grouped['table_file'] = data_schedule_grouped['table_file'].apply(lambda x: x[0])
data_schedule_grouped.head(10)
#qa_df_clean['sequence_id'].iloc[0]== 'nt-1'

Unnamed: 0,sequence_id,question,table_file,answer_coordinates,answer_text
0,ss-1030-0,"[What is scheduled on Monday at 9?, What is th...",table_csv/schedule.csv,"[[(0, 1)], [(1, 1)]]","[[TCS], [SSA]]"
1,ss-1030-1,"[What subject is at 11 on Tuesday?, What is th...",table_csv/schedule.csv,"[[(1, 2)], [(2, 2)]]","[[Project 2-2], [SSA]]"
2,ss-1030-10,[What class is at 15 on Monday?],table_csv/schedule.csv,"[[(3, 1)]]",[[Project 2-2]]
3,ss-1030-11,[What's the first class on Tuesday?],table_csv/schedule.csv,"[[(0, 2)]]",[[Logic]]
4,ss-1030-12,[What's the last class on Wednesday?],table_csv/schedule.csv,"[[(4, 3)]]",[[Calculus]]
5,ss-1030-13,[What class is at 9 on Thursday?],table_csv/schedule.csv,"[[(0, 4)]]",[[Calculus]]
6,ss-1030-14,[What's happening at 11 on Friday?],table_csv/schedule.csv,"[[(1, 5)]]",[[TCS]]
7,ss-1030-15,[What's the second class on Monday?],table_csv/schedule.csv,"[[(1, 1)]]",[[SSA]]
8,ss-1030-16,[What is the last class on Tuesday?],table_csv/schedule.csv,"[[(4, 2)]]",[[TCS]]
9,ss-1030-17,[What is the first class on Wednesday?],table_csv/schedule.csv,"[[(0, 3)]]",[[TCS]]


In [204]:
table_schedule = pd.read_csv("student_schedule_week.csv").astype(str)
display(table_schedule)
item = data_schedule_grouped.iloc[0]
print(item.question)
print(item.answer_coordinates)
print(item.answer_text) 


Unnamed: 0,Time,Monday,Tuesday,Wednesday,Thursday,Friday
0,9,TCS,Logic,TCS,Calculus,MM
1,11,SSA,Project 2-2,Project 2-2,Project 2-2,TCS
2,13,MM,SSA,TCS,SSA,TCS
3,15,Project 2-2,Calculus,MM,Logic,Calculus
4,17,MM,TCS,Calculus,SSA,Project 2-2


['What is scheduled on Monday at 9?', 'What is the next lecture?']
[[(0, 1)], [(1, 1)]]
[['TCS'], ['SSA']]


In [203]:
# path to the directory containing all csv files
table_csv_path = "table_csv"

item = grouped.iloc[0]
print(item)
table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str) 

display(table)
print("")
print(item.question)


question              [who are all the athletes?, where are they fro...
table_file                                        table_csv/204_521.csv
answer_coordinates    [['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4,...
answer_text           [['Tommy Green', 'Janis Dalins', 'Ugo Frigerio...
Name: ns-1292-0, dtype: object


Unnamed: 0,Rank,Name,Nationality,Time (hand),Notes
0,,Tommy Green,Great Britain,4:50:10,OR
1,,Janis Dalins,Latvia,4:57:20,
2,,Ugo Frigerio,Italy,4:59:06,
3,4.0,Karl Hahnel,Germany,5:06:06,
4,5.0,Ettore Rivolta,Italy,5:07:39,
5,6.0,Paul Sievert,Germany,5:16:41,
6,7.0,Henri Quintric,France,5:27:25,
7,8.0,Ernie Crosbie,United States,5:28:02,
8,9.0,Bill Chisholm,United States,5:51:00,
9,10.0,Alfred Maasik,Estonia,6:19:00,



['who are all the athletes?', 'where are they from?', 'along with paul sievert, which athlete is from germany?']


We can see that there are 3 sequential questions asked related to the contents of the table. 

We can now use `TapasTokenizer` to batch encode this, as follows:

In [212]:
import torch
from transformers import TapasTokenizer

# initialize the tokenizer
tokenizer = TapasTokenizer.from_pretrained("google/tapas-small-finetuned-sqa")

In [206]:
encoding = tokenizer(table=table_schedule, queries=item.question, answer_coordinates=item.answer_coordinates, answer_text=item.answer_text,
                     truncation=True, padding="max_length", return_tensors="pt")
encoding.keys()

dict_keys(['input_ids', 'labels', 'numeric_values', 'numeric_values_scale', 'token_type_ids', 'attention_mask'])

TAPAS basically flattens every table-question pair before feeding it into a BERT like model:

In [213]:
tokenizer.decode(encoding["input_ids"][0])

'[CLS] what is scheduled on monday at 9? [SEP] time monday tuesday wednesday thursday friday 9 tcs logic tcs calculus mm 11 ssa project 2 - 2 project 2 - 2 project 2 - 2 tcs 13 mm ssa tcs ssa tcs 15 project 2 - 2 calculus mm logic calculus 17 mm tcs calculus ssa project 2 - 2 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

The `token_type_ids` created here will be of shape (batch_size, sequence_length, 7), as TAPAS uses 7 different token types to encode tabular structure. Let's verify this:

In [209]:
encoding["token_type_ids"].shape

torch.Size([2, 512, 7])

In [210]:
assert encoding["token_type_ids"].shape == (2, 512, 7)



One thing we can verify is whether the `prev_label` token type ids are created correctly. These indicate which tokens were (part of) an answer to the previous table-question pair. 

The prev_label token type ids of the first example in a batch must always be zero (since there's no previous table-question pair). Let's verify this:

In [12]:
assert encoding["token_type_ids"][0][:,3].sum() == 0

However, the `prev_label` token type ids of the second table-question pair in the batch must be set to 1 for the tokens which were an answer to the previous (i.e. the first) table question pair in the batch. The answers to the first table-question pair are the following:

In [211]:
print(item.answer_text[0])

['TCS']


So let's now verify whether the `prev_label` ids of the second table-question pair are set correctly:

In [215]:
for id, prev_label in zip (encoding["input_ids"][1], encoding["token_type_ids"][1][:,3]):
  if id != 0: # we skip padding tokens
    print(tokenizer.decode([id]), prev_label.item())

[CLS] 0
what 0
is 0
the 0
next 0
lecture 0
? 0
[SEP] 0
time 0
monday 0
tuesday 0
wednesday 0
thursday 0
friday 0
9 0
tc 1
##s 1
logic 0
tc 0
##s 0
calculus 0
mm 0
11 0
ss 0
##a 0
project 0
2 0
- 0
2 0
project 0
2 0
- 0
2 0
project 0
2 0
- 0
2 0
tc 0
##s 0
13 0
mm 0
ss 0
##a 0
tc 0
##s 0
ss 0
##a 0
tc 0
##s 0
15 0
project 0
2 0
- 0
2 0
calculus 0
mm 0
logic 0
calculus 0
17 0
mm 0
tc 0
##s 0
calculus 0
ss 0
##a 0
project 0
2 0
- 0
2 0


This looks OK! Be sure to check this, because the token type ids are critical for the performance of TAPAS.

Let's create a PyTorch dataset and corresponding dataloader. Note the __getitem__ method here: in order to properly set the prev_labels token types, we must check whether a table-question pair is the first in a sequence or not. In case it is, we can just encode it. In case it isn't, we need to encode it together with the previous table-question pair.

Note that this is not the most efficient approach, because we're effectively tokenizing each table-question pair twice when applied on the entire dataset (feel free to ping me a more efficient solution).

In [216]:
class TableDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        table = pd.read_csv("student_schedule_week.csv").astype(str) # TapasTokenizer expects the table data to be text only
        if item.position != 0:
          # use the previous table-question pair to correctly set the prev_labels token type ids
          previous_item = self.df.iloc[idx-1]
          encoding = self.tokenizer(table=table, 
                                    queries=[previous_item.question, item.question], 
                                    answer_coordinates=[previous_item.answer_coordinates, item.answer_coordinates], 
                                    answer_text=[previous_item.answer_text, item.answer_text],
                                    padding="max_length",
                                    truncation=True,
                                    return_tensors="pt"
          )
          # use encodings of second table-question pair in the batch
          encoding = {key: val[-1] for key, val in encoding.items()}
        else:
          # this means it's the first table-question pair in a sequence
          encoding = self.tokenizer(table=table, 
                                    queries=item.question, 
                                    answer_coordinates=item.answer_coordinates, 
                                    answer_text=item.answer_text,
                                    padding="max_length",
                                    truncation=True,
                                    return_tensors="pt"
          )
          # remove the batch dimension which the tokenizer adds 
          encoding = {key: val.squeeze(0) for key, val in encoding.items()}
        return encoding

    def __len__(self):
        return len(self.df)

train_dataset = TableDataset(df=data_schedule, tokenizer=tokenizer)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2)

In [217]:
train_dataset[0]["token_type_ids"].shape

torch.Size([512, 7])

In [218]:
train_dataset[1]["input_ids"].shape

torch.Size([512])

In [219]:
batch = next(iter(train_dataloader))

In [220]:
batch["input_ids"].shape

torch.Size([2, 512])

In [221]:
batch["token_type_ids"].shape

torch.Size([2, 512, 7])

Let's decode the first table-question pair:

In [222]:
tokenizer.decode(batch["input_ids"][0])

'[CLS] what is scheduled on monday at 9? [SEP] time monday tuesday wednesday thursday friday 9 tcs logic tcs calculus mm 11 ssa project 2 - 2 project 2 - 2 project 2 - 2 tcs 13 mm ssa tcs ssa tcs 15 project 2 - 2 calculus mm logic calculus 17 mm tcs calculus ssa project 2 - 2 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

In [223]:
#first example should not have any prev_labels set
assert batch["token_type_ids"][0][:,3].sum() == 0

Let's decode the second table-question pair and verify some more:

In [224]:
tokenizer.decode(batch["input_ids"][1])

'[CLS] what is the next lecture? [SEP] time monday tuesday wednesday thursday friday 9 tcs logic tcs calculus mm 11 ssa project 2 - 2 project 2 - 2 project 2 - 2 tcs 13 mm ssa tcs ssa tcs 15 project 2 - 2 calculus mm logic calculus 17 mm tcs calculus ssa project 2 - 2 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

In [225]:
assert batch["labels"][0].sum() == batch["token_type_ids"][1][:,3].sum()
print(batch["token_type_ids"][1][:,3].sum())

tensor(2)


In [226]:
for id, prev_label in zip(batch["input_ids"][1], batch["token_type_ids"][1][:,3]):
  if id != 0:
    print(tokenizer.decode([id]), prev_label.item())

[CLS] 0
what 0
is 0
the 0
next 0
lecture 0
? 0
[SEP] 0
time 0
monday 0
tuesday 0
wednesday 0
thursday 0
friday 0
9 0
tc 1
##s 1
logic 0
tc 0
##s 0
calculus 0
mm 0
11 0
ss 0
##a 0
project 0
2 0
- 0
2 0
project 0
2 0
- 0
2 0
project 0
2 0
- 0
2 0
tc 0
##s 0
13 0
mm 0
ss 0
##a 0
tc 0
##s 0
ss 0
##a 0
tc 0
##s 0
15 0
project 0
2 0
- 0
2 0
calculus 0
mm 0
logic 0
calculus 0
17 0
mm 0
tc 0
##s 0
calculus 0
ss 0
##a 0
project 0
2 0
- 0
2 0


## Define the model

Here we initialize the model with a pre-trained base and randomly initialized cell selection head, and move it to the GPU (if available).

Note that the `google/tapas-base` checkpoint has (by default) an SQA configuration, so we don't need to specify any additional hyperparameters.

In [228]:
from transformers import TapasForQuestionAnswering

model = TapasForQuestionAnswering.from_pretrained("google/tapas-small-finetuned-sqa")
#device = torch.device("mps")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

TapasForQuestionAnswering(
  (tapas): TapasModel(
    (embeddings): TapasEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings_0): Embedding(3, 512)
      (token_type_embeddings_1): Embedding(256, 512)
      (token_type_embeddings_2): Embedding(256, 512)
      (token_type_embeddings_3): Embedding(2, 512)
      (token_type_embeddings_4): Embedding(256, 512)
      (token_type_embeddings_5): Embedding(256, 512)
      (token_type_embeddings_6): Embedding(10, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): TapasEncoder(
      (layer): ModuleList(
        (0-3): 4 x TapasLayer(
          (attention): TapasAttention(
            (self): TapasSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=Tr

## Training the model

Let's fine-tune the model in well-known PyTorch fashion:

In [229]:
from transformers import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(10):  # loop over the dataset multiple times
   print("Epoch:", epoch)
   for idx, batch in enumerate(train_dataloader):
        # get the inputs;
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        labels = batch["labels"].to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                       labels=labels)
        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()



Epoch: 0
Loss: 6.7472333908081055
Loss: 1.426924705505371
Loss: 1.3566066026687622
Loss: 4.098746299743652
Loss: 6.059797783564136e-07
Loss: 1.9401189088821411
Loss: 0.03174411505460739
Loss: 2.1853280067443848
Loss: 0.4325258433818817
Loss: 0.44027256965637207
Loss: 0.6540981531143188
Loss: 0.3549450933933258
Loss: 0.09983222186565399
Epoch: 1
Loss: 0.29138728976249695
Loss: 1.3100438117980957
Loss: 0.8841657638549805
Loss: 0.6797493696212769
Loss: 0.03665916249155998
Loss: 0.4115382134914398
Loss: 2.5765963073354214e-05
Loss: 0.0015640822239220142
Loss: 0.01928963139653206
Loss: 2.301542554050684e-05
Loss: 0.00500061921775341
Loss: 0.00021047922200523317
Loss: 5.960463411724959e-08
Epoch: 2
Loss: 0.2278648167848587
Loss: 0.43987253308296204
Loss: 0.495622843503952
Loss: 0.2884819209575653
Loss: 4.4703435264636937e-07
Loss: 0.21275582909584045
Loss: 2.980228259730211e-07
Loss: 9.23868185509491e-07
Loss: 8.356442413059995e-05
Loss: 2.284842679500798e-07
Loss: 0.00016926703392527997
Los

In [None]:
torch.save(model, 'schedule-tapas-small-finetuned-sqa.pth')


## Inference

As SQA is a bit different due to its conversational nature, we need to run every training example of the a batch one by one through the model (sequentially), overwriting the `prev_labels` token types (which were created by the tokenizer) by the answer predicted by the model. It is based on the [following code](https://github.com/google-research/tapas/blob/f458b6624b8aa75961a0ab78e9847355022940d3/tapas/experiments/prediction_utils.py#L92) from the official implementation:

In [230]:
import collections
import numpy as np

def compute_prediction_sequence(model, data, device):
  """Computes predictions using model's answers to the previous questions."""
  
  # prepare data
  input_ids = data["input_ids"].to(device)
  attention_mask = data["attention_mask"].to(device)
  token_type_ids = data["token_type_ids"].to(device)

  all_logits = []
  prev_answers = None

  num_batch = data["input_ids"].shape[0]
  
  for idx in range(num_batch):
    
    if prev_answers is not None:
        coords_to_answer = prev_answers[idx]
        # Next, set the label ids predicted by the model
        prev_label_ids_example = token_type_ids_example[:,3] # shape (seq_len,)
        model_label_ids = np.zeros_like(prev_label_ids_example.cpu().numpy()) # shape (seq_len,)

        # for each token in the sequence:
        token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)
        for i in range(model_label_ids.shape[0]):
          segment_id = token_type_ids_example[:,0].tolist()[i]
          col_id = token_type_ids_example[:,1].tolist()[i] - 1
          row_id = token_type_ids_example[:,2].tolist()[i] - 1
          if row_id >= 0 and col_id >= 0 and segment_id == 1:
            model_label_ids[i] = int(coords_to_answer[(col_id, row_id)])

        # set the prev label ids of the example (shape (1, seq_len) )
        token_type_ids_example[:,3] = torch.from_numpy(model_label_ids).type(torch.long).to(device)   

    prev_answers = {}
    # get the example
    input_ids_example = input_ids[idx] # shape (seq_len,)
    attention_mask_example = attention_mask[idx] # shape (seq_len,)
    token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)
    # forward pass to obtain the logits
    outputs = model(input_ids=input_ids_example.unsqueeze(0), 
                    attention_mask=attention_mask_example.unsqueeze(0), 
                    token_type_ids=token_type_ids_example.unsqueeze(0))
    logits = outputs.logits
    all_logits.append(logits)

    # convert logits to probabilities (which are of shape (1, seq_len))
    dist_per_token = torch.distributions.Bernoulli(logits=logits)
    probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(dist_per_token.probs.device) 

    # Compute average probability per cell, aggregating over tokens.
    # Dictionary maps coordinates to a list of one or more probabilities
    coords_to_probs = collections.defaultdict(list)
    prev_answers = {}
    for i, p in enumerate(probabilities.squeeze().tolist()):
      segment_id = token_type_ids_example[:,0].tolist()[i]
      col = token_type_ids_example[:,1].tolist()[i] - 1
      row = token_type_ids_example[:,2].tolist()[i] - 1
      if col >= 0 and row >= 0 and segment_id == 1:
        coords_to_probs[(col, row)].append(p)

    # Next, map cell coordinates to 1 or 0 (depending on whether the mean prob of all cell tokens is > 0.5)
    coords_to_answer = {}
    for key in coords_to_probs:
      coords_to_answer[key] = np.array(coords_to_probs[key]).mean() > 0.5
    prev_answers[idx+1] = coords_to_answer
    
  logits_batch = torch.cat(tuple(all_logits), 0)
  
  return logits_batch

In [233]:
data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], 
        'Age': ["56", "45", "59"],
        'Number of movies': ["87", "53", "69"],
        'Date of birth': ["7 february 1967", "10 june 1996", "28 november 1967"]}
queries = ["How many movies has George Clooney played in?", "How old is he?", "What's his date of birth?"]

table = pd.DataFrame.from_dict(data)

inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
logits = compute_prediction_sequence(model, inputs, device)

In [240]:
our_queries=["What subject is at 11 on Wednesday?","What is the next lecture?"]
inputs = tokenizer(table=table_schedule, queries=our_queries, padding='max_length', return_tensors="pt")
logits = compute_prediction_sequence(model, inputs, device)

Finally, we can use the handy `convert_logits_to_predictions` function of `TapasTokenizer` to convert the logits into predicted coordinates, and print out the result:

In [241]:
predicted_answer_coordinates, = tokenizer.convert_logits_to_predictions(inputs, logits.cpu().detach())

In [242]:
# handy helper function in case inference on Pandas dataframe
answers = []
for coordinates in predicted_answer_coordinates:
  if len(coordinates) == 1:
    # only a single cell:
    answers.append(table_schedule.iat[coordinates[0]])
  else:
    # multiple cells
    cell_values = []
    for coordinate in coordinates:
      cell_values.append(table_schedule.iat[coordinate])
    answers.append(", ".join(cell_values))

display(table)
print("")
for query, answer in zip(our_queries, answers):
  print(query)
  print("Predicted answer: " + answer)

Unnamed: 0,Actors,Age,Number of movies,Date of birth
0,Brad Pitt,56,87,7 february 1967
1,Leonardo Di Caprio,45,53,10 june 1996
2,George Clooney,59,69,28 november 1967



What subject is at 11 on Wednesday?
Predicted answer: Project 2-2
What is the next lecture?
Predicted answer: MM


Note that the results here are not correct, that's obvious since we only trained on 28 examples and tested it on an entire different example. In reality, you should train on the entire dataset. The result of this is the `google/tapas-base-finetuned-sqa` checkpoint.

## Legacy

The code below was considered during the creation of this tutorial, but eventually not used.

In [None]:
# grouped = data.groupby(data.position)
# test = grouped.get_group(0)
# test.index

In [None]:
def custom_collate_fn(data):
  """
  A custom collate function to batch input_ids, attention_mask, token_type_ids and so on of different batch sizes.
  
  Args:
    data: 
      a list of dictionaries (each dictionary is what the __getitem__ method of TableDataset returns)
  """
  result = {}
  for k in data[0].keys():
      result[k] = torch.cat([x[k] for x in data], dim=0)

  return result

class TableDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str) # TapasTokenizer expects the table data to be text only
        if item.position != 0:
          # use the previous table-question pair 
          previous_item = self.df.iloc[idx-1]
          encoding = self.tokenizer(table=table, 
                                    queries=[previous_item.question, item.question], 
                                    answer_coordinates=[previous_item.answer_coordinates, item.answer_coordinates], 
                                    answer_text=[previous_item.answer_text, item.answer_text],
                                    padding="max_length",
                                    truncation=True,
                                    return_tensors="pt"
          )
          # remove the batch dimension which the tokenizer adds 
          encoding = {key: val[-1] for key, val in encoding.items()}
          #encoding = {key: val.squeeze(0) for key, val in encoding.items()}
        else:
          # this means it's the first table-question pair in a sequence
          encoding = self.tokenizer(table=table, 
                                    queries=item.question, 
                                    answer_coordinates=item.answer_coordinates, 
                                    answer_text=item.answer_text,
                                    padding="max_length",
                                    truncation=True,
                                    return_tensors="pt"
          )
        return encoding

    def __len__(self):
        return len(self.df)

train_dataset = TableDataset(df=grouped, tokenizer=tokenizer)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, collate_fn=custom_collate_fn)