In [4]:
from transformers import pipeline

# Open and read the article
question = "What is the capital of the Netherlands?"
context = r"The four largest cities in the Netherlands are Amsterdam, Rotterdam, The Hague and Utrecht.[17] Amsterdam is the country's most populous city and nominal capital,[18] while The Hague holds the seat of the States General, Cabinet and Supreme Court.[19] The Port of Rotterdam is the busiest seaport in Europe, and the busiest in any country outside East Asia and Southeast Asia, behind only China and Singapore."

# Generating an answer to the question in context
qa = pipeline("question-answering")
answer = qa(question=question, context=context)

# Print the answer
print(f"Question: {question}")
print(f"Answer: '{answer['answer']}' with score {answer['score']}")

No model was supplied, defaulted to distilbert-base-cased-distilled-squad (https://huggingface.co/distilbert-base-cased-distilled-squad)


Question: What is the capital of the Netherlands?
Answer: 'Amsterdam' with score 0.37749919295310974


# Intro

https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/longformer/modeling_longformer.py \\
https://huggingface.co/docs/transformers/model_doc/longformer#transformers.LongformerTokenizer

In [None]:
!pip install datasets
!pip install transformers

In [5]:
## IMPORTS

import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import datasets

from transformers import (LongformerModel, LongformerTokenizer, LongformerPreTrainedModel,
                          LongformerConfig, Trainer, TrainingArguments, EarlyStoppingCallback)
from transformers.models.longformer.modeling_longformer import LongformerQuestionAnsweringModelOutput

In [6]:
## HELPER FUNCTIONS

def _get_question_end_index(input_ids, sep_token_id):
    """
    Computes the index of the first occurrence of `sep_token_id`.
    """

    sep_token_indices = (input_ids == sep_token_id).nonzero()
    batch_size = input_ids.shape[0]

    assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
    assert (
        sep_token_indices.shape[0] == 3 * batch_size
    ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
    return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]


def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
    """
    Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
    True` else after `sep_token_id`.
    """
    question_end_index = _get_question_end_index(input_ids, sep_token_id)
    question_end_index = question_end_index.unsqueeze(dim=1)  # size: batch_size x 1
    # bool attention mask with True in locations of global attention
    attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
    if before_sep_token is True:
        attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
    else:
        # last token is separation token and should not be counted and in the middle are two separation tokens
        attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
            attention_mask.expand_as(input_ids) < input_ids.shape[-1]
        ).to(torch.uint8)

    return attention_mask

# Loading Data

In [8]:
from torch.utils.data import Dataset, DataLoader


class CoQADataset(Dataset):
    def __init__(self, data):
        self.data = data
 
    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


coqa_data = datasets.load_dataset("coqa")

dataset_train = CoQADataset(coqa_data["train"])
dataset_val = CoQADataset(coqa_data['validation'])

loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True) # , collate_fn=CoQADataset.collate_fn
loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True) # , collate_fn=CoQADataset.collate_fn

Downloading:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/766 [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset coqa/default (download: 55.40 MiB, generated: 18.35 MiB, post-processed: Unknown size, total: 73.75 MiB) to /root/.cache/huggingface/datasets/coqa/default/1.0.0/553ce70bfdcd15ff4b5f4abc4fc2f37137139cde1f58f4f60384a53a327716f0...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/49.0M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset coqa downloaded and prepared to /root/.cache/huggingface/datasets/coqa/default/1.0.0/553ce70bfdcd15ff4b5f4abc4fc2f37137139cde1f58f4f60384a53a327716f0. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

NameError: ignored

# Model

In [None]:
class MyLongformerForQuestionAnswering(LongformerPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

        self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096', gradient_checkpointing=True)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        global_attention_mask=None,
        head_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if global_attention_mask is None:
            if input_ids is None:
                logger.warning(
                    "It is not possible to automatically generate the `global_attention_mask` because input_ids is None. Please make sure that it is correctly set."
                )
            else:
                # set global attention on question tokens automatically
            global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id)
 
        outputs = self.longformer(
            input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
            head_mask=head_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return LongformerQuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            global_attentions=outputs.global_attentions,
        )

    def predict_answer(question, text):
        encoding = self.tokenizer(question, text, return_tensors="pt")
        start_positions, end_positions = torch.tensor([1]), torch.tensor([3])

        outputs = self(**encoding, start_positions=start_positions, end_positions=end_positions)

        start_logits = outputs.start_logits
        end_logits = outputs.end_logits
        all_tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].tolist())

        answer_tokens = all_tokens[torch.argmax(start_logits) :torch.argmax(end_logits)+1]
        answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token
        print(outputs)

