In [1]:
# handle gemformer imports
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

In [2]:
import torch
import numpy as np
from tqdm import tqdm
from transformers import RobertaTokenizerFast
from datasets import load_from_disk
from datasets.arrow_dataset import Dataset

from gemformer.utils import add_qa_evidence_tokens, pad_and_drop_duplicates

stride = 20
max_length = 512

In [None]:
def tokenize_mem(mem):
    tokens = tokenizer(mem['mem'], max_length=202, truncation=True)
    tokens.update({'mem': [kk[1:-1] for kk in tokens['input_ids']]})
    return tokens

## MuSiQue

In [None]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
tokenizer = add_qa_evidence_tokens(tokenizer, tokens_to_add=['[para]'])

max_num_answers = 1
max_num_paragraphs = 20
max_num_sentences = 152 # from train and val data
eos = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
PARA_MARKER_token = tokenizer.convert_tokens_to_ids('[para]')

In [None]:
data = Dataset.from_dict({'mem': torch.load("../yake_mem_strings_musique_train.pkl")})
mem_list = data.map(tokenize_mem,
                    batched=True,
                    remove_columns=data.column_names
                   )
train_dataset = torch.load('../musique_train_examples_allenai_style_with_para_seps.pkl')
list_train_dataset = train_dataset[:]
list_train_dataset.update({'mem_tokens': mem_list})
mem_train_dataset = Dataset.from_dict(list_train_dataset)

In [None]:
def preprocess_roberta_training_examples_musique(train_examples):
    questions = [q.strip() for q in train_examples["question"]]
    inputs = tokenizer(
        questions,
        train_examples['context'],
        max_length=max_length-len(train_examples['mem_tokens'][0]),
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    raw_ids = [train_examples['id'][kk] for kk in inputs['overflow_to_sample_mapping']]
    # store start/end positions of context to filter part of sequence for uncertainty-based topk
    inputs.update({"id": train_examples['id']})
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = train_examples['char_answer_offsets']
    batch_question_type = [] # yes = 0, no = 1, span = 2
    ex_context_start_id = []
    ex_context_end_id = []
    batch_start_positions_list = []
    batch_end_positions_list = []
    supp_paras = train_examples['supp_para_char_offsets']
    batch_title_positions = []
    batch_sent_positions = []
    batch_titles_to_sents = []

    for ex_id in train_examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        context_token_ids = [np.where(np.array(inputs['input_ids'][i]) == eos)[0] for i in ex_indices]
        ex_context_start_id.append([elem[1] + 1 for elem in context_token_ids])
        ex_context_end_id.append([elem[-1] for elem in context_token_ids])

        start_positions_list = []
        end_positions_list = []

        supp_para_start_positions = []
        supp_para_end_positions = []

        for ex_sample in ex_indices:
            offset_idx = ex_sample
            offset = offset_mapping[offset_idx]
            sample_idx = sample_map[offset_idx]
            sequence_ids = inputs.sequence_ids(offset_idx)
            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            start_positions = []
            end_positions = []

            for answer in [answers[sample_idx]]:
                start_char = answer[0]
                end_char = answer[1]

                if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                    continue
                else:
                    idx = context_start
                    while idx <= context_end and offset[idx][0] <= start_char:
                        idx += 1
                    start_positions.append(idx - 1)

                    idx = context_end
                    while idx >= context_start and offset[idx][1] >= end_char:
                        idx -= 1
                    end_positions.append(idx + 1)

        start_positions, end_positions = pad_and_drop_duplicates(start_positions, 
                                                                 end_positions, 
                                                                 max_num_answers)
        start_positions_list.append(start_positions)
        end_positions_list.append(end_positions)
        
        supp_para_start_positions.append([])
        for supp_para_idx in range(len(supp_paras[sample_idx])):
            supp_para = supp_paras[sample_idx][supp_para_idx]
            supp_para_start_char = supp_para[0]
            supp_para_end_char = supp_para[1]

            if offset[context_start][0] > supp_para_start_char or offset[context_end][1] < supp_para_end_char:
                continue
            else:
                idx1 = context_start
                while idx1 <= context_end and offset[idx1][0] <= supp_para_start_char:
                    idx1 += 1
                supp_para_start_positions[-1].append(idx1 - 1)

        title_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_para_start_positions):
            title_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == PARA_MARKER_token)[0]])

            title_positions[-1] = pad_and_drop_duplicates(start_positions=title_positions[-1], 
                                                        max_num_answers=max_num_paragraphs)

        batch_start_positions_list.append(start_positions_list)
        batch_end_positions_list.append(end_positions_list)
        batch_title_positions.append(title_positions)

    inputs.update({"start_positions": batch_start_positions_list})
    inputs.update({"end_positions": batch_end_positions_list})
    inputs.update({'context_start_id': ex_context_start_id})
    inputs.update({'context_end_id': ex_context_end_id})
    inputs.update({"supp_para_labels": batch_title_positions})

    rearranged_inps = []
    rearranged_masks = []

    for ex_id in train_examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        rearranged_inps.append([inputs['input_ids'][i] for i in ex_indices])
        rearranged_masks.append([inputs['attention_mask'][i] for i in ex_indices])

    inputs.update({'input_ids': rearranged_inps})
    inputs.update({'attention_mask': rearranged_masks})

    return inputs


