## Task description
- Chinese Extractive Question Answering
  - Input: Paragraph + Question
  - Output: Answer

- Objective: Fine tune a pretrained model on downstream task using transformers

- Todo
    - Fine tune a pretrained chinese BERT model
    - Change hyperparameters (e.g. doc_stride)
    - Apply linear learning rate decay
    - Try other pretrained models
    - Improve preprocessing
    - Improve postprocessing
- Training tips
    - Automatic mixed precision
    - Gradient accumulation
    - Ensemble
  

In [1]:
import json
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, Dataset 
from transformers import AdamW, BertForQuestionAnswering, BertTokenizerFast

from tqdm.auto import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

# Fix random seed for reproducibility
def same_seeds(seed):
    
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
same_seeds(0)

  from .autonotebook import tqdm as notebook_tqdm
NOTE: Redirects are currently not supported in Windows or MacOs.


In [9]:
fp16_training = False

if fp16_training:
    from accelerate import Accelerator
    accelerate  = Accelerator(fp16 = True)
    device = accelerate.device

## Load Model and Tokenizer


In [2]:
model = BertForQuestionAnswering.from_pretrained("bert-base-chinese").to(device)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForQuestionAnswering: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-chinese a

## Read Data

- Training set: 31690 QA pairs
- Dev set: 4131  QA pairs
- Test set: 4957  QA pairs

- {train/dev/test}_questions:	
  - List of dicts with the following keys:
   - id (int)
   - paragraph_id (int)
   - question_text (string)
   - answer_text (string)
   - answer_start (int)
   - answer_end (int)
- {train/dev/test}_paragraphs: 
  - List of strings
  - paragraph_ids in questions correspond to indexs in paragraphs
  - A paragraph may be used by several questions 

In [13]:
import os
def read_data(file):
    with open(file, 'r', encoding='UTF-8') as reader:
        data = json.load(reader)
    return data["questions"], data["paragraphs"]

prefix = "D://Datasets/"

train_questions, train_paragraphs = read_data(os.path.join(prefix, "hw7_train.json"))
dev_questions, dev_paragraphs = read_data(os.path.join(prefix, "hw7_dev.json"))
test_questions, test_paragraphs = read_data(os.path.join(prefix, "hw7_test.json"))

## Tokenize Data

In [5]:
# Tokenize questions and paragraphs separately
# 「add_special_tokens」 is set to False since special tokens will be added when tokenized questions and paragraphs are combined in datset __getitem__ 
train_questions_tokenized = tokenizer([train_question["question_text"] for train_question in train_questions], add_special_tokens=False)
dev_questions_tokenized = tokenizer([dev_question["question_text"] for dev_question in dev_questions], add_special_tokens=False)
test_questions_tokenized = tokenizer([test_question["question_text"] for test_question in test_questions], add_special_tokens=False) 

train_paragraphs_tokenized = tokenizer(train_paragraphs, add_special_tokens=False)
dev_paragraphs_tokenized = tokenizer(dev_paragraphs, add_special_tokens=False)
test_paragraphs_tokenized = tokenizer(test_paragraphs, add_special_tokens=False)



Token indices sequence length is longer than the specified maximum sequence length for this model (570 > 512). Running this sequence through the model will result in indexing errors


## Dataset and Dataloader

