In [None]:
# download data from https://www.dropbox.com/s/npidmtadreo6df2/data.zip to 
import json
train = json.load(open('./2wikimultihopqa/train.json', 'r'))
dev = json.load(open('./2wikimultihopqa/dev.json', 'r'))

In [None]:
from prepro_char_based_targets import process_file
from datasets.arrow_dataset import Dataset


train_examples = process_file(train, with_special_seps=True, dataset_name='2wikimhqa')
train_examples = Dataset.from_dict({feature: [train_examples[i][feature] for i in range(len(train_examples))] for feature in train_examples[0]})


In [None]:
import torch
torch.save(train_examples, '../2wikimhqa_train_examples_with_special_seps.pkl')

val_examples = process_file(dev, with_special_seps=True, dataset_name='2wikimhqa')
val_examples = Dataset.from_dict({feature: [val_examples[i][feature] for i in range(len(val_examples))] for feature in val_examples[0]})
torch.save(val_examples, '../2wikimhqa_val_examples_with_special_seps.pkl')

In [1]:
# handle gemformer imports
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

In [2]:
import os
import json
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import RobertaTokenizerFast, default_data_collator
from torch.utils.data import DataLoader
from datasets.arrow_dataset import Dataset
from datasets import load_from_disk

from gemformer.utils import add_qa_evidence_tokens, pad_and_drop_duplicates, ROBERTA_BASE_SPECIAL_TOKENS

tokenizer_name = 'roberta-base'
tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_name)
tokenizer = add_qa_evidence_tokens(tokenizer)

2023-11-29 16:15:02.278868: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


initial vocab len = 50265
We have added 5 tokens
final vocab len = 50270


In [None]:
train_examples = torch.load('../2wikimhqa_train_examples_with_special_seps.pkl')

stride = 20
max_num_answers = 64
max_num_paragraphs = 10 # by dataset construction
max_num_sentences = 210 # from train and val data
eos = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) #'</s>'
TITLE_END_token = tokenizer.convert_tokens_to_ids('</t>') # indicating the end of the title of a paragraph
TITLE_START_token = tokenizer.convert_tokens_to_ids('<t>')
SENT_MARKER_END_token = tokenizer.convert_tokens_to_ids('[/sent]')
MAX_SEQ_LEN = 512

