# BERT for Question Answer

## Setup

In [1]:
# if using colab, uncomment the below
# !pip install torch "argilla" datasets accelerate transformers setfit
# !pip install wandb

In [2]:
from datasets import load_dataset

In [3]:
squadv2 = load_dataset('squad_v2')

Found cached dataset squad_v2 (/Users/arths/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
print(squadv2)

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})


In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

## Preprocessing Data

Our sequences will look like

```
[CLS] ...question tokens... [SEP] ...context tokens... [SEP]
```

In cases where the context is too long, we'll split into multiple sequences, like

```
[CLS] ...question tokens... [SEP] ...some context tokens... [SEP]
[CLS] ...question tokens... [SEP] ...overlap from prev sequence... ...more context tokens... [SEP]
...
```

Bassed on the question tokens, the model needs to get a contiguous subset of the context tokens as the answer. Our dataset contains the start position of the answer in the original context string.

The HuggingFace tokenizer is able to map each item in the tokenized sequence to the start and end indices in the original context string.

We need to find which indices in the tokenized sequence map to the start and end of the answer so that our model knows how to predict the contiguous answer section.

If there is no answer available in a sequence, we will set the answer start and end to the `[CLS]` token.

Additionally, for context split accross multiple tokenized sequences, for sequences without the answer (or with only a part of the answer), we will treat it the same as 'no answer' sequences.

In [6]:
def map_answer(offset, ans_start, ans_end, sequence_ids):

    # get start and end indices in tokenized sequence
    idx = 0
    while sequence_ids[idx] != 1: idx += 1
    context_start = idx
    while sequence_ids[idx] == 1: idx += 1
    context_end = idx - 1

    # start with [CLS]
    start, end = 0, 0

    # if answer is not fully in this tokenized sequence, map to [CLS]
    if offset[context_end][0] > ans_end or offset[context_end][1] < ans_start:
        return start, end
    
    idx = context_start
    while idx <= context_end and offset[idx][0] <= ans_start: idx += 1
    start = idx - 1

    idx = context_end
    while idx >= context_start and offset[idx][1] >= ans_end: idx -= 1
    end = idx + 1

    return start, end

def get_answer_mapped_data(batch):
    questions = batch['question']
    contexts = batch['context']
    answers = batch['answers']

    inputs = tokenizer(
        # add data for tokenizing and padding
        questions, contexts,        # data to tokenize
        max_length=400,             # max_length per sequence
        padding='max_length',       # pad til max_length

        # handling truncation
        truncation='only_second',   # only truncate context
        stride=128,                 # overlap size
        return_overflowing_tokens=True, # tokenizer automatically 
                                        # makes extra sequences

        # get mappings to original sentence
        return_offsets_mapping=True,# used to map answer to sequence
    )

    offset_mapping = inputs.pop('offset_mapping')
    sample_map = inputs.pop('overflow_to_sample_mapping')
    starts = []
    ends = []

    for i, offset in enumerate(offset_mapping):

        map_i = sample_map[i]

        answer = answers[map_i]
        text = answer['text']
        
        # SQuAD v2 has some adversarial examples with 'unanswerable' questions
        # in this case, map to [CLS]
        if len(text) < 1:
            starts.append(0)
            ends.append(0)
            continue

        ans_start = answer['answer_start'][0]
        ans_end = ans_start + len(text[0])
        sequence_ids = inputs.sequence_ids(map_i)

        start, end = map_answer(offset, ans_start, ans_end, sequence_ids)

        starts.append(start)
        ends.append(end)

    inputs['start_positions'] = starts
    inputs['end_positions'] = ends

    return inputs

In [7]:
tokenized_squadv2 = squadv2.map(get_answer_mapped_data,
                                batched=True,
                                remove_columns=squadv2['train'].column_names)

Loading cached processed dataset at /Users/arths/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-9b81bd08df7dc75b.arrow
Loading cached processed dataset at /Users/arths/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-7f101e47c6d41ad5.arrow


In [8]:
print(tokenized_squadv2['train'][0].keys())

dict_keys(['input_ids', 'attention_mask', 'start_positions', 'end_positions'])


# Train

### Set Up HuggingFace Training

In [9]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import DefaultDataCollator

We will use DistilBERT for lower memory usage and thus faster training (from larger batch sizes).