mem_train_dataset = mem_train_dataset.map(
    preprocess_roberta_training_examples_musique,
    batched=True,
    batch_size=1,
    remove_columns=mem_train_dataset.column_names,
)

tmp = mem_train_dataset[:]
tmp.update({'mem_tokens': mem_list})
mem_train_dataset = Dataset.from_dict(tmp)
mem_train_dataset.save_to_disk('../musique_yake_mem_train_dataset')

In [None]:
data = Dataset.from_dict({'mem': torch.load("../yake_mem_strings_musique_val.pkl")})
mem_list = data.map(tokenize_mem,
              batched=True,
              remove_columns=data.column_names,
             )
val_dataset = torch.load('../musique_val_examples_allenai_style_with_para_seps_no_mem_seps.pkl')
list_val_dataset = val_dataset[:]
list_val_dataset.update({'mem_tokens': mem_list[:len(val_dataset)]})
mem_val_dataset = Dataset.from_dict(list_val_dataset)

In [None]:
def preprocess_roberta_validation_examples_musique(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length-len(examples['mem_tokens'][0]),
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # store raw sample id to match subsamples related to the same big context document 
    raw_ids = [examples['id'][kk] for kk in inputs['overflow_to_sample_mapping']]
    inputs.update({"example_ids": examples['id']})
    offset_mapping = inputs["offset_mapping"]
    sample_map = inputs.pop("overflow_to_sample_mapping")
    ex_context_start_id = []
    ex_context_end_id = []
    supp_paras = examples['supp_para_char_offsets']
    batch_title_positions = []
    batch_titles_to_sents = []
    batch_concat_titles_ids = []

    for ex_id in examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        context_token_ids = [np.where(np.array(inputs['input_ids'][i]) == eos)[0] for i in ex_indices]
        ex_context_start_id.append([elem[1] + 1 for elem in context_token_ids])
        ex_context_end_id.append([elem[-1] for elem in context_token_ids])

        supp_para_start_positions = []

        for ex_sample in ex_indices:

            offset_idx = ex_sample
            offset = offset_mapping[offset_idx]
            #for offset_idx, offset in tqdm(enumerate(offset_mapping)):

            sample_idx = sample_map[offset_idx]
            sequence_ids = inputs.sequence_ids(offset_idx)

            inputs["offset_mapping"][offset_idx] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]
            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1


            supp_para_start_positions.append([])
            for supp_para_idx in range(len(supp_paras[sample_idx])):
                supp_para = supp_paras[sample_idx][supp_para_idx]
                supp_para_start_char = supp_para[0]
                supp_para_end_char = supp_para[1]

                if offset[context_start][0] > supp_para_start_char or offset[context_end][1] < supp_para_end_char:
                    continue
                else:
                    idx1 = context_start
                    while idx1 <= context_end and offset[idx1][0] <= supp_para_start_char:
                        idx1 += 1
                    supp_para_start_positions[-1].append(idx1 - 1)

        title_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_para_start_positions):
            title_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == PARA_MARKER_token)[0]])
            title_positions[-1] = pad_and_drop_duplicates(start_positions=title_positions[-1], 
                                                          max_num_answers=max_num_paragraphs)

        concat_input_ids = []
        concat_offset_mapping = []
        for ii in ex_indices:
            concat_input_ids += inputs['input_ids'][ii]
            concat_offset_mapping += inputs['offset_mapping'][ii]

        titles = np.where(np.array(concat_input_ids) == PARA_MARKER_token)[0]
        np_titles_concat_offset_mapping = np.array(concat_offset_mapping)[titles]
        titles_offsets_unique = sorted(np.unique(np_titles_concat_offset_mapping.tolist(), axis=0).tolist())
        concat_titles_ids = [titles_offsets_unique.index(list(i)) for i in np_titles_concat_offset_mapping]

        batch_concat_titles_ids.append(concat_titles_ids)
        batch_title_positions.append(title_positions)

    inputs.update({'context_start_id': ex_context_start_id})
    inputs.update({'context_end_id': ex_context_end_id})
    inputs.update({"supp_para_labels": batch_title_positions})
    inputs.update({"concat_titles_ids": batch_concat_titles_ids})

    rearranged_inps = []
    rearranged_masks = []
    rearranged_offset_mapping = []

    for ex_id in examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        rearranged_inps.append([inputs['input_ids'][i] for i in ex_indices])
        rearranged_masks.append([inputs['attention_mask'][i] for i in ex_indices])
        rearranged_offset_mapping.append([inputs['offset_mapping'][i] for i in ex_indices])

    inputs.update({'input_ids': rearranged_inps})
    inputs.update({'attention_mask': rearranged_masks})
    inputs.update({'offset_mapping': rearranged_offset_mapping})

    return inputs

