In [None]:
!pip install transformers

In [None]:
import os
import torch
import collections
import json
import string
import re
import numpy as np
import torch.nn as nn
from tqdm.auto import tqdm
from transformers import BertTokenizerFast,BertModel,BertPreTrainedModel,AdamW
from torch.utils.data import DataLoader


In [None]:
def read_data(train=True):
    num_article = len(data['data'])
    if train:
        temp_data = data['data'][:(9*num_article//10)]
    else:
        temp_data = data['data'][(9*num_article//10):]
    contexts = []
    questions = []
    answers = []
    for article in temp_data:
        for p in article['paragraphs']:
            context = p['context']
            for qa in p['qas']:
                question = qa['question']
                contexts.append(context)
                questions.append(question)
                if qa['is_impossible']:
                    answers.append({'answer_start': 0, 'text': ''})
                else:
                    answers.append(qa['answers'][0])
    return contexts, questions, answers

def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx-1:end_idx-1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1     
        elif context[start_idx-2:end_idx-2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2
            
def preprocess(contexts,questions,answers,train=True):
    context_list = []
    question_list = []
    start_position = []
    end_position = []
    for i in tqdm(range(len(contexts))):
        encoding = tokenizer(contexts[i])
        token_list = tokenizer.convert_ids_to_tokens(encoding['input_ids'],skip_special_tokens=True)
        context_list.append(['[UNK]']+token_list)
        if train:
            if answers[i]['text'] == '':
                start_position.append(0)
                end_position.append(0)
            else:
                start = encoding.char_to_token(answers[i]['answer_start'])
                end = encoding.char_to_token(answers[i]['answer_end']-1)
                start_position.append(start)
                end_position.append(end)
    for i in tqdm(range(len(questions))):
        encoding = tokenizer(questions[i])
        token_list = tokenizer.convert_ids_to_tokens(encoding['input_ids'],skip_special_tokens=True)
        question_list.append(token_list)
    if train:
        return context_list,question_list,start_position,end_position
    else:
        return context_list,question_list
    
def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)

def white_space_fix(text):
    return ' '.join(text.split())

def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
    return text.lower()

def normalize_answer(s):
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
    if normalize_answer(a_gold) == normalize_answer(a_pred):
        return 1
    else:
        return 0

def compute_f1(a_gold, a_pred):
    gold_tokens = get_tokens(a_gold)
    pred_tokens = get_tokens(a_pred)
    overlap = collections.Counter(gold_tokens) & collections.Counter(pred_tokens)
    num_overlap = sum(overlap.values())
    if len(gold_tokens) == 0 or len(pred_tokens) == 0:
        if gold_tokens == pred_tokens:
            return 1
        else:
            return 0
    if num_overlap == 0:
        return 0
    precision = 1.0 * num_overlap / len(pred_tokens)
    recall = 1.0 * num_overlap / len(gold_tokens)
    f1 = (2.0 * precision * recall) / (precision + recall)
    return f1

def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        if answers[i]['answer_start']==0 and answers[i]['answer_end']==0:
            start_positions.append(0)
            end_positions.append(0)
        else:     
            start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
            end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
            # if None, the answer passage has been truncated
            if start_positions[-1] is None:
                start_positions[-1] = max_length-1
            if end_positions[-1] is None:
                end_positions[-1] = max_length-1
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    
def evaluate(model, dataset):
    with torch.no_grad():
        model.eval()
        device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        model.to(device)
        em = 0.0
        f1 = 0.0
        for i in tqdm(range(len(dataset))):
            temp_data = dataset[i]
            input_ids = temp_data['input_ids'].to(device).unsqueeze(0)
            attention_mask = temp_data['attention_mask'].to(device).unsqueeze(0)
            token_type_ids = temp_data['token_type_ids'].to(device).unsqueeze(0)
            start_score,end_score = model.get_scores(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            start_score = start_score.squeeze(0).cpu()
            end_score = end_score.squeeze(0).cpu()
            answer_start = torch.argmax(start_score).item()
            answer_end = torch.argmax(end_score).item()
            pred = ''
            length = start_score.size(0)
            if answer_start == 0 or answer_end == 0 or answer_start==(length-1) or answer_end==(length-1):
                pred = ''
            elif answer_end < answer_start:
                pred = ''
            elif answer_end - answer_start > 20:
                pred = ''
            else:
                input_ids.cpu()
                pred = tokenizer.decode(input_ids[0][answer_start:(answer_end+1)])
            gold_text = dataset[i]['gold_text']
            em += compute_exact(gold_text,pred)
            f1 += compute_f1(gold_text,pred)
        em /= len(dataset)
        f1 /= len(dataset)
        print('EM: %.5f'%em)
        print('F1: %.5f'%f1)
        return em, f1
    
def online_predict(question, context):
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    with torch.no_grad():
        temp_encoding = tokenizer(context,question,truncation=True,padding=True)
        input_ids = torch.LongTensor(temp_encoding['input_ids']).unsqueeze(0).to(device)
        token_type_ids = torch.LongTensor(temp_encoding['token_type_ids']).unsqueeze(0).to(device)
        attention_mask = torch.LongTensor(temp_encoding['attention_mask']).unsqueeze(0).to(device)
        start_score, end_score = model_bert.get_scores(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        start_score = start_score.squeeze(0).cpu()
        end_score = end_score.squeeze(0).cpu()
        answer_start = torch.argmax(start_score).item()
        answer_end = torch.argmax(end_score).item()
        pred = ''
        length = start_score.size(0)
        if answer_start == 0 or answer_end == 0 or answer_start==(length-1) or answer_end==(length-1):
            pred = ''
        elif answer_end < answer_start:
            pred = ''
        elif answer_end - answer_start > 20:
            pred = ''
        else:
            input_ids.cpu()
            pred = tokenizer.decode(input_ids[0][answer_start:(answer_end+1)])
        return pred, True

In [None]:
train_path = '../input/squad-20/train-v2.0.json'
f = open(train_path,'r')
data = json.load(f)
f.close()

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [None]:
train_contexts, train_questions, train_answers = read_data()
add_end_idx(train_answers, train_contexts)
train_context_list,train_question_list,train_start_position,train_end_position = preprocess(train_contexts,train_questions,train_answers)

In [None]:
val_contexts, val_questions, val_answers = read_data(train=False)
val_context_list,val_question_list = preprocess(val_contexts,val_questions,val_answers,train=False)

In [None]:
total_count = len(train_answers)
na_count = 0
for i in range(total_count):
    if train_answers[i]['text']=='':
        na_count+=1
print('total questions: %d'%(total_count))
print('un-answered questions: %d'%na_count)
print('un-answered percent: %.2f'%(100*(na_count/total_count)))

In [None]:
max_length = 384
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding='max_length', max_length=max_length)
add_token_positions(train_encodings, train_answers)

In [None]:
class SquadDataset(torch.utils.data.Dataset):   
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

In [None]:
class SquadDataset_Validation(torch.utils.data.Dataset):
    def __init__(self, encodings, answers):
        self.encodings = encodings
        self.answers = answers
    def __getitem__(self, idx):
        return_dict = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return_dict['gold_text'] = self.answers[idx]['text']
        return return_dict
    def __len__(self):
        return len(self.encodings.input_ids)

In [None]:
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding='max_length', max_length=max_length)
val_dataset = SquadDataset_Validation(val_encodings, val_answers)

In [None]:
class Bert_QA(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.fc = nn.Linear(config.hidden_size, 2)
        self.criterion = nn.CrossEntropyLoss()

    def forward(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        start_positions,
        end_positions,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        sequence_output = outputs.last_hidden_state 
        logits = self.fc(sequence_output) 
        context_mask = (attention_mask-token_type_ids).unsqueeze(-1)
        logits = logits + (context_mask + 1e-45).log()
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        start_loss = self.criterion(start_logits, start_positions)
        end_loss = self.criterion(end_logits, end_positions)
        loss = start_loss + end_loss
        return loss

    def get_scores(self,
        input_ids,
        attention_mask,
        token_type_ids
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True
        )
        sequence_output = outputs.last_hidden_state
        logits = self.fc(sequence_output) 
        context_mask = (attention_mask-token_type_ids).unsqueeze(-1)
        logits = logits + (context_mask + 1e-45).log()
        start_logits, end_logits = logits.split(1, dim=-1) 
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1) 
        start_score = nn.Softmax(dim=1)(start_logits)
        end_score = nn.Softmax(dim=1)(end_logits)
        return start_score,end_score

In [None]:
torch.cuda.empty_cache()
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = Bert_QA.from_pretrained('bert-base-uncased')
model.to(device)
model.train()
train_dataset = SquadDataset(train_encodings)
train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, drop_last=True)
optim = AdamW(model.parameters(), lr=3e-5)
iter_counter = 0
best_f1 = 0
avg_loss = 0

for epoch in range(3):
    print('epoch %d start!'%(epoch+1))
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        loss = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions, token_type_ids=token_type_ids)
        loss.backward()
        optim.step()
        iter_counter += 1
        avg_loss += loss.item()
        if iter_counter%100 == 0:
            avg_loss /= 100
            print('iter %d'%iter_counter)
            print('loss %.5f'%avg_loss)
            avg_loss = 0
            print()
        if iter_counter%2000 == 0:
            em,f1 = evaluate(model,val_dataset)
            model.train()
            if f1>best_f1:
                best_f1 = f1
                model.save_pretrained('../custom/bert_qa_model_uncased')
                print('best model!')
            print()
    print('epoch %d finish!'%(epoch+1))

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model_bert = Bert_QA.from_pretrained('../custom/bert_qa_model_uncased')
model_bert.to(device)
model_bert.eval()
evaluate(model_bert,val_dataset)

In [None]:
hasans_em = []
hasans_f1 = []
noans_em = []
noans_f1 = []
for i in tqdm(range(len(val_questions))):
    pred,_ = online_predict(val_questions[i],val_contexts[i])
    gold = val_answers[i]['text']
    if gold == '':
        noans_em.append(compute_exact(gold,pred))
        noans_f1.append(compute_f1(gold,pred))
    else:
        hasans_em.append(compute_exact(gold,pred))
        hasans_f1.append(compute_f1(gold,pred))

print('has-ans em: %.2f'%(100*np.mean(hasans_em)))
print('has-ans f1: %.2f'%(100*np.mean(hasans_f1)))
print('no-ans em: %.2f'%(100*np.mean(noans_em)))
print('no-ans f1: %.2f'%(100*np.mean(noans_f1)))

In [None]:
for i in range(20):
    index = np.random.randint(10000)
    c = val_contexts[index]
    q = val_questions[index]
    p_bert,_ = online_predict(q,c)
    print('context:')
    print(c)
    print('question:')
    print(q)
    print('gold answer: %s'%val_answers[index]['text'])
    print('bert answer: %s'%p_bert)
    print()