In [6]:
class QA_Dataset(Dataset):
    def __init__(self, split, questions, tokenized_questions, tokenized_paragraphs):
        self.split = split
        self.questions = questions
        self.tokenized_questions = tokenized_questions
        self.tokenized_paragraphs = tokenized_paragraphs
        self.max_question_len = 40
        self.max_paragraph_len = 150
        
        self.doc_stride = 150
        
        #Input sequence length = [CLS] + question + [SEP] + paragraph + [SEP]
        self.max_seq_len = 1 + self.max_question_len + 1 + self.max_paragraph_len + 1
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        tokenized_question = self.tokenized_questions[idx]
        tokenized_paragraph = self.tokenized_paragraphs[question['paragraph_id']]
        
        #### TODO: Preprocessing ####
        if self.split == "train":
            answer_start_token = tokenized_paragraph.char_to_token(question["answer_start"])
            answer_end_token = tokenized_paragraph.char_to_token(question["answer_end"])
            
            mid = (answer_start_token + answer_end_token) // 2
            paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))
            paragraph_end = paragraph_start + self.max_paragraph_len
            
            ## Slice question/paragraph and add special tokens (101: CLS, 102: SEP)
            input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]
            input_ids_paragraph = tokenized_paragraph.ids[paragraph_start:paragraph_end] + [102]
            
            #Convert answer's start/end positions in tokenized_paragraph to start/end positions in the window  
            answer_start_token += len(input_ids_question) - paragraph_start
            answer_end_token += len(input_ids_question) - paragraph_start
            
            # Pad sequence and obtain inputs to model 
            inputs_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)
            
            return torch.tensor(inputs_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), answer_start_token, answer_end_token
        else:
            input_ids_list, token_type_ids_list, attention_mask_list = [], [], []
            # Paragraph is split into several windows, each with start positions separated by step "doc_stride"
            for i in range(0, len(tokenized_paragraph), self.doc_stride):
                # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)
                input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]
                input_ids_paragraph = tokenized_paragraph.ids[i : (i + self.max_paragraph_len)] + [102]
                # Pad sequence and obtain inputs to model
                input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)
                
                input_ids_list.append(input_ids)
                token_type_ids_list.append(token_type_ids)
                attention_mask_list.append(attention_mask)
                
            return torch.tensor(input_ids_list), torch.tensor(token_type_ids_list), torch.tensor(attention_mask_list)
                    
    
    def padding(self, input_ids_question, input_ids_paragraph):
        padding_len = self.max_seq_len - len(input_ids_question) - len(input_ids_paragraph)
        inputs_ids = input_ids_question + input_ids_paragraph + [0]*padding_len
        # Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]
        token_type_ids = [0] * len(input_ids_question) + [1] * len(input_ids_paragraph) + [0] * padding_len
        # Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]
        attention_mask = [1] * (len(input_ids_question) + len(input_ids_paragraph)) + [0] * padding_len
        
        return inputs_ids, token_type_ids, attention_mask

    
train_set = QA_Dataset("train", train_questions, train_questions_tokenized, train_paragraphs_tokenized)
dev_set = QA_Dataset("dev", dev_questions, dev_questions_tokenized, dev_paragraphs_tokenized)
test_set = QA_Dataset("test", test_questions, test_questions_tokenized, test_paragraphs_tokenized)

train_batch_size = 32

train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, pin_memory=False)

# Note: Do NOT change batch size of dev_loader / test_loader !
# Although batch size=1, it is actually a batch consisting of several windows from the same QA pair
dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, pin_memory=False)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=False)

## Function for Evaluation

In [7]:
def evaluate(data, output):
    
    answer = ''
    max_prob = float('-inf')
    num_of_windows = data[0].shape[1]
    
    for k in range(num_of_windows):
        start_prob, start_index = torch.max(output.start_logits[k], dim=0)
        end_prob, end_index = torch.max(output.end_logits[k], dim=0)
        
        prob = start_prob + end_prob
        
        if prob > max_prob:
            max_prob = prob
            answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])
    
    return answer.replace(" ", '')


## Training

In [15]:
num_epoch = 1
validation = True
logging_step = 100
learning_rate = 1e-4

optimizer = AdamW(model.parameters(), lr=learning_rate)
## linear learning rate decay
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=20)

if fp16_training:
    model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)

print('Start training\n')