def preprocess_roberta_long_training_examples(train_examples):
    questions = [q.strip() for q in train_examples["question"]]
    inputs = tokenizer(
        questions,
        train_examples['context'],
        max_length=MAX_SEQ_LEN,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    raw_ids = [train_examples['id'][kk] for kk in inputs['overflow_to_sample_mapping']]
    # store start/end positions of context to filter part of sequence for uncertainty-based topk
    inputs.update({"id": train_examples['id']})
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = train_examples['char_answer_offsets']
    batch_question_type = [] # yes = 0, no = 1, span = 2
    ex_context_start_id = []
    ex_context_end_id = []
    batch_start_positions_list = []
    batch_end_positions_list = []
    supp_sents = train_examples['supp_sent_char_offsets']
    supp_titles = train_examples['supp_title_char_offsets']
    batch_title_positions = []
    batch_sent_positions = []
    batch_titles_to_sents = []

    for ex_id in train_examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        context_token_ids = [np.where(np.array(inputs['input_ids'][i]) == eos)[0] for i in ex_indices]
        ex_context_start_id.append([elem[1] + 1 for elem in context_token_ids])
        ex_context_end_id.append([elem[-1] for elem in context_token_ids])

        start_positions_list = []
        end_positions_list = []
        question_type = []
        supp_sent_start_positions = []
        supp_sent_end_positions = []
        supp_title_start_positions = []
        supp_title_end_positions = []

        for ex_sample in ex_indices:
            offset_idx = ex_sample
            offset = offset_mapping[offset_idx]
            sample_idx = sample_map[offset_idx]
            sequence_ids = inputs.sequence_ids(offset_idx)
            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            ques_type_start_char = answers[sample_idx][0][0]
            ques_type_end_char = answers[sample_idx][0][1]
            if ques_type_start_char == -1 and ques_type_end_char == -1:
                question_type.append(0)
            elif ques_type_start_char == -2 and ques_type_end_char == -2:
                question_type.append(1)
            else:
                question_type.append(2)

            start_positions = []
            end_positions = []

            for answer in answers[sample_idx]:
                start_char = answer[0]
                end_char = answer[1]

                if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                    continue
                else:
                    idx = context_start
                    while idx <= context_end and offset[idx][0] <= start_char:
                        idx += 1
                    start_positions.append(idx - 1)

                    idx = context_end
                    while idx >= context_start and offset[idx][1] >= end_char:
                        idx -= 1
                    end_positions.append(idx + 1)

            start_positions, end_positions = pad_and_drop_duplicates(start_positions, end_positions, max_num_answers)
            start_positions_list.append(start_positions)
            end_positions_list.append(end_positions)

            supp_title_start_positions.append([])
            supp_title_end_positions.append([])

            for supp_title_idx in range(len(supp_titles[sample_idx])):
                supp_title = supp_titles[sample_idx][supp_title_idx]
                supp_title_start_char = supp_title[0]
                supp_title_end_char = supp_title[1]

                if offset[context_start][0] > supp_title_start_char or offset[context_end][1] < supp_title_end_char:
                    continue
                else:
                    idx1 = context_start
                    while idx1 <= context_end and offset[idx1][0] <= supp_title_start_char:
                        idx1 += 1
                    supp_title_start_positions[-1].append(idx1 - 1)

                    idx1 = context_end
                    while idx1 >= context_start and offset[idx1][1] >= supp_title_end_char:
                        idx1 -= 1
                    supp_title_end_positions[-1].append(idx1 + 1)

            supp_sent_start_positions.append([])
            supp_sent_end_positions.append([])
            for supp_sent_idx in range(len(supp_sents[sample_idx])):
                supp_sent = supp_sents[sample_idx][supp_sent_idx]
                supp_sent_start_char = supp_sent[0]
                supp_sent_end_char = supp_sent[1]

                if offset[context_start][0] > supp_sent_start_char or offset[context_end][1] < supp_sent_end_char:
                    continue
                else:
                    idx2 = context_start
                    while idx2 <= context_end and offset[idx2][0] <= supp_sent_start_char:
                        idx2 += 1
                    supp_sent_start_positions[-1].append(idx2 - 1)

                    idx2 = context_end
                    while idx2 >= context_start and offset[idx2][1] >= supp_sent_end_char:
                        idx2 -= 1
                    supp_sent_end_positions[-1].append(idx2 + 1)

        title_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_title_start_positions):
            title_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == TITLE_START_token)[0]])
            title_positions[-1] = pad_and_drop_duplicates(start_positions=title_positions[-1], 
                                                          max_num_answers=max_num_paragraphs)
        sent_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_sent_end_positions):
            sent_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == SENT_MARKER_END_token)[0]])
            sent_positions[-1] = pad_and_drop_duplicates(start_positions=sent_positions[-1], 
                                                         max_num_answers=max_num_sentences)

        batch_question_type.append(question_type)
        batch_start_positions_list.append(start_positions_list)
        batch_end_positions_list.append(end_positions_list)
        batch_title_positions.append(title_positions)
        batch_sent_positions.append(sent_positions)

    inputs.update({"start_positions": batch_start_positions_list})
    inputs.update({"end_positions": batch_end_positions_list})
    inputs.update({'question_type': batch_question_type})
    inputs.update({'context_start_id': ex_context_start_id})
    inputs.update({'context_end_id': ex_context_end_id})
    inputs.update({"supp_title_labels": batch_title_positions})
    inputs.update({"supp_sent_labels": batch_sent_positions})

    rearranged_inps = []
    rearranged_masks = []

    for ex_id in train_examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        rearranged_inps.append([inputs['input_ids'][i] for i in ex_indices])
        rearranged_masks.append([inputs['attention_mask'][i] for i in ex_indices])

    inputs.update({'input_ids': rearranged_inps})
    inputs.update({'attention_mask': rearranged_masks})

    return inputs


