In [None]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
import wandb
wandb.login()

In [None]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration, T5Config

tokenizer = T5TokenizerFast.from_pretrained("t5-small")

In [None]:
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding='do_not_pad',
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [None]:
from datasets import load_dataset
squad = load_dataset("squad")

In [None]:
tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)

In [None]:
from typing import Tuple


class QuestionAnswering(nn.Module):
    def __init__(self, model):
        # model: T5 with encoder and decoder
        super().__init__()
        
        self.model = model
        self.qa_outputs = nn.Linear(model.config.hidden_size, 2)
        

    def question_answering(self, batch):
        """
        For each token in sequence predict its probability of
        being starting position and ending position
        """
        outputs = self.model.encoder(
            batch['input_ids'],
            attention_mask=batch['attention_mask']
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        
        # sometimes the start/end positions are outside our model inputs, we ignore these terms
        ignored_index = start_logits.size(1)
        start_positions = batch['start_positions'].clamp(0, ignored_index)
        end_positions = batch['end_positions'].clamp(0, ignored_index)

        loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
        start_loss = loss_fct(start_logits, start_positions)
        end_loss = loss_fct(end_logits, end_positions)
        total_loss = (start_loss + end_loss) / 2

        output = {
            'start_logits': start_logits.detach().cpu(),
            'end_logits': end_logits.detach().cpu(),
            'start_positions': start_positions.cpu(),
            'end_positions': end_positions.cpu(),
        }
        return total_loss, output
    
    def em(self, predicted: Tuple[int, int], ground_truth: Tuple[int, int]):
        """
        calculates exact match metric
        """
        return list(predicted) == list(ground_truth)
    
    def f1(self, predicted: Tuple[int, int], ground_truth: Tuple[int, int]):
        """
        calculates f1 metric
        """

        if predicted[1] < predicted[0]:
            return 0
        
        predicted = set(range(predicted[0], predicted[1] + 1))
        ground_truth = set(range(ground_truth[0], ground_truth[1] + 1))
        
        tp = len(predicted & ground_truth)
        if tp == 0:
            return 0

        fp = len(predicted - ground_truth)
        fn = len(ground_truth - predicted)
        
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        
        return 2 * precision * recall / (precision + recall)

    def train_one_epoch(self, dataloader, optimizer):
        self.train()
        
        ems = []
        f1s = []
        for batch in tqdm(dataloader):
            for k, v in batch.items():
                batch[k] = v.to(device)

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                loss, output = self.question_answering(batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            predicted = torch.stack([
                output['start_logits'].argmax(-1),
                output['end_logits'].argmax(-1)
            ]).T
            ground_truth = torch.stack([
                output['start_positions'],
                output['end_positions']
            ]).T
            
            ems.extend([self.em(pred, truth) for pred, truth in zip(predicted, ground_truth)])
            f1s.extend([self.f1(pred, truth) for pred, truth in zip(predicted, ground_truth)])
            
            wandb.log({
                'em': np.mean(ems[-32:]),
                'f1': np.mean(f1s[-32:])
            })

        return np.mean(ems), np.mean(f1s)
    
    @torch.no_grad()
    def evaluate(self, dataloader):
        self.eval()
        
        ems = []
        f1s = []
        for batch in tqdm(dataloader):
            for k, v in batch.items():
                batch[k] = v.to(device)

            loss, output = self.question_answering(batch)
            
            predicted = torch.stack([
                output['start_logits'].argmax(-1),
                output['end_logits'].argmax(-1)
            ]).T
            ground_truth = torch.stack([
                output['start_positions'],
                output['end_positions']
            ]).T
            
            ems.extend([self.em(pred, truth) for pred, truth in zip(predicted, ground_truth)])
            f1s.extend([self.f1(pred, truth) for pred, truth in zip(predicted, ground_truth)])

        return np.mean(ems), np.mean(f1s)

In [None]:
from functools import partial


def collate_batch(pad_id, batch):
    input_ids = []
    start_positions = []
    end_positions = []
    for sample in batch:
        input_ids.append(torch.tensor(sample['input_ids'], dtype=torch.long))
        start_positions.append(sample['start_positions'])
        end_positions.append(sample['start_positions'])

    batch = {
        'input_ids': pad_sequence(input_ids, padding_value=pad_id, batch_first=True),
        'start_positions': torch.tensor(start_positions, dtype=torch.long),
        'end_positions': torch.tensor(start_positions, dtype=torch.long)
    }
    batch['attention_mask'] = (batch['input_ids'] != pad_id).clone()
    
    return batch


qa_train_loader = torch.utils.data.DataLoader(
    tokenized_squad['train'],
    collate_fn=partial(collate_batch, tokenizer.pad_token_id),
    batch_size=64
)

qa_val_loader = torch.utils.data.DataLoader(
    tokenized_squad['validation'],
    collate_fn=partial(collate_batch, tokenizer.pad_token_id),
    batch_size=32
)

In [None]:
model = T5ForConditionalGeneration(T5Config.from_pretrained('t5-small'))
model.load_state_dict(torch.load('your/pretrained/model.pt'))

In [None]:
qa = QuestionAnswering(model).to(device)
optimizer = torch.optim.AdamW(qa.parameters(), lr=5e-5)

In [None]:
wandb.init(project='project', name='name')

In [None]:
for epoch in range(3):
    qa.train_one_epoch(qa_train_loader, optimizer)

In [None]:
qa.evaluate(qa_train_loader)