In [1]:
import numpy as np

import torch
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

from transformers import BertForQuestionAnswering, AdamW, BertTokenizerFast, pipeline
from datasets import load_dataset
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
squad_train = load_dataset('squad_v2', split='train')
squad_val = load_dataset('squad_v2', split='validation')
mlqa_en = load_dataset("mlqa", "mlqa.en.en", split="test")
mlqa_zh = load_dataset("mlqa", "mlqa.zh.zh", split="test") 

# Preprocessing

In [4]:
squad_train[0]

{'id': '56be85543aeaaa14008c9063',
 'title': 'Beyoncé',
 'context': 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".',
 'question': 'When did Beyonce start becoming popular?',
 'answers': {'text': ['in the late 1990s'], 'answer_start': [269]}}

In [5]:
def find_end(example):

    if (len(example['answers']['text']) != 0):
        context = example['context']
        text = example['answers']['text'][0]
        start_idx = example['answers']['answer_start'][0]

        end_idx = start_idx + len(text)
        
        temp = example['answers'] # to change the value
        temp['answer_end']=end_idx 
        temp['answer_start'] = start_idx # [num]->num
        temp['text'] = text # ['text']->text
    
    else:
        temp = example['answers']
        temp['answer_end'] = 0 # []->0
        temp['answer_start'] = 0 # []->0
        temp['text'] = "" # []->""
        
    return example

squad_train = squad_train.map(find_end)

In [6]:
squad_train[0]

{'id': '56be85543aeaaa14008c9063',
 'title': 'Beyoncé',
 'context': 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".',
 'question': 'When did Beyonce start becoming popular?',
 'answers': {'answer_end': 286,
  'answer_start': 269,
  'text': 'in the late 1990s'}}

In [7]:
squad_train[-10]

{'id': '5a7e05ef70df9f001a875425',
 'title': 'Matter',
 'context': 'These quarks and leptons interact through four fundamental forces: gravity, electromagnetism, weak interactions, and strong interactions. The Standard Model of particle physics is currently the best explanation for all of physics, but despite decades of efforts, gravity cannot yet be accounted for at the quantum level; it is only described by classical physics (see quantum gravity and graviton). Interactions between quarks and leptons are the result of an exchange of force-carrying particles (such as photons) between quarks and leptons. The force-carrying particles are not themselves building blocks. As one consequence, mass and energy (which cannot be created or destroyed) cannot always be related to matter (which can be created out of non-matter particles such as photons, or even out of pure energy, such as kinetic energy). Force carriers are usually not considered matter: the carriers of the electric force (photons)

In [8]:
# Tokenization
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
tokenized_train = tokenizer(squad_train['context'], squad_train['question'], truncation=True, padding=True)

In [9]:
def find_token_indexes(tokenized, dataset):
    start_token_list = []
    end_token_list = []
    answers = dataset['answers']
    for i in range(len(answers)):
        if (answers[i]['text'] != ''):
            start_token = tokenized.char_to_token(i, answers[i]['answer_start'])
            end_token = tokenized.char_to_token(i, answers[i]['answer_end'] - 1)
            
            # if start token is None, the answer passage has been truncated
            if start_token is None:
                start_token = tokenizer.model_max_length
            if end_token is None:
                end_token = tokenizer.model_max_length
        else:
            start_token = 0
            end_token = 0
            
        start_token_list.append(start_token)
        end_token_list.append(end_token)

    return start_token_list, start_token_list
    
s, e = find_token_indexes(tokenized_train, squad_train)
squad_train = squad_train.add_column("start_position", s)
squad_train = squad_train.add_column("end_position", e)

In [10]:
squad_train

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'start_position', 'end_position'],
    num_rows: 130319
})

In [11]:
batch_size = 8
train_data = TensorDataset(torch.tensor(tokenized_train['input_ids'], dtype=torch.int64), 
                           torch.tensor(tokenized_train['token_type_ids'], dtype=torch.int64), 
                           torch.tensor(tokenized_train['attention_mask'], dtype=torch.float), 
                           torch.tensor(squad_train['start_position'], dtype=torch.int64), 
                           torch.tensor(squad_train['start_position'], dtype=torch.int64))

train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Fine-Tune

In [12]:
model = BertForQuestionAnswering.from_pretrained("bert-base-multilingual-cased")
epochs = 3
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForQuestionAnswering: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-bas

In [13]:
for epoch in range(epochs):
    epoch_loss = []
    validation_loss = []
    
    total_loss = 0
    model.train()

    count=-1
    progress_bar = tqdm(train_dataloader, leave=True, position=0)
    progress_bar.set_description(f"Epoch {epoch+1}")
    for batch in progress_bar:
        count+=1
        input_ids, segment_ids, mask, start, end  = tuple(t.to(device) for t in batch)

        model.zero_grad()
        loss, start_logits, end_logits = model(input_ids = input_ids, 
                                                token_type_ids = segment_ids, 
                                                attention_mask = mask, 
                                                start_positions = start, 
                                                end_positions = end,
                                                return_dict = False)           

        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if (count % 20 == 0 and count != 0):
            avg = total_loss/count
            progress_bar.set_postfix(Loss=avg)
            
    torch.save(model.state_dict(), "./bert_" + str(epoch) + ".h5") # save for later use
    avg_train_loss = total_loss / len(train_dataloader)
    epoch_loss.append(avg_train_loss)
    print(f"Epoch {epoch} Loss: {avg_train_loss}\n")