mem_val_dataset = mem_val_dataset.map(
    preprocess_roberta_validation_examples_musique,
    batched=True,
    batch_size=1,
    remove_columns=mem_val_dataset.column_names,
)
tmp = mem_val_dataset[:]
tmp.update({'mem_tokens': mem_list[:len(mem_val_dataset)]})
mem_val_dataset = Dataset.from_dict(tmp)
mem_val_dataset.save_to_disk('../musique_yake_mem_val_dataset')

## HotpotQA and 2WikiMHQA

In [None]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
tokenizer = add_qa_evidence_tokens(tokenizer)

In [None]:
data = Dataset.from_dict({'mem': torch.load("../yake_mem_strings_hotpotqa_train.pkl")})
mem_list = data.map(tokenize_mem,
                    batched=True,
                    remove_columns=data.column_names
                   )
list_train_dataset = torch.load('../hotpotqa_train_examples_with_special_seps.pkl')

tmp = list_train_dataset[:]
tmp.update({'mem_tokens': mem_list})
mem_train_dataset = Dataset.from_dict(tmp)

In [None]:
max_num_answers = 64
max_num_paragraphs = 10 # by hotpotqa construction
max_num_sentences = 150 # from train and val data
eos = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
TITLE_END_token = tokenizer.convert_tokens_to_ids('</t>')
TITLE_START_token = tokenizer.convert_tokens_to_ids('<t>')
SENT_MARKER_END_token = tokenizer.convert_tokens_to_ids('[/sent]')
    

