In [None]:
import re
import json
import random
from utils_data import infer_sentence_breaks
from transformers import BertTokenizer

In [None]:
with open('data/SQuAD/train.json') as data_file:
    train_json = json.load(data_file)
with open('data/SQuAD/dev.json') as data_file:
    dev_json = json.load(data_file)

In [None]:
def generate_qa_pairs_train(data):
    questions = []
    answers = []
    for passage in data["data"]:
        for paragraph in passage["paragraphs"]:
            paragraph_text = paragraph["context"]
            sentence_breaks = list(infer_sentence_breaks(paragraph_text))
            for qas in paragraph["qas"]:
                question = qas["question"]
                answer_sentences = set()
                for answer in qas["answers"]:
                    answer_start = answer["answer_start"]
                    # Map the answer fragment back to its enclosing sentence.
                    sentence = None
                    for start, end in sentence_breaks:
                        if start <= answer_start < end:
                            sentence = paragraph_text[start:end]
                            break
                    # Avoid generating duplicate answer sentences.
                    if sentence not in answer_sentences:
                        answer_sentences.add(str(sentence))
                        questions.append(question)
                        answers.append(sentence)
    return questions, answers

def generate_qa_pairs_test(data):
    q2a_dict = dict()
    candidates = []
    for passage in data["data"]:
        for paragraph in passage["paragraphs"]:
            paragraph_text = paragraph["context"]
            sentence_breaks = list(infer_sentence_breaks(paragraph_text))
            paragraph_sentences = [paragraph_text[start:end] for (start, end) in sentence_breaks]
            candidates += paragraph_sentences
            for qas in paragraph["qas"]:
                question = qas["question"]
                if question not in q2a_dict.keys():
                    q2a_dict[question] = set()
                for answer in qas["answers"]:
                    answer_start = answer["answer_start"]
                    # Map the answer fragment back to its enclosing sentence.
                    sentence = None
                    for start, end in sentence_breaks:
                        if start <= answer_start < end:
                            sentence = paragraph_text[start:end]
                            break

                    # Avoid generating duplicate answer sentences.
                    if sentence not in q2a_dict[question]:
                        q2a_dict[question].add(str(sentence))
    questions = []
    for q in q2a_dict.keys():
        questions.append(q)
    
    # remove duplicate candidates
    # candidates = list(set(candidates))
    
    ground_truth = []
    for q in questions:
        q_answers = q2a_dict[q]
        answer_ids = []
        for a in q_answers:
            answer_ids.append(candidates.index(a))
        ground_truth.append(answer_ids)
    
    return questions, candidates, ground_truth

In [None]:
ori_train_questions, ori_train_answers = generate_qa_pairs_train(train_json)

In [None]:
test_questions, test_candidates, test_ground_truth = generate_qa_pairs_test(dev_json)

In [None]:
print(len(test_questions), len(test_candidates))

In [None]:
print(len(ori_train_questions), len(ori_train_answers))

In [None]:
# with duplicate candidates
print(len(test_questions), len(test_candidates))

In [None]:
# split original train dataset to new train dataset and nre dev dataset
num_dev_samples = int(len(ori_train_questions) * 0.1)
random.seed(12345)
all_idx = [i for i in range(len(ori_train_questions))]
dev_idx = random.sample(all_idx, num_dev_samples)

In [None]:
dev_questions = [ori_train_questions[i] for i in all_idx if i in dev_idx]
dev_answers = [ori_train_answers[i] for i in all_idx if i in dev_idx]
train_questions = [ori_train_questions[i] for i in all_idx if i not in dev_idx]
train_answers = [ori_train_answers[i] for i in all_idx if i not in dev_idx]
print(len(train_questions), len(dev_questions))

In [None]:
print(len(set(train_questions)), len(set(train_answers)))
print(len(set(dev_questions)), len(set(dev_answers)))

In [None]:
saved_data = {
    'train':{
        'questions': train_questions,
        'answers': train_answers,
    },
    'dev':{
        'questions': dev_questions,
        'answers': dev_answers,
    },
    'test':{
        'questions': test_questions,
        'candidates': test_candidates,
        'ground_truth': test_ground_truth
    }
}
with open('data/squad.json', 'w', encoding='utf-8') as f:
    json.dump(saved_data, f)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
def tokenize(questions, answers):
    question_insts = []
    answer_insts = []
    for q, a in zip(questions, answers):
        question_insts.append(tokenizer.tokenize(q))
        answer_insts.append(tokenizer.tokenize(a))
    return question_insts, answer_insts

In [None]:
train_question_insts, train_answer_insts = tokenize(train_questions, train_answers)

In [None]:
train_question_insts, train_answer_insts = tokenize(ori_train_questions, ori_train_answers)

In [None]:
dev_question_insts, dev_answer_insts = tokenize(dev_questions, dev_answers)
test_question_insts, test_answer_insts = tokenize(test_questions, test_answers)

In [None]:
def calculate_average_token_number(list):
    total_token_number = 0
    for s in list:
        total_token_number += len(s)
    average_token_number =  total_token_number/len(list)
    print(average_token_number) 
    return average_token_number  



In [None]:
a = calculate_average_token_number(train_question_insts)
b = calculate_average_token_number(train_answer_insts) 

In [None]:
import seaborn as sns
%matplotlib inline
def plot_dist(len_list):
    return sns.displot(len_list)
# ax.savefig('response_length.png', dpi=200)

In [None]:
ax = plot_dist([len(s) for s in train_question_insts])