# How to Train Bert For Q&A in Any Language

In [1]:
pip install -qqq tqdm datasets transformers torch

In [2]:
from tqdm.auto import tqdm  # for showing progress bar
from datasets import load_dataset
from transformers import TrainingArguments, Trainer, DefaultDataCollator
from transformers import BertTokenizerFast, BertTokenizer, BertForQuestionAnswering
from transformers import AdamW
import torch
import pandas as pd

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]

In [3]:
squad = load_dataset('squad')



  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [5]:
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [6]:
squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [7]:
squad_train = squad['train'].select(range(1000))

In [8]:
squad_val = squad['validation'].select(range(1000))

In [9]:
tokenized_squad_train = squad_train.map(preprocess_function, batched=True, remove_columns=squad_train.column_names)



In [10]:
tokenized_squad_val = squad_val.map(preprocess_function, batched=True, remove_columns=squad_val.column_names)

  0%|          | 0/1 [00:00<?, ?ba/s]

In [11]:
data_collator = DefaultDataCollator()

Create Model

In [12]:
training_args = TrainingArguments(
    output_dir="bert-base-uncased-finetuned-squad",
    evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=20,
    num_train_epochs=10,
    weight_decay=0.01,
)

In [13]:
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [14]:
trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=tokenized_squad_train,
                  eval_dataset=tokenized_squad_val,
                  tokenizer=tokenizer,
                  data_collator=data_collator)

In [15]:
trainer.train()

***** Running training *****
  Num examples = 1000
  Num Epochs = 10
  Instantaneous batch size per device = 20
  Total train batch size (w. parallel, distributed & accumulation) = 20
  Gradient Accumulation steps = 1
  Total optimization steps = 500


Step,Training Loss,Validation Loss
500,1.569,2.284388


***** Running Evaluation *****
  Num examples = 1000
  Batch size = 20
Saving model checkpoint to bert-base-uncased-finetuned-squad/checkpoint-500
Configuration saved in bert-base-uncased-finetuned-squad/checkpoint-500/config.json
Model weights saved in bert-base-uncased-finetuned-squad/checkpoint-500/pytorch_model.bin
tokenizer config file saved in bert-base-uncased-finetuned-squad/checkpoint-500/tokenizer_config.json
Special tokens file saved in bert-base-uncased-finetuned-squad/checkpoint-500/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=500, training_loss=1.5690496826171876, metrics={'train_runtime': 697.2397, 'train_samples_per_second': 14.342, 'train_steps_per_second': 0.717, 'total_flos': 1959725675520000.0, 'train_loss': 1.5690496826171876, 'epoch': 10.0})