def preprocess_roberta_training_examples(train_examples):
    questions = [q.strip() for q in train_examples["question"]]
    inputs = tokenizer(
        questions,
        train_examples['context'],
        max_length=max_length-len(train_examples['mem_tokens'][0]),
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    raw_ids = [train_examples['id'][kk] for kk in inputs['overflow_to_sample_mapping']]
    # store start/end positions of context to filter part of sequence for uncertainty-based topk
    inputs.update({"id": train_examples['id']})
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = train_examples['char_answer_offsets']
    batch_question_type = [] # yes = 0, no = 1, span = 2
    ex_context_start_id = []
    ex_context_end_id = []
    batch_start_positions_list = []
    batch_end_positions_list = []
    if train_examples.get('supp_sent_char_offsets', False):
        supp_sents = train_examples['supp_sent_char_offsets']
    if train_examples.get('supp_title_char_offsets', False):
        supp_titles = train_examples['supp_title_char_offsets']
    batch_title_positions = []
    batch_sent_positions = []
    batch_titles_to_sents = []

    for ex_id in train_examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        context_token_ids = [np.where(np.array(inputs['input_ids'][i]) == eos)[0] for i in ex_indices]
        ex_context_start_id.append([elem[1] + 1 for elem in context_token_ids])
        ex_context_end_id.append([elem[-1] for elem in context_token_ids])

        start_positions_list = []
        end_positions_list = []
        question_type = []
        supp_sent_start_positions = []
        supp_sent_end_positions = []
        supp_title_start_positions = []
        supp_title_end_positions = []

        for ex_sample in ex_indices:
            offset_idx = ex_sample
            offset = offset_mapping[offset_idx]
            sample_idx = sample_map[offset_idx]
            sequence_ids = inputs.sequence_ids(offset_idx)
            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            ques_type_start_char = answers[sample_idx][0][0]
            ques_type_end_char = answers[sample_idx][0][1]
            if ques_type_start_char == -1 and ques_type_end_char == -1:
                question_type.append(0)
            elif ques_type_start_char == -2 and ques_type_end_char == -2:
                question_type.append(1)
            else:
                question_type.append(2)

            start_positions = []
            end_positions = []

            for answer in answers[sample_idx]:
                start_char = answer[0]
                end_char = answer[1]

                if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                    continue
                else:
                    idx = context_start
                    while idx <= context_end and offset[idx][0] <= start_char:
                        idx += 1
                    start_positions.append(idx - 1)

                    idx = context_end
                    while idx >= context_start and offset[idx][1] >= end_char:
                        idx -= 1
                    end_positions.append(idx + 1)

        start_positions, end_positions = pad_and_drop_duplicates(start_positions, 
                                                                 end_positions, 
                                                                 max_num_answers)
        start_positions_list.append(start_positions)
        end_positions_list.append(end_positions)

        supp_title_start_positions.append([])
        supp_title_end_positions.append([])

        for supp_title_idx in range(len(supp_titles[sample_idx])):
            supp_title = supp_titles[sample_idx][supp_title_idx]
            supp_title_start_char = supp_title[0]
            supp_title_end_char = supp_title[1]

            if offset[context_start][0] > supp_title_start_char or offset[context_end][1] < supp_title_end_char:
                continue
            else:
                idx1 = context_start
                while idx1 <= context_end and offset[idx1][0] <= supp_title_start_char:
                    idx1 += 1
                supp_title_start_positions[-1].append(idx1 - 1)

                idx1 = context_end
                while idx1 >= context_start and offset[idx1][1] >= supp_title_end_char:
                    idx1 -= 1
                supp_title_end_positions[-1].append(idx1 + 1)

        supp_sent_start_positions.append([])
        supp_sent_end_positions.append([])
        for supp_sent_idx in range(len(supp_sents[sample_idx])):
            supp_sent = supp_sents[sample_idx][supp_sent_idx]
            supp_sent_start_char = supp_sent[0]
            supp_sent_end_char = supp_sent[1]

            if offset[context_start][0] > supp_sent_start_char or offset[context_end][1] < supp_sent_end_char:
                continue
            else:
                idx2 = context_start
                while idx2 <= context_end and offset[idx2][0] <= supp_sent_start_char:
                    idx2 += 1
                supp_sent_start_positions[-1].append(idx2 - 1)

                idx2 = context_end
                while idx2 >= context_start and offset[idx2][1] >= supp_sent_end_char:
                    idx2 -= 1
                supp_sent_end_positions[-1].append(idx2 + 1)

        title_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_title_start_positions):
            title_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == TITLE_START_token)[0]])
            title_positions[-1] = pad_and_drop_duplicates(start_positions=title_positions[-1], 
                                                          max_num_answers=max_num_paragraphs)


        sent_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_sent_end_positions):
            sent_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == SENT_MARKER_END_token)[0]])
            sent_positions[-1] = pad_and_drop_duplicates(start_positions=sent_positions[-1], 
                                                         max_num_answers=max_num_sentences)



        batch_question_type.append(question_type)
        batch_start_positions_list.append(start_positions_list)
        batch_end_positions_list.append(end_positions_list)
        batch_title_positions.append(title_positions)
        batch_sent_positions.append(sent_positions)

    inputs.update({"start_positions": batch_start_positions_list})
    inputs.update({"end_positions": batch_end_positions_list})
    inputs.update({'question_type': batch_question_type})
    inputs.update({'context_start_id': ex_context_start_id})
    inputs.update({'context_end_id': ex_context_end_id})
    inputs.update({"supp_title_labels": batch_title_positions})
    inputs.update({"supp_sent_labels": batch_sent_positions})

    rearranged_inps = []
    rearranged_masks = []

    for ex_id in train_examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        rearranged_inps.append([inputs['input_ids'][i] for i in ex_indices])
        rearranged_masks.append([inputs['attention_mask'][i] for i in ex_indices])

    inputs.update({'input_ids': rearranged_inps})
    inputs.update({'attention_mask': rearranged_masks})

    return inputs


