# BERT for Question Answer

## Setup

In [None]:
from datasets import load_dataset

In [None]:
squadv2 = load_dataset('squad_v2')

In [None]:
print(squadv2)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

## Preprocessing Data

Our sequences will look like

```
[CLS] ...question tokens... [SEP] ...context tokens... [SEP]
```

In cases where the context is too long, we'll split into multiple sequences, like

```
[CLS] ...question tokens... [SEP] ...some context tokens... [SEP]
[CLS] ...question tokens... [SEP] ...overlap from prev sequence... ...more context tokens... [SEP]
...
```

Bassed on the question tokens, the model needs to get a contiguous subset of the context tokens as the answer. Our dataset contains the start position of the answer in the original context string.

The HuggingFace tokenizer is able to map each item in the tokenized sequence to the start and end indices in the original context string.

We need to find which indices in the tokenized sequence map to the start and end of the answer so that our model knows how to predict the contiguous answer section.

If there is no answer available in a sequence, we will set the answer start and end to the `[CLS]` token.

Additionally, for context split accross multiple tokenized sequences, for sequences without the answer (or with only a part of the answer), we will treat it the same as 'no answer' sequences.

In [None]:
def map_answer(offset, ans_start, ans_end, sequence_ids):

    # get start and end indices in tokenized sequence
    idx = 0
    while sequence_ids[idx] != 1: idx += 1
    context_start = idx
    while sequence_ids[idx] == 1: idx += 1
    context_end = idx - 1

    # start with [CLS]
    start, end = 0, 0

    # if answer is not fully in this tokenized sequence, map to [CLS]
    if offset[context_end][0] > ans_end or offset[context_end][1] < ans_start:
        return start, end
    
    idx = context_start
    while idx <= context_end and offset[idx][0] <= ans_start: idx += 1
    start = idx - 1

    idx = context_end
    while idx >= context_start and offset[idx][1] >= ans_end: idx -= 1
    end = idx + 1

    return start, end

def get_answer_mapped_data(batch):
    questions = batch['question']
    contexts = batch['context']
    answers = batch['answers']

    inputs = tokenizer(
        # add data for tokenizing and padding
        questions, contexts,        # data to tokenize
        max_length=400,             # max_length per sequence
        padding='max_length',       # pad til max_length

        # handling truncation
        truncation='only_second',   # only truncate context
        stride=128,                 # overlap size
        return_overflowing_tokens=True, # tokenizer automatically 
                                        # makes extra sequences

        # get mappings to original sentence
        return_offsets_mapping=True,# used to map answer to sequence
    )

    offset_mapping = inputs.pop('offset_mapping')
    sample_map = inputs.pop('overflow_to_sample_mapping')
    starts = []
    ends = []

    for i, offset in enumerate(offset_mapping):

        map_i = sample_map[i]

        answer = answers[map_i]
        text = answer['text']
        
        # SQuAD v2 has some adversarial examples with 'unanswerable' questions
        # in this case, map to [CLS]
        if len(text) < 1:
            starts.append(0)
            ends.append(0)
            continue

        ans_start = answer['answer_start'][0]
        ans_end = ans_start + len(text[0])
        sequence_ids = inputs.sequence_ids(map_i)

        start, end = map_answer(offset, ans_start, ans_end, sequence_ids)

        starts.append(start)
        ends.append(end)

    inputs['start_positions'] = starts
    inputs['end_positions'] = ends

    return inputs

In [None]:
tokenized_squadv2 = squadv2.map(get_answer_mapped_data,
                                batched=True,
                                remove_columns=squadv2['train'].column_names)

In [None]:
print(tokenized_squadv2['train'][0].keys())

# Train

### Set Up HuggingFace Training

In [None]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import DefaultDataCollator

We will use DistilBERT for lower memory usage and thus faster training (from larger batch sizes).

In [None]:
dbert_qa = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-uncased')

In [None]:
BATCH_SIZE = 16
LR = 2e-5
EPOCHS = 3
WEIGHT_DECAY = 0.01
CHKPT_DIR = 'checkpoints'

In [None]:
data_collator = DefaultDataCollator()

train_args = TrainingArguments(
    # save model
    output_dir=CHKPT_DIR,

    # epochs
    evaluation_strategy='epoch',
    num_train_epochs=EPOCHS,

    # batch sizes
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    
    # hyperparams
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,

    # log to wandb
    report_to='wandb',

)

trainer = Trainer(
    model=dbert_qa,
    args=train_args,
    train_dataset=tokenized_squadv2['train'],
    eval_dataset=tokenized_squadv2['validation'],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

### Run Training

In [None]:
import wandb

# use to log in to wandb if needed
# API_KEY = # wandb api key
# wandb.login(key=API_KEY)

wandb.init(
    project='SQuAD2.0 with Fine-Tuned DistilBERT',
    notes='Solving Standford\'s SQuAD 2.0 Q&A dataset with DistilBERT transfer learning.',
)

wandb.config = {
    'epochs': EPOCHS, 
    'learning_rate': LR, 
    'batch_size': BATCH_SIZE,
    'weight_decay': WEIGHT_DECAY,
}

trainer.train()

wandb.finish()