In [1]:
import json
import sys
from pathlib import Path

import torch
from datasets import Dataset, load_dataset
import pandas as pd
from tqdm.autonotebook import tqdm

def read_squad_files(path: str):
    """
    Takes the path for a SQuAD formated json file
    Returns a Huggingface datasets object similar to the one called form:
    
    ```
    from dataset import load_dataset
    load_dataset("squad")
    ```
    """
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)
    contexts = []
    questions = []
    answers = []

    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
            
    df = pd.DataFrame({
        "context": contexts,
        "question": questions,
        "answers": answers,
    })
    return Dataset.from_pandas(df)
    

    
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096', do_lowercase=True, use_fast=True)


train_dataset = read_squad_files('/workspace/data/trivia_squad/squad-wikipedia-train-4096.json')
valid_dataset = read_squad_files('/workspace/data/trivia_squad/squad-wikipedia-dev-4096.json')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=694.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355863.0, style=ProgressStyle(descript…




In [2]:
def get_correct_alignement(context: str, answer):
    """ Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """
    gold_text = answer["text"][0]
    start_idx = answer["answer_start"][0]
    end_idx = start_idx + len(gold_text)
    if context[start_idx:end_idx] == gold_text:
        return start_idx, end_idx  # When the gold label position is good
    elif context[start_idx - 1 : end_idx - 1] == gold_text:
        return start_idx - 1, end_idx - 1  # When the gold label is off by one character
    elif context[start_idx - 2 : end_idx - 2] == gold_text:
        return start_idx - 2, end_idx - 2  # When the gold label is off by two character
    else:
        raise ValueError()


def add_triviaqa_end_idx(context, answer):
    #for answer, context in zip(answers, contexts):
    gold_text = answer['text']
    start_idx = answer['answer_start']
    end_idx = start_idx + len(gold_text)

    # sometimes squad answers are off by a character or two – fix this
    if context[start_idx:end_idx] == gold_text:
        end_idx = end_idx
    elif context[start_idx-1:end_idx-1] == gold_text:
        start_idx = start_idx - 1
        end_idx = end_idx - 1     # When the gold label is off by one character
    elif context[start_idx-2:end_idx-2] == gold_text:
        start_idx = start_idx - 2
        end_idx = end_idx - 2     # When the gold label is off by two characters
    return start_idx, end_idx


# Tokenize our training dataset
def convert_to_features(example):
    # Tokenize contexts and questions (as pairs of inputs)
    encodings = tokenizer.encode_plus(
        example["question"],
        example["context"],
        padding=True,
        #pad_to_max_length=True,   # TODO
        #max_length=512,
        truncation=True,
    )
    context_encodings = tokenizer.encode_plus(example["context"])

    # Compute start and end tokens for labels using Transformers's fast tokenizers alignement methodes.
    # this will give us the position of answer span in the context text
    start_idx, end_idx = add_triviaqa_end_idx(example["context"], example["answers"])
    start_positions_context = context_encodings.char_to_token(start_idx)
    end_positions_context = context_encodings.char_to_token(end_idx - 1)
    
    # FIXME: UGLY HACK because of XLM-R tokenization, works fine with monolingual
    # 2 training examples returns incorrect positions
    try:
        # here we will compute the start and end position of the answer in the whole example
        # as the example is encoded like this <s> question</s></s> context</s>
        # and we know the postion of the answer in the context
        # we can just find out the index of the sep token and then add that to position + 1 (+1 because there are two sep tokens)
        # this will give us the position of the answer span in whole example
        sep_idx = encodings["input_ids"].index(tokenizer.sep_token_id)
        start_positions = start_positions_context + sep_idx + 1
        end_positions = end_positions_context + sep_idx + 1

        if end_positions > 512:
            start_positions, end_positions = 0, 0
    
    # Returned None for start or end position index
    except:
        start_positions = None
        end_positions = None
    
    encodings.update(
        {
            "start_positions": start_positions,
            "end_positions": end_positions,
            "attention_mask": encodings["attention_mask"],
        }
    )
    return encodings

In [None]:
import time

start = time.time()
train_dataset = train_dataset.map(convert_to_features)
print(f"Took: {time.time() - start}")

HBox(children=(FloatProgress(value=0.0, max=1851594.0), HTML(value='')))

In [None]:
start = time.time()
valid_dataset = valid_dataset.map(convert_to_features)
print(f"Took: {time.time() - start}")

In [None]:
# set the tensor type and the columns which the dataset should return
columns = ['input_ids', 'attention_mask', 'start_positions', 'end_positions']
train_dataset.set_format(type='torch', columns=columns)

In [None]:
valid_dataset.set_format(type='torch', columns=columns) 

In [None]:
from tqdm import tqdm, trange

a=range(int(1e8))

#method 3
b = map(str, tqdm(a))

In [None]:
# or
trange(int(1e7), leave=True)

In [12]:
len(train_dataset)

1851594

In [13]:
import datasets

In [18]:
train = datasets.load_dataset("squad", split="train")

Reusing dataset squad (/.cache/huggingface/datasets/squad/plain_text/1.0.0/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41)


In [19]:
train

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
})

In [20]:
1851594/87599

21.137159099989727