mem_train_dataset = mem_train_dataset.map(
    preprocess_roberta_training_examples,
    batched=True,
    batch_size=1,
    remove_columns=mem_train_dataset.column_names,
)

tmp = mem_train_dataset[:]
tmp.update({'mem_tokens': mem_list})
mem_train_dataset = Dataset.from_dict(tmp)
mem_train_dataset.save_to_disk('../hotpotqa_yake_mem_train_dataset')


In [None]:
data = Dataset.from_dict({'mem': torch.load("../yake_mem_strings_hotpotqa_val.pkl")})
mem_list = data.map(tokenize_mem,
                    batched=True,
                    remove_columns=data.column_names
                   )

In [None]:
list_val_dataset = torch.load('../hotpotqa_val_examples_with_special_seps.pkl')

tmp = list_val_dataset[:]
tmp.update({'mem_tokens': mem_list[:len(list_val_dataset)]})
mem_val_dataset = Dataset.from_dict(tmp)


def preprocess_roberta_validation_examples(examples):

    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length-len(examples['mem_tokens'][0]),
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    raw_ids = [examples['id'][kk] for kk in inputs['overflow_to_sample_mapping']]
    inputs.update({"example_ids": examples['id']})
    offset_mapping = inputs["offset_mapping"]
    sample_map = inputs.pop("overflow_to_sample_mapping")
    ex_context_start_id = []
    ex_context_end_id = []
    supp_sents = examples['supp_sent_char_offsets']
    supp_titles = examples['supp_title_char_offsets']
    batch_title_positions = []
    batch_sent_positions = []
    batch_titles_to_sents = []
    batch_concat_titles_ids = []

    for ex_id in examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        context_token_ids = [np.where(np.array(inputs['input_ids'][i]) == eos)[0] for i in ex_indices]
        ex_context_start_id.append([elem[1] + 1 for elem in context_token_ids])
        ex_context_end_id.append([elem[-1] for elem in context_token_ids])

        supp_sent_start_positions = []
        supp_sent_end_positions = []
        supp_title_start_positions = []
        supp_title_end_positions = []

        for ex_sample in ex_indices:

            offset_idx = ex_sample
            offset = offset_mapping[offset_idx]

            sample_idx = sample_map[offset_idx]
            sequence_ids = inputs.sequence_ids(offset_idx)

            inputs["offset_mapping"][offset_idx] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]
            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1


            supp_title_start_positions.append([])
            supp_title_end_positions.append([])
            for supp_title_idx in range(len(supp_titles[sample_idx])):
                supp_title = supp_titles[sample_idx][supp_title_idx]
                supp_title_start_char = supp_title[0]
                supp_title_end_char = supp_title[1]

                if offset[context_start][0] > supp_title_start_char or offset[context_end][1] < supp_title_end_char:
                    continue
                else:
                    idx1 = context_start
                    while idx1 <= context_end and offset[idx1][0] <= supp_title_start_char:
                        idx1 += 1
                    supp_title_start_positions[-1].append(idx1 - 1)

                    idx1 = context_end
                    while idx1 >= context_start and offset[idx1][1] >= supp_title_end_char:
                        idx1 -= 1
                    supp_title_end_positions[-1].append(idx1 + 1)

            supp_sent_start_positions.append([])
            supp_sent_end_positions.append([])
            for supp_sent_idx in range(len(supp_sents[sample_idx])):
                supp_sent = supp_sents[sample_idx][supp_sent_idx]
                supp_sent_start_char = supp_sent[0]
                supp_sent_end_char = supp_sent[1]

                if offset[context_start][0] > supp_sent_start_char or offset[context_end][1] < supp_sent_end_char:
                    continue
                else:
                    idx2 = context_start
                    while idx2 <= context_end and offset[idx2][0] <= supp_sent_start_char:
                        idx2 += 1
                    supp_sent_start_positions[-1].append(idx2 - 1)

                    idx2 = context_end
                    while idx2 >= context_start and offset[idx2][1] >= supp_sent_end_char:
                        idx2 -= 1
                    supp_sent_end_positions[-1].append(idx2 + 1)



        title_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_title_start_positions):
            title_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == TITLE_START_token)[0]])
            title_positions[-1] = pad_and_drop_duplicates(start_positions=title_positions[-1], 
                                                          max_num_answers=max_num_paragraphs)


        sent_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_sent_end_positions):
            sent_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == SENT_MARKER_END_token)[0]])
            sent_positions[-1] = pad_and_drop_duplicates(start_positions=sent_positions[-1], 
                                                         max_num_answers=max_num_sentences)


        concat_input_ids = []
        concat_offset_mapping = []
        for ii in ex_indices:
            concat_input_ids += inputs['input_ids'][ii]
            concat_offset_mapping += inputs['offset_mapping'][ii]

        titles = np.where(np.array(concat_input_ids) == TITLE_START_token)[0]
        np_titles_concat_offset_mapping = np.array(concat_offset_mapping)[titles]
        titles_offsets_unique = sorted(np.unique(np_titles_concat_offset_mapping.tolist(), axis=0).tolist())
        concat_titles_ids = [titles_offsets_unique.index(list(i)) for i in np_titles_concat_offset_mapping]


        global_sents = np.where(np.array(concat_input_ids) == SENT_MARKER_END_token)[0]#.tolist()
        global_sents_offsets_unique = []
        for i in np.array(concat_offset_mapping)[global_sents]:
            if i not in global_sents_offsets_unique:
                global_sents_offsets_unique.append(i)



        titles_to_sents = []

        for i, offsets_list in zip([inputs['input_ids'][ii] for ii in ex_indices], 
                                   [inputs['offset_mapping'][ii] for ii in ex_indices]):
            sents = np.where(np.array(i) == SENT_MARKER_END_token)[0].tolist()
            sents_offsets = np.array(offsets_list)[sents]

            tmp = []
            if len(titles_offsets_unique) > 1:
                #local chunk sent_id, global doc sent_id
                tmp = [[(sent_id, global_sents_offsets_unique.index(sent_offset)) for sent_id, sent_offset in enumerate(sents_offsets) if ((sent_offset[0] >= title_offset[-1]) and (sent_offset[-1] <= titles_offsets_unique[i+1][0])) ] for i, title_offset in enumerate(titles_offsets_unique[:-1])]

            if len(titles_offsets_unique) > 0:
                tmp.append([(sent_id, global_sents_offsets_unique.index(sent_offset)) for sent_id, sent_offset in enumerate(sents_offsets) if (sent_offset[0] >= titles_offsets_unique[-1][-1])])

            titles_to_sents.append(tmp)

        batch_concat_titles_ids.append(concat_titles_ids)
        batch_title_positions.append(title_positions)
        batch_sent_positions.append(sent_positions)
        batch_titles_to_sents.append(titles_to_sents)

    inputs.update({'context_start_id': ex_context_start_id})
    inputs.update({'context_end_id': ex_context_end_id})
    inputs.update({"supp_title_labels": batch_title_positions})
    inputs.update({"supp_sent_labels": batch_sent_positions})
    inputs.update({"titles_to_sents": batch_titles_to_sents})
    inputs.update({"concat_titles_ids": batch_concat_titles_ids})

    rearranged_inps = []
    rearranged_masks = []
    rearranged_offset_mapping = []

    for ex_id in examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        rearranged_inps.append([inputs['input_ids'][i] for i in ex_indices])
        rearranged_masks.append([inputs['attention_mask'][i] for i in ex_indices])
        rearranged_offset_mapping.append([inputs['offset_mapping'][i] for i in ex_indices])

    inputs.update({'input_ids': rearranged_inps})
    inputs.update({'attention_mask': rearranged_masks})
    inputs.update({'offset_mapping': rearranged_offset_mapping})

    return inputs