Epoch 1: 100%|██████████| 16290/16290 [1:02:05<00:00,  4.37it/s, Loss=1.5]


Epoch 0 Loss: 1.4995357166955154



Epoch 2: 100%|██████████| 16290/16290 [1:02:21<00:00,  4.35it/s, Loss=1.18]


Epoch 1 Loss: 1.1804199021019184



Epoch 3: 100%|██████████| 16290/16290 [1:01:54<00:00,  4.39it/s, Loss=1.03]


Epoch 2 Loss: 1.0265117677540374



# Evaluation

In [14]:
# model.load_state_dict(torch.load("./bert_2.h5"))

In [15]:
from collections import Counter

def f1_score(prediction, truth):
    pred_tokens = prediction.split()
    truth_tokens = truth.split()
    common = Counter(pred_tokens) & Counter(truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_tokens)
    recall = 1.0 * num_same / len(truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def exact_match_score(prediction, truth):
    return prediction == truth

def evaluate(validation_dataset):
    # preprocess
    tokenized_validation = tokenizer(validation_dataset['context'], 
                                     validation_dataset['question'], 
                                     truncation=True, 
                                     padding=True, 
                                     return_offsets_mapping=True)

    batch_size = 8
    val_data = TensorDataset(torch.tensor(tokenized_validation['input_ids'], dtype=torch.int64), 
                            torch.tensor(tokenized_validation['token_type_ids'], dtype=torch.int64), 
                            torch.tensor(tokenized_validation['attention_mask'], dtype=torch.float))
    val_sampler = SequentialSampler(val_data)
    val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)
    
    # evaluate
    threshold = 1.0

    model.eval()
    total_f1 = 0
    total_exact_match = 0
    num_evaluated = 0

    for test_batch in tqdm(val_dataloader):
        input_ids, segment_ids, masks = tuple(t.to(device) for t in test_batch)

        with torch.no_grad():
            # prediction logits
            start_logits, end_logits = model(input_ids=input_ids,
                                            token_type_ids=segment_ids,
                                            attention_mask=masks,
                                            return_dict=False)

        # to cpu
        start_logits = start_logits.detach().cpu()
        end_logits = end_logits.detach().cpu()

        # for every sequence in batch 
        for bidx in range(len(start_logits)):
            # apply softmax to logits to get scores
            start_scores = np.array(F.softmax(start_logits[bidx], dim = 0))
            end_scores = np.array(F.softmax(end_logits[bidx], dim = 0))

            # find max for start<=end
            size = len(start_scores)
            scores = np.zeros((size, size))

            for j in range(size):
                for i in range(j+1): # include j
                    scores[i,j] = start_scores[i] + end_scores[j]

            # find best i and j
            start_pred, end_pred = np.unravel_index(scores.argmax(), scores.shape)
            answer_pred = ""
            if scores[start_pred, end_pred] > threshold:
                offsets = tokenized_validation.offset_mapping[num_evaluated]
                pred_char_start = offsets[start_pred][0]

                if end_pred < len(offsets):
                    pred_char_end = offsets[end_pred][1]
                    answer_pred = validation_dataset[num_evaluated]['context'][pred_char_start:pred_char_end]
                else:
                    answer_pred = validation_dataset[num_evaluated]['context'][pred_char_start:]

            ground_truths = validation_dataset[num_evaluated]['answers']['text']
            if ground_truths:
                best_f1 = max(f1_score(answer_pred, truth) for truth in ground_truths)
                total_f1 += best_f1

                exact_match = any(exact_match_score(answer_pred, truth) for truth in ground_truths)
                total_exact_match += int(exact_match)

            num_evaluated += 1

    avg_f1 = total_f1 / num_evaluated
    avg_exact_match = total_exact_match / num_evaluated
    print("Average F1 Score: ", avg_f1)
    print("Exact Match Score: ", avg_exact_match)

In [16]:
evaluate(squad_val)

100%|██████████| 1485/1485 [09:22<00:00,  2.64it/s]


Average F1 Score:  0.11357357518658164
Exact Match Score:  0.06333698307083298


In [17]:
evaluate(mlqa_en)

100%|██████████| 1449/1449 [09:12<00:00,  2.62it/s]


Average F1 Score:  0.2040851551647495
Exact Match Score:  0.12174288179465056


In [18]:
evaluate(mlqa_zh)

100%|██████████| 643/643 [04:02<00:00,  2.65it/s]


Average F1 Score:  0.00546084892979968
Exact Match Score:  0.004477321393809616
