https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb#scrollTo=-gJOEe0Ye0di

In [None]:
!pip install nlp
!pip install transformers
!pip install datasets

In [5]:
## IMPORTS
import json

import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import datasets

from transformers import (LongformerModel, LongformerTokenizer, LongformerPreTrainedModel,
                          LongformerConfig, Trainer, TrainingArguments, EarlyStoppingCallback)
from transformers.models.longformer.modeling_longformer import LongformerQuestionAnsweringModelOutput

In [6]:
import torch
import nlp
from transformers import LongformerTokenizerFast

tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096')

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/694 [00:00<?, ?B/s]

In [17]:
question = "who is Ali?"
context = "Ali is an engineer."
answer = "an engineer"

start_idx, end_idx = 7, 18

In [18]:
# get_correct_alignement

gold_text = answer

if context[start_idx:end_idx] == gold_text:
     start_idx, end_idx = start_idx, end_idx       # When the gold label position is good
elif context[start_idx-1:end_idx-1] == gold_text:
    print("HERE")
    start_idx, end_idx = start_idx-1, end_idx-1   # When the gold label is off by one character
elif context[start_idx-2:end_idx-2] == gold_text:
    print("THERE")
    start_idx, end_idx = start_idx-2, end_idx-2   # When the gold label is off by two character
else:
    raise ValueError()

In [19]:
start_idx, end_idx

(7, 18)

In [20]:
# Tokenize contexts and questions (as pairs of inputs)
input_pairs = [question, context]
# encodings = tokenizer.encode_plus(input_pairs, pad_to_max_length=True, max_length=512)
encodings = tokenizer.encode_plus(input_pairs)
context_encodings = tokenizer.encode_plus(context)

In [21]:
encodings, context_encodings

({'input_ids': [0, 8155, 16, 4110, 116, 2, 2, 37358, 16, 41, 8083, 4, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]},
 {'input_ids': [0, 37358, 16, 41, 8083, 4, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]})

In [22]:
# Compute start and end tokens for labels using Transformers's fast tokenizers alignement methodes.
# this will give us the position of answer span in the context text
#   start_idx, end_idx = get_correct_alignement(example['context'], example['answers'])
start_positions_context = context_encodings.char_to_token(start_idx)
end_positions_context = context_encodings.char_to_token(end_idx-1)

In [23]:
start_positions_context, end_positions_context

(3, 4)

In [24]:
# here we will compute the start and end position of the answer in the whole example
# as the example is encoded like this <s> question</s></s> context</s>
# and we know the postion of the answer in the context
# we can just find out the index of the sep token and then add that to position + 1 (+1 because there are two sep tokens)
# this will give us the position of the answer span in whole example 
last_sep_idx = encodings['input_ids'].index(tokenizer.sep_token_id) + 1
start_positions = start_positions_context + last_sep_idx
end_positions = end_positions_context + last_sep_idx

In [25]:
start_positions, end_positions

(9, 10)

In [7]:
def get_correct_alignement(context, answer):
    """ Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """
    gold_text = answer['text'][0]
    start_idx = answer['answer_start'][0]
    end_idx = start_idx + len(gold_text)
    if context[start_idx:end_idx] == gold_text:
        return start_idx, end_idx       # When the gold label position is good
    elif context[start_idx-1:end_idx-1] == gold_text:
        return start_idx-1, end_idx-1   # When the gold label is off by one character
    elif context[start_idx-2:end_idx-2] == gold_text:
        return start_idx-2, end_idx-2   # When the gold label is off by two character
    else:
        raise ValueError()

# Tokenize our training dataset
def convert_to_features(example):
    # Tokenize contexts and questions (as pairs of inputs)
    input_pairs = [example['question'], example['context']]
    encodings = tokenizer.encode_plus(input_pairs, pad_to_max_length=True, max_length=512)
    context_encodings = tokenizer.encode_plus(example['context'])
    

    # Compute start and end tokens for labels using Transformers's fast tokenizers alignement methodes.
    # this will give us the position of answer span in the context text
    start_idx, end_idx = get_correct_alignement(example['context'], example['answers'])
    start_positions_context = context_encodings.char_to_token(start_idx)
    end_positions_context = context_encodings.char_to_token(end_idx-1)
    
    # here we will compute the start and end position of the answer in the whole example
    # as the example is encoded like this <s> question</s></s> context</s>
    # and we know the postion of the answer in the context
    # we can just find out the index of the sep token and then add that to position + 1 (+1 because there are two sep tokens)
    # this will give us the position of the answer span in whole example 
    sep_idx = encodings['input_ids'].index(tokenizer.sep_token_id)
    start_positions = start_positions_context + sep_idx + 1
    end_positions = end_positions_context + sep_idx + 1

    if end_positions > 512:
      start_positions, end_positions = 0, 0

    encodings.update({'start_positions': start_positions,
                      'end_positions': end_positions,
                      'attention_mask': encodings['attention_mask']})
    return encodings