mem_train_dataset = train_examples.map(
    preprocess_roberta_long_training_examples,
    batched=True,
    batch_size=1,
    remove_columns=train_examples.column_names,
)

mem_train_dataset.save_to_disk('../2wikimhqa_preprocessed_train_examples_512_multitask_stride20_one_doc_batched_without_zero_answer_pos_without_CoT_triplets')

In [None]:
val_examples = torch.load('../2wikimhqa_val_examples_with_special_seps.pkl')

def preprocess_longformer_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=MAX_SEQ_LEN,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    # store raw sample id to match subsamples related to the same big context document 
    raw_ids = [examples['id'][kk] for kk in inputs['overflow_to_sample_mapping']]
    inputs.update({"example_ids": examples['id']})
    offset_mapping = inputs["offset_mapping"]
    sample_map = inputs.pop("overflow_to_sample_mapping")
    ex_context_start_id = []
    ex_context_end_id = []
    supp_sents = examples['supp_sent_char_offsets']
    supp_titles = examples['supp_title_char_offsets']
    batch_title_positions = []
    batch_sent_positions = []
    batch_titles_to_sents = []
    batch_concat_titles_ids = []

    for ex_id in examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        context_token_ids = [np.where(np.array(inputs['input_ids'][i]) == eos)[0] for i in ex_indices]
        ex_context_start_id.append([elem[1] + 1 for elem in context_token_ids])
        ex_context_end_id.append([elem[-1] for elem in context_token_ids])

        supp_sent_start_positions = []
        supp_sent_end_positions = []
        supp_title_start_positions = []
        supp_title_end_positions = []

        for ex_sample in ex_indices:
            offset_idx = ex_sample
            offset = offset_mapping[offset_idx]
            sample_idx = sample_map[offset_idx]
            sequence_ids = inputs.sequence_ids(offset_idx)

            inputs["offset_mapping"][offset_idx] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]
            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            supp_title_start_positions.append([])
            supp_title_end_positions.append([])
            for supp_title_idx in range(len(supp_titles[sample_idx])):
                supp_title = supp_titles[sample_idx][supp_title_idx]
                supp_title_start_char = supp_title[0]
                supp_title_end_char = supp_title[1]
                if offset[context_start][0] > supp_title_start_char or offset[context_end][1] < supp_title_end_char:
                    continue
                else:
                    idx1 = context_start
                    while idx1 <= context_end and offset[idx1][0] <= supp_title_start_char:
                        idx1 += 1
                    supp_title_start_positions[-1].append(idx1 - 1)

                    idx1 = context_end
                    while idx1 >= context_start and offset[idx1][1] >= supp_title_end_char:
                        idx1 -= 1
                    supp_title_end_positions[-1].append(idx1 + 1)

            supp_sent_start_positions.append([])
            supp_sent_end_positions.append([])
            for supp_sent_idx in range(len(supp_sents[sample_idx])):
                supp_sent = supp_sents[sample_idx][supp_sent_idx]
                supp_sent_start_char = supp_sent[0]
                supp_sent_end_char = supp_sent[1]
                if offset[context_start][0] > supp_sent_start_char or offset[context_end][1] < supp_sent_end_char:
                    continue
                else:
                    idx2 = context_start
                    while idx2 <= context_end and offset[idx2][0] <= supp_sent_start_char:
                        idx2 += 1
                    supp_sent_start_positions[-1].append(idx2 - 1)

                    idx2 = context_end
                    while idx2 >= context_start and offset[idx2][1] >= supp_sent_end_char:
                        idx2 -= 1
                    supp_sent_end_positions[-1].append(idx2 + 1)

        title_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_title_start_positions):
            title_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == TITLE_START_token)[0]])
            title_positions[-1] = pad_and_drop_duplicates(start_positions=title_positions[-1], 
                                                          max_num_answers=max_num_paragraphs)
        sent_positions = []
        for i, supp in zip([inputs['input_ids'][ii] for ii in ex_indices], supp_sent_end_positions):
            sent_positions.append([0 if j not in supp else 1 for j in np.where(np.array(i) == SENT_MARKER_END_token)[0]])
            sent_positions[-1] = pad_and_drop_duplicates(start_positions=sent_positions[-1], 
                                                         max_num_answers=max_num_sentences)

        concat_input_ids = []
        concat_offset_mapping = []
        for ii in ex_indices:
            concat_input_ids += inputs['input_ids'][ii]
            concat_offset_mapping += inputs['offset_mapping'][ii]

        titles = np.where(np.array(concat_input_ids) == TITLE_START_token)[0]
        np_titles_concat_offset_mapping = np.array(concat_offset_mapping)[titles]
        titles_offsets_unique = sorted(np.unique(np_titles_concat_offset_mapping.tolist(), axis=0).tolist())
        concat_titles_ids = [titles_offsets_unique.index(list(i)) for i in np_titles_concat_offset_mapping]
     
        global_sents = np.where(np.array(concat_input_ids) == SENT_MARKER_END_token)[0]
        global_sents_offsets_unique = []
        for i in np.array(concat_offset_mapping)[global_sents]:
            if i not in global_sents_offsets_unique:
                global_sents_offsets_unique.append(i)

        titles_to_sents = []

        for i, offsets_list in zip([inputs['input_ids'][ii] for ii in ex_indices], 
                                   [inputs['offset_mapping'][ii] for ii in ex_indices]):
            sents = np.where(np.array(i) == SENT_MARKER_END_token)[0].tolist()
            sents_offsets = np.array(offsets_list)[sents]

            tmp = []
            if len(titles_offsets_unique) > 1:
                #local chunk sent_id, global doc sent_id
                tmp = [[(sent_id, global_sents_offsets_unique.index(sent_offset)) for sent_id, sent_offset in enumerate(sents_offsets) if ((sent_offset[0] >= title_offset[-1]) and (sent_offset[-1] <= titles_offsets_unique[i+1][0])) ] for i, title_offset in enumerate(titles_offsets_unique[:-1])]

            if len(titles_offsets_unique) > 0:
                tmp.append([(sent_id, global_sents_offsets_unique.index(sent_offset)) for sent_id, sent_offset in enumerate(sents_offsets) if (sent_offset[0] >= titles_offsets_unique[-1][-1])])
            titles_to_sents.append(tmp)

        batch_concat_titles_ids.append(concat_titles_ids)
        batch_title_positions.append(title_positions)
        batch_sent_positions.append(sent_positions)
        batch_titles_to_sents.append(titles_to_sents)

    inputs.update({'context_start_id': ex_context_start_id})
    inputs.update({'context_end_id': ex_context_end_id})
    inputs.update({"supp_title_labels": batch_title_positions})
    inputs.update({"supp_sent_labels": batch_sent_positions})
    inputs.update({"titles_to_sents": batch_titles_to_sents})
    inputs.update({"concat_titles_ids": batch_concat_titles_ids})

    rearranged_inps = []
    rearranged_masks = []
    rearranged_offset_mapping = []

    for ex_id in examples["id"]:
        ex_indices = np.where(np.array(raw_ids) == ex_id)[0].tolist()
        rearranged_inps.append([inputs['input_ids'][i] for i in ex_indices])
        rearranged_masks.append([inputs['attention_mask'][i] for i in ex_indices])
        rearranged_offset_mapping.append([inputs['offset_mapping'][i] for i in ex_indices])

    inputs.update({'input_ids': rearranged_inps})
    inputs.update({'attention_mask': rearranged_masks})
    inputs.update({'offset_mapping': rearranged_offset_mapping})

    return inputs

mem_val_dataset = val_examples.map(
    preprocess_longformer_validation_examples,
    batched=True,
    batch_size=1,
    remove_columns=val_examples.column_names,
)
mem_val_dataset.save_to_disk('../2wikimhqa_preprocessed_val_examples_512_multitask_stride20_one_doc_batched_without_zero_answer_pos_without_CoT_triplets')