mem_val_dataset = mem_val_dataset.map(
    preprocess_roberta_validation_examples,
    batched=True,
    batch_size=1,
    remove_columns=mem_val_dataset.column_names,
)
tmp = mem_val_dataset[:]
tmp.update({'mem_tokens': mem_list[:len(mem_val_dataset)]})
mem_val_dataset = Dataset.from_dict(tmp)
mem_val_dataset.save_to_disk('../hotpotqa_yake_mem_val_dataset')

### 2wikiMHQA

In [24]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
tokenizer = add_qa_evidence_tokens(tokenizer)

data = Dataset.from_dict({'mem': torch.load("../yake_mem_strings_2wikimhqa_train.pkl")})
mem_list = data.map(tokenize_mem,
                    batched=True,
                    remove_columns=data.column_names
                   )

list_train_dataset = torch.load('../2wikimhqa_train_examples_with_special_seps.pkl')

tmp = list_train_dataset[:]
tmp.update({'mem_tokens': mem_list})
mem_train_dataset = Dataset.from_dict(tmp)

max_num_answers = 64
max_num_paragraphs = 10 # by hotpotqa construction
max_num_sentences = 210 # checked from train and val data
eos = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) #'</s>'
TITLE_END_token = tokenizer.convert_tokens_to_ids('</t>')   # indicating the end of the title of a paragraph