In [1]:
def test_my_longformer_for_question_answering():
    tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
    model = MyLongformerForQuestionAnswering.from_pretrained('allenai/longformer-base-4096', gradient_checkpointing=True)

    question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
    encoding = tokenizer(question, text, return_tensors="pt")
    start_positions, end_positions = torch.tensor([1]), torch.tensor([3])

    outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)

    start_logits = outputs.start_logits
    end_logits = outputs.end_logits
    all_tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].tolist())

    answer_tokens = all_tokens[torch.argmax(start_logits) :torch.argmax(end_logits)+1]
    answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token
    print(outputs)

# Train

In [None]:
def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    _ = precision_recall_fscore_support(y_true=labels, y_pred=pred)
    return {"accuracy": accuracy, "precision": _["precision"], "recall": _["recall"], "f1": _["fbeta_score"]}

In [None]:
class Trainer:
    def __init__(self, x):
        pass

In [None]:
## TRAINING

training_args = TrainingArguments(
    output_dir="output",
    evaluation_strategy="steps",
    eval_steps=500,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    seed=0,
    load_best_model_at_end=True,
    disable_tqdm = False, 
    warmup_steps=200,
    # dataloader_num_workers = 0,
    # run_name = 'longformer-classification-updated-rtx3090_paper_replication_2_warm'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# Train pre-trained model
trainer.train()

In [None]:
# save the best model
trainer.save_model('/media/data_files/github/website_tutorials/results/paper_replication_lr_warmup200')
trainer.evaluate()

# Predicting

In [None]:
import numpy as np


def get_top_answers(possible_starts, possible_ends, input_ids):
  answers = []
  for start,end in zip(possible_starts, possible_ends):
    #+1 for end
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[start:end+1]))
    answers.append( answer )
  return answers


def answer_question(question, context, top_n):

    inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt")

    input_ids = inputs["input_ids"].tolist()[0]

    text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    model_out = model(**inputs)

    answer_start_scores = model_out["start_logits"]
    answer_end_scores = model_out["end_logits"]

    possible_starts = np.argsort(answer_start_scores.cpu().detach().numpy()).flatten()[::-1][:top_n]
    possible_ends = np.argsort(answer_end_scores.cpu().detach().numpy()).flatten()[::-1][:top_n]
    
    #get best answer
    answer_start = torch.argmax(answer_start_scores)  
    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

    answers = get_top_answers(possible_starts, possible_ends, input_ids )

    return { "answer":answer, "answer_start":answer_start, "answer_end":answer_end, "input_ids":input_ids,
            "answer_start_scores":answer_start_scores, "answer_end_scores":answer_end_scores, "inputs":inputs, "answers":answers,
            "possible_starts":possible_starts, "possible_ends":possible_ends}

In [None]:
answer

''

In [None]:
outputs

LongformerQuestionAnsweringModelOutput([('start_logits',
                                         tensor([[-0.0883, -0.4564, -0.1797, -0.1296, -0.2541, -0.2960, -0.2346, -0.0994,
                                                  -0.1706, -0.2088, -0.2648, -0.2898, -0.1920, -0.1234, -0.0849, -0.1439,
                                                  -0.1699]], grad_fn=<CloneBackward0>)),
                                        ('end_logits',
                                         tensor([[0.3568, 0.4643, 0.3713, 0.4937, 0.5418, 0.3480, 0.6213, 0.3679, 0.2930,
                                                  0.4053, 0.4651, 0.3043, 0.3991, 0.2815, 0.4168, 0.3727, 0.2794]],
                                                grad_fn=<CloneBackward0>))])