In [93]:
import json
import torch
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, BertForQuestionAnswering

In [13]:
def load_json_data(path):
    with open(path) as f:
        json_data = json.load(f)
    return json_data["data"]

def find_target_sublist(my_list, target_sublist):
    target_length = len(target_sublist)
    for i in range(len(my_list)):
        if my_list[i : i + target_length] == target_sublist:
            return i, i + target_length
    return -1, -1

In [5]:
json_datas = load_json_data("./dat/train-v2.0.json")
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2")

In [94]:
input_data = {'input_ids': [],'token_type_ids': [],'attention_mask': [],'start_positions': [],'end_positions': []}

impossible_num = 0
total_num = 0
max_len = 512
error_string = []
# for json_data in tqdm(json_datas[:1], desc="Processing articles"):
for json_data in json_datas[:1]:
    for paragraphs in json_data['paragraphs']:
        context = paragraphs["context"]
        qas = paragraphs['qas']
        for qa in qas:
            total_num += 1
            try:
                if not qa['is_impossible']: # 不使用不可能的QA解答
                    # 取得問題
                    question = qa['question']
                    
                    # 取得答案
                    answers = qa['answers'][0]['text']
                    answers_ids = tokenizer(answers).input_ids[1:-1]
                    
                    # 轉換成數字
                    inputs = tokenizer(context, question, return_tensors="pt", max_length=max_len, truncation=True)
                    inputs_ids = list(inputs.input_ids[0])

                    start_positions, end_positions = find_target_sublist(inputs_ids, answers_ids)
                    if start_positions == -1 or end_positions == -1: continue
                    start_positions, end_positions = torch.tensor([start_positions]), torch.tensor([end_positions])

                    
                    input_data['start_positions'].append(start_positions)
                    input_data['end_positions'].append(end_positions)
                    input_data['input_ids'].append(inputs.input_ids[0])
                    input_data['attention_mask'].append(inputs.attention_mask[0])
                    input_data['token_type_ids'].append(inputs.token_type_ids[0])
                    
            except Exception as e:
                error_string.append(f"{e}")
                if not qa["is_impossible"]: 
                    impossible_num += 1

error = list(set([f"{e}" for e in error_string]))
if(len(error)): print(error)
input_data = {k: pad_sequence(v, padding_value=0, batch_first=True) if v else torch.tensor([]) for k, v in input_data.items()}
input_data = {k: v[:, :max_len] for k, v in input_data.items() if v.size(0) > 0}

In [99]:
class QADataset(Dataset):
    def __init__(self, input_data):
        self.input_data = input_data

    def __len__(self):
        return len(self.input_data["input_ids"])

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_data["input_ids"][idx],
            "attention_mask": self.input_data["attention_mask"][idx],
            "start_positions": self.input_data["start_positions"][idx],
            "end_positions": self.input_data["end_positions"][idx],
        }

tensor([[  101, 24041,   144,  ...,     0,     0,     0],
        [  101, 24041,   144,  ...,     0,     0,     0],
        [  101, 24041,   144,  ...,     0,     0,     0],
        ...,
        [  101,  1130,  1382,  ...,     0,     0,     0],
        [  101,  1130,  1382,  ...,     0,     0,     0],
        [  101,  1130,  1382,  ...,     0,     0,     0]])