for epoch in range(num_epoch):
    model.train()
    step = 1
    train_loss = train_acc = 0
    bar = tqdm(train_loader, desc = f'Epoch:{epoch} ')
    for data in bar:
        data = [x.to(device) for x in data]
        
        output = model(input_ids=data[0], token_type_ids=data[1], attention_mask=data[2], start_positions=data[3], end_positions=data[4])
        
        start_index = torch.argmax(output.start_logits, dim=1)
        end_index = torch.argmax(output.end_logits, dim=1)
        
        train_acc += ((start_index == data[3]) & (end_index == data[4])).float().mean()
        train_loss += output.loss
        
        bar.set_postfix(loss = output.loss.item())
        
        optimizer.zero_grad()
        if fp16_training:
            accelerator.backward(output.loss)
        else:
            output.loss.backward()
        
        optimizer.step()
        lr_scheduler.step()
        step += 1
        
        if step % logging_step == 0:
            print(f'Epoch {epoch} | Step {step} | loss = {train_loss.item() / logging_step:.3f}, acc={train_acc.item()}')
    
    if validation:
        print('Start validation ...')
        model.eval()
        with torch.no_grad():
            dev_acc = 0
            for i, data in enumerate(tqdm(dev_loader)):
                data = [x.to(device) for x in data]
                output = model(input_ids=data[0][0], token_type_ids=data[1][0], attention_mask=data[2][0])
                
                dev_acc += evaluate(data, output) == dev_questions[i]['answer_text']
            
            print(f'Validation | Epoch {epoch + 1} | acc = {dev_acc / len(dev_loader)}')
            
    
        


Start training



Epoch:0 :  10%|█████▉                                                     | 99/991 [01:07<09:01,  1.65it/s, loss=0.607]

Epoch 0 | Step 100 | loss = 0.409, acc=80.625


Epoch:0 :  20%|███████████▋                                              | 199/991 [02:09<08:04,  1.64it/s, loss=0.743]

Epoch 0 | Step 200 | loss = 0.823, acc=161.09375


Epoch:0 :  30%|█████████████████▍                                        | 299/991 [03:27<07:27,  1.55it/s, loss=0.355]

Epoch 0 | Step 300 | loss = 1.259, acc=240.9375


Epoch:0 :  40%|███████████████████████▎                                  | 399/991 [04:32<06:12,  1.59it/s, loss=0.587]

Epoch 0 | Step 400 | loss = 1.689, acc=321.03125


Epoch:0 :  50%|█████████████████████████████▏                            | 499/991 [05:39<06:20,  1.29it/s, loss=0.407]

Epoch 0 | Step 500 | loss = 2.079, acc=402.4375


Epoch:0 :  60%|███████████████████████████████████                       | 599/991 [06:42<04:03,  1.61it/s, loss=0.525]

Epoch 0 | Step 600 | loss = 2.484, acc=483.90625


Epoch:0 :  71%|████████████████████████████████████████▉                 | 699/991 [07:43<02:59,  1.63it/s, loss=0.816]

Epoch 0 | Step 700 | loss = 2.920, acc=564.28125


Epoch:0 :  81%|██████████████████████████████████████████████▊           | 799/991 [08:55<02:55,  1.09it/s, loss=0.653]

Epoch 0 | Step 800 | loss = 3.313, acc=645.59375


Epoch:0 :  91%|████████████████████████████████████████████████████▌     | 899/991 [10:23<01:13,  1.25it/s, loss=0.456]

Epoch 0 | Step 900 | loss = 3.688, acc=728.21875


Epoch:0 : 100%|██████████████████████████████████████████████████████████| 991/991 [11:44<00:00,  1.41it/s, loss=0.437]


Start validation ...


100%|██████████████████████████████████████████████████████████████████████████████| 4131/4131 [02:22<00:00, 28.89it/s]

Validation | Epoch 1 | acc = 0.4279835390946502





## Testing 

In [22]:
result = []
model.eval()
with torch.no_grad():
    for data in tqdm(test_loader):
        data = [x.to(device) for x in data]
        output = model(input_ids=data[0][0], token_type_ids=data[1][0], attention_mask=data[2][0])
        result.append(evaluate(data, output))
    

100%|██████████████████████████████████████████████████████████████████████████████| 4957/4957 [03:05<00:00, 26.77it/s]