mem_train_dataset = mem_train_dataset.map(
    preprocess_roberta_training_examples,
    batched=True,
    batch_size=1,
    remove_columns=mem_train_dataset.column_names,
)

tmp = mem_train_dataset[:]
tmp.update({'mem_tokens': mem_list})
mem_train_dataset = Dataset.from_dict(tmp)
mem_train_dataset.save_to_disk('../2wikimhqa_yake_mem_train_dataset')


  0%|          | 0/167454 [00:00<?, ?ba/s]

In [None]:
data = Dataset.from_dict({'mem': torch.load("../yake_mem_strings_2wikimhqa_val.pkl")})
mem_list = data.map(tokenize_mem,
                    batched=True,
                    remove_columns=data.column_names
                   )
list_val_dataset = torch.load('../2wikimhqa_val_examples_with_special_seps_with_mem_seps.pkl')

tmp = list_val_dataset[:]
tmp.update({'mem_tokens': mem_list[:len(list_val_dataset)]})
mem_val_dataset = Dataset.from_dict(tmp)

mem_val_dataset = mem_val_dataset.map(
    preprocess_roberta_validation_examples,
    batched=True,
    batch_size=1,
    remove_columns=mem_val_dataset.column_names,
)
tmp = mem_val_dataset[:]
tmp.update({'mem_tokens': mem_list[:len(mem_val_dataset)]})
mem_val_dataset = Dataset.from_dict(tmp)
mem_val_dataset.save_to_disk('../2wikimhqa_yake_mem_val_dataset')