In [10]:
dbert_qa = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-uncased')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

In [11]:
BATCH_SIZE = 16
LR = 2e-5
EPOCHS = 3
WEIGHT_DECAY = 0.01
CHKPT_DIR = 'checkpoints'

In [12]:
data_collator = DefaultDataCollator()

train_args = TrainingArguments(
    # save model every epoch
    output_dir=CHKPT_DIR,
    save_strategy='epoch',

    # epochs
    evaluation_strategy='epoch',
    num_train_epochs=EPOCHS,

    # batch sizes
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    
    # hyperparams
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,

    # log to wandb
    report_to='wandb',
)

trainer = Trainer(
    model=dbert_qa,
    args=train_args,
    train_dataset=tokenized_squadv2['train'],
    eval_dataset=tokenized_squadv2['validation'],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

### Run Training

In [13]:
import wandb

# use to log in to wandb if needed
# API_KEY = # wandb api key
# wandb.login(key=API_KEY)

wandb.init(
    project='SQuAD2.0 with Fine-Tuned DistilBERT',
    notes='Solving Standford\'s SQuAD 2.0 Q&A dataset with DistilBERT transfer learning.',
)

wandb.config = {
    'epochs': EPOCHS, 
    'learning_rate': LR, 
    'batch_size': BATCH_SIZE,
    'weight_decay': WEIGHT_DECAY,
}

trainer.train()

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33marth-shukla[0m. Use [1m`wandb login --relogin`[0m to force relogin


  2%|▏         | 500/24654 [02:13<1:43:07,  3.90it/s]

{'loss': 0.3954, 'learning_rate': 1.9594386306481706e-05, 'epoch': 0.06}


  4%|▍         | 1000/24654 [04:21<1:41:50,  3.87it/s]

{'loss': 0.2009, 'learning_rate': 1.9188772612963417e-05, 'epoch': 0.12}


  6%|▌         | 1500/24654 [06:30<1:40:22,  3.84it/s]

{'loss': 0.161, 'learning_rate': 1.878315891944512e-05, 'epoch': 0.18}


  8%|▊         | 2000/24654 [08:38<1:37:03,  3.89it/s]

{'loss': 0.1758, 'learning_rate': 1.837754522592683e-05, 'epoch': 0.24}


 10%|█         | 2500/24654 [10:47<1:35:25,  3.87it/s]

{'loss': 0.1913, 'learning_rate': 1.7971931532408534e-05, 'epoch': 0.3}


 12%|█▏        | 3000/24654 [12:56<1:33:01,  3.88it/s]

{'loss': 0.1774, 'learning_rate': 1.7566317838890245e-05, 'epoch': 0.37}


 14%|█▍        | 3500/24654 [15:05<1:31:20,  3.86it/s]

{'loss': 0.1692, 'learning_rate': 1.716070414537195e-05, 'epoch': 0.43}


 16%|█▌        | 4000/24654 [17:13<1:29:00,  3.87it/s]

{'loss': 0.1874, 'learning_rate': 1.6755090451853653e-05, 'epoch': 0.49}


 18%|█▊        | 4500/24654 [19:22<1:26:11,  3.90it/s]

{'loss': 0.155, 'learning_rate': 1.6349476758335365e-05, 'epoch': 0.55}


 20%|██        | 5000/24654 [21:30<1:24:13,  3.89it/s]

{'loss': 0.1581, 'learning_rate': 1.594386306481707e-05, 'epoch': 0.61}


 22%|██▏       | 5500/24654 [23:39<1:22:10,  3.88it/s]

{'loss': 0.1583, 'learning_rate': 1.5538249371298777e-05, 'epoch': 0.67}


 24%|██▍       | 6000/24654 [25:48<1:20:11,  3.88it/s]

{'loss': 0.1549, 'learning_rate': 1.5132635677780483e-05, 'epoch': 0.73}


 26%|██▋       | 6500/24654 [27:57<1:18:10,  3.87it/s]

{'loss': 0.1657, 'learning_rate': 1.472702198426219e-05, 'epoch': 0.79}


 28%|██▊       | 7000/24654 [30:05<1:16:00,  3.87it/s]

{'loss': 0.1613, 'learning_rate': 1.4321408290743897e-05, 'epoch': 0.85}


 30%|███       | 7500/24654 [32:14<1:13:34,  3.89it/s]

{'loss': 0.158, 'learning_rate': 1.3915794597225603e-05, 'epoch': 0.91}


 32%|███▏      | 8000/24654 [34:23<1:11:25,  3.89it/s]

{'loss': 0.1445, 'learning_rate': 1.351018090370731e-05, 'epoch': 0.97}


                                                      
 33%|███▎      | 8218/24654 [36:22<1:03:55,  4.28it/s]

{'eval_loss': 0.11020325124263763, 'eval_runtime': 63.1212, 'eval_samples_per_second': 191.758, 'eval_steps_per_second': 11.993, 'epoch': 1.0}


 34%|███▍      | 8500/24654 [37:35<1:09:12,  3.89it/s] 

{'loss': 0.1506, 'learning_rate': 1.3104567210189016e-05, 'epoch': 1.03}


 37%|███▋      | 9000/24654 [39:44<1:07:15,  3.88it/s]

{'loss': 0.132, 'learning_rate': 1.2698953516670724e-05, 'epoch': 1.1}


 39%|███▊      | 9500/24654 [41:53<1:05:25,  3.86it/s]

{'loss': 0.1232, 'learning_rate': 1.229333982315243e-05, 'epoch': 1.16}


 41%|████      | 10000/24654 [44:01<1:02:51,  3.89it/s]

{'loss': 0.1337, 'learning_rate': 1.1887726129634138e-05, 'epoch': 1.22}


 43%|████▎     | 10500/24654 [46:10<1:00:38,  3.89it/s]

{'loss': 0.1092, 'learning_rate': 1.1482112436115844e-05, 'epoch': 1.28}


 45%|████▍     | 11000/24654 [48:19<58:41,  3.88it/s]  

{'loss': 0.1387, 'learning_rate': 1.107649874259755e-05, 'epoch': 1.34}


 47%|████▋     | 11500/24654 [50:27<56:18,  3.89it/s]

{'loss': 0.1161, 'learning_rate': 1.0670885049079258e-05, 'epoch': 1.4}


 49%|████▊     | 12000/24654 [52:36<54:29,  3.87it/s]

{'loss': 0.1406, 'learning_rate': 1.0265271355560964e-05, 'epoch': 1.46}


 51%|█████     | 12500/24654 [54:45<52:05,  3.89it/s]

{'loss': 0.1366, 'learning_rate': 9.859657662042672e-06, 'epoch': 1.52}


 53%|█████▎    | 13000/24654 [56:54<49:54,  3.89it/s]

{'loss': 0.1335, 'learning_rate': 9.454043968524378e-06, 'epoch': 1.58}


 55%|█████▍    | 13500/24654 [59:03<47:48,  3.89it/s]

{'loss': 0.1242, 'learning_rate': 9.048430275006085e-06, 'epoch': 1.64}


 57%|█████▋    | 14000/24654 [1:01:11<45:40,  3.89it/s]

{'loss': 0.1303, 'learning_rate': 8.642816581487791e-06, 'epoch': 1.7}


 59%|█████▉    | 14500/24654 [1:03:20<43:58,  3.85it/s]

{'loss': 0.1324, 'learning_rate': 8.2372028879695e-06, 'epoch': 1.76}


 61%|██████    | 15000/24654 [1:05:29<41:15,  3.90it/s]

{'loss': 0.143, 'learning_rate': 7.831589194451205e-06, 'epoch': 1.83}


 63%|██████▎   | 15500/24654 [1:07:37<39:31,  3.86it/s]

{'loss': 0.1263, 'learning_rate': 7.425975500932913e-06, 'epoch': 1.89}


 65%|██████▍   | 16000/24654 [1:09:46<37:03,  3.89it/s]

{'loss': 0.1211, 'learning_rate': 7.020361807414618e-06, 'epoch': 1.95}


                                                       
 67%|██████▋   | 16436/24654 [1:12:41<32:05,  4.27it/s]

{'eval_loss': 0.12113236635923386, 'eval_runtime': 62.952, 'eval_samples_per_second': 192.273, 'eval_steps_per_second': 12.025, 'epoch': 2.0}


 67%|██████▋   | 16500/24654 [1:12:58<34:54,  3.89it/s]   

{'loss': 0.1293, 'learning_rate': 6.614748113896325e-06, 'epoch': 2.01}


 69%|██████▉   | 17000/24654 [1:15:07<32:53,  3.88it/s]

{'loss': 0.0833, 'learning_rate': 6.209134420378032e-06, 'epoch': 2.07}


 71%|███████   | 17500/24654 [1:17:16<30:40,  3.89it/s]

{'loss': 0.1166, 'learning_rate': 5.80352072685974e-06, 'epoch': 2.13}


 73%|███████▎  | 18000/24654 [1:19:25<28:37,  3.87it/s]

{'loss': 0.0852, 'learning_rate': 5.397907033341447e-06, 'epoch': 2.19}


 75%|███████▌  | 18500/24654 [1:21:34<26:23,  3.89it/s]

{'loss': 0.1123, 'learning_rate': 4.992293339823153e-06, 'epoch': 2.25}


 77%|███████▋  | 19000/24654 [1:23:42<24:16,  3.88it/s]

{'loss': 0.0939, 'learning_rate': 4.5866796463048596e-06, 'epoch': 2.31}


 79%|███████▉  | 19500/24654 [1:25:51<22:09,  3.88it/s]

{'loss': 0.1001, 'learning_rate': 4.1810659527865665e-06, 'epoch': 2.37}


 81%|████████  | 20000/24654 [1:28:00<20:05,  3.86it/s]

{'loss': 0.0923, 'learning_rate': 3.775452259268273e-06, 'epoch': 2.43}


 83%|████████▎ | 20500/24654 [1:30:08<17:52,  3.87it/s]

{'loss': 0.0949, 'learning_rate': 3.36983856574998e-06, 'epoch': 2.49}


 85%|████████▌ | 21000/24654 [1:32:17<15:50,  3.84it/s]

{'loss': 0.1096, 'learning_rate': 2.9642248722316867e-06, 'epoch': 2.56}


 87%|████████▋ | 21500/24654 [1:34:26<13:32,  3.88it/s]

{'loss': 0.0967, 'learning_rate': 2.5586111787133936e-06, 'epoch': 2.62}


 89%|████████▉ | 22000/24654 [1:36:35<11:21,  3.89it/s]

{'loss': 0.1081, 'learning_rate': 2.1529974851951005e-06, 'epoch': 2.68}


 91%|█████████▏| 22500/24654 [1:38:44<09:16,  3.87it/s]

{'loss': 0.0912, 'learning_rate': 1.7473837916768072e-06, 'epoch': 2.74}


 93%|█████████▎| 23000/24654 [1:40:52<07:05,  3.89it/s]

{'loss': 0.0987, 'learning_rate': 1.341770098158514e-06, 'epoch': 2.8}


 95%|█████████▌| 23500/24654 [1:43:01<04:57,  3.88it/s]

{'loss': 0.0964, 'learning_rate': 9.361564046402208e-07, 'epoch': 2.86}


 97%|█████████▋| 24000/24654 [1:45:10<02:48,  3.89it/s]

{'loss': 0.1084, 'learning_rate': 5.305427111219275e-07, 'epoch': 2.92}


 99%|█████████▉| 24500/24654 [1:47:18<00:39,  3.89it/s]

{'loss': 0.0927, 'learning_rate': 1.249290176036343e-07, 'epoch': 2.98}


                                                       
100%|██████████| 24654/24654 [1:49:01<00:00,  4.29it/s]

{'eval_loss': 0.14294327795505524, 'eval_runtime': 62.9034, 'eval_samples_per_second': 192.422, 'eval_steps_per_second': 12.034, 'epoch': 3.0}


100%|██████████| 24654/24654 [1:49:02<00:00,  3.77it/s]

{'train_runtime': 6542.3196, 'train_samples_per_second': 60.291, 'train_steps_per_second': 3.768, 'train_loss': 0.13702084790512412, 'epoch': 3.0}





0,1
eval/loss,▁▃█
eval/runtime,█▃▁
eval/samples_per_second,▁▆█
eval/steps_per_second,▁▆█
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▂▂▁▁▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.14294
eval/runtime,62.9034
eval/samples_per_second,192.422
eval/steps_per_second,12.034
train/epoch,3.0
train/global_step,24654.0
train/learning_rate,0.0
train/loss,0.0927
train/total_flos,4.02621817931424e+16
train/train_loss,0.13702


## Inference

In a previous project, I fine-tuned DistilBERT for Sentiment Analysis on Sentiment140 (https://github.com/arth-shukla/sentiment140-bert-transfer-learning). Here, I found that only 1-2 epochs were needed for optimal fine-tuning.

Similarly, as seen above, after the first epoch validation loss only increases. It's likely the model starts overfitting here. So, I load the model from the end of the first epoch.

Feel free to play around with the model below! It seems to have the most trouble with 
1. questions using language it hasn't seen before (e.g. non-American names)
2. deducing "unanswerable" questions.
3. switching around context in a sentence

In [13]:
# TODO: add code to install my model checkpoints from gdrive

In [15]:
from transformers import pipeline

trained_dbert_qa = AutoModelForQuestionAnswering.from_pretrained('./checkpoints/checkpoint-8218')
question_answerer = pipeline('question-answering', model=trained_dbert_qa, tokenizer=tokenizer)

In [16]:
# Information from https://en.wikipedia.org/wiki/List_of_highest-grossing_films
c1 = 'With a worldwide box-office gross of over $2.9 billion, Avatar is proclaimed to be the "highest-grossing" film, but such claims usually refer to theatrical revenues only and do not take into account home video and television income, which can form a significant portion of a film\'s earnings.'

q1a = 'How much did Avatar gross worldwide?'
q1a_ans = question_answerer(question=q1a, context=c1)
print('q1a', q1a_ans)

q1b = 'What is proclaimed to be the "highest-grossing" film?'
q1b_ans = question_answerer(question=q1b, context=c1)
print('q1b', q1b_ans)

q1a {'score': 9.966933056659855e-09, 'start': 37, 'end': 54, 'answer': 'over $2.9 billion'}
q1b {'score': 7.906986732031385e-10, 'start': 0, 'end': 46, 'answer': 'With a worldwide box-office gross of over $2.9'}


In [17]:
q2 = 'How is Sandy feeling today?'
c2 = 'Jim said he feels absolutely awful.'
q2_ans = question_answerer(question=q2, context=c2)
print('q2', q2_ans)

q2 {'score': 0.0008811914012767375, 'start': 18, 'end': 34, 'answer': 'absolutely awful'}


In [18]:
# information from https://planetradio.co.uk/hits-radio/entertainment/music/taylor-swift-awards/
c3 = 'As one of the busiest women in music, it\'ll come as no surprise that Taylor has won pretty much every award there is to win in the biz - being the proud owner of no less than 12 GRAMMY Awards.'

q3 = 'How many Grammys has Taylor Swift won?'
q3_ans = question_answerer(question=q3, context=c3)
print('q3', q3_ans)

q3 {'score': 1.4922909485903801e-07, 'start': 162, 'end': 191, 'answer': 'no less than 12 GRAMMY Awards'}


In [19]:
# information from https://www.1057thepoint.com/music-news/hozier-announces-new-album-unreal-unearth/
c4 = 'Hozier has announced a new album called Unreal Unearth. The third full-length effort from the “Take Me to Church” artist — and his first since 2019\'s Wasteland, Baby! — arrives August 18.'

q4 = 'What is Hozier\'s new album?'
q4_ans = question_answerer(question=q4, context=c4)
print('q4', q4_ans)

q4 {'score': 3.425412273827533e-07, 'start': 40, 'end': 54, 'answer': 'Unreal Unearth'}


In [20]:
import wandb

# use to log in to wandb if needed
# API_KEY = # wandb api key
# wandb.login(key=API_KEY)

wandb.init(
    project='SQuAD2.0 with Fine-Tuned DistilBERT',
    notes='Solving Standford\'s SQuAD 2.0 Q&A dataset with DistilBERT transfer learning.',
)

wandb.config = {
    'epochs': EPOCHS, 
    'learning_rate': LR, 
    'batch_size': BATCH_SIZE,
    'weight_decay': WEIGHT_DECAY,
}

table = wandb.Table(columns=['Content', 'Question', 'Model Answer'], 
            data=[[c1, q1a, q1a_ans], [c1, q1b, q1b_ans], [c2, q2, q2_ans], [c3, q3, q3_ans], [c4, q4, q4_ans]])

wandb.log({ 'QA Inference': table })

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33marth-shukla[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.003 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.661013…