### 0. Import libraries and load data

In [1]:
!pip install evaluate

In [2]:
!pip install transformers

In [3]:
!pip install accelerate

In [4]:
import os
import copy
import random
import itertools
import collections
from abc import ABC, abstractmethod

# utils
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# Hugging Face
import evaluate
import transformers
from datasets import load_dataset
from transformers import pipeline
from accelerate import Accelerator
from transformers import get_scheduler
from transformers import AutoTokenizer
from transformers import default_data_collator
from transformers import AutoModelForQuestionAnswering

# PyTorch
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader

# produce repeatable results
transformers.set_seed(42)

# enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# hyperparameters
EPOCHS = 3

# if True, process only 200 elements
TEST = False
ROOT_PATH = "./data/"
PRTRAINED_MODEL_CHECKPOINT = "prajjwal1/bert-tiny"

In [5]:
class DatasetLoader:
    def __init__(self, directory="../data/"):
        self.directory = directory

    def store_data(self):
        # download and chache data
        squad_data = load_dataset("squad")
        squad_data.cache_files

        # store a json for train and validation data
        for split, dataset in squad_data.items():
            dataset.to_json(f"{self.directory}squad_data_{split}.json")
        print(f"The dataset is stored at {self.directory}")

    def load_data(self):
        # load the train and validation datasets
        data_files = {
            "train": f"{self.directory}squad_data_train.json",
            "validation": f"{self.directory}squad_data_validation.json",
        }

        data = load_dataset("json", data_files=data_files)

        return data

In [6]:
dataset_loader = DatasetLoader(ROOT_PATH)
raw_data = dataset_loader.load_data()
raw_data


### 1. Preprocess the data

In [7]:
class DataPreprocessor(ABC):
    def __init__(self, max_length=384, stride=128) -> None:
        self.max_length = max_length
        self.stride = stride

    @abstractmethod
    def preprocess_train_data(self, questions_contexts):
        pass

    @abstractmethod
    def preprocess_dev_data(self, questions_contexts):
        pass


In [8]:
class DefaultDataPreprocessor(DataPreprocessor):
    def __init__(self, tokenizer, train_data_raw, dev_data_raw):
        DataPreprocessor.__init__(self)
        self.tokenizer = tokenizer
        self.train_data_raw = train_data_raw
        self.dev_data_raw = dev_data_raw

    def preprocess_train_data(self, questions_contexts):
        questions = [q.strip() for q in questions_contexts["question"]]

        preproc_questions_contexts = self.tokenizer(
            questions,
            questions_contexts["context"],
            max_length=self.max_length,
            truncation="only_second",
            stride=self.stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        offset_mapping = preproc_questions_contexts.pop("offset_mapping")
        sample_map = preproc_questions_contexts.pop("overflow_to_sample_mapping")
        answers = questions_contexts["answers"]
        start_positions = []
        end_positions = []

        for i, offset in enumerate(offset_mapping):
            sample_idx = sample_map[i]
            answer = answers[sample_idx]
            start_char = answer["answer_start"][0]
            end_char = answer["answer_start"][0] + len(answer["text"][0])
            sequence_ids = preproc_questions_contexts.sequence_ids(i)

            idx = 0

            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx

            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                idx = context_start

                while idx <= context_end and offset[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)

                idx = context_end

                while idx >= context_start and offset[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)

        preproc_questions_contexts["start_positions"] = start_positions
        preproc_questions_contexts["end_positions"] = end_positions

        return preproc_questions_contexts

    def preprocess_dev_data(self, questions_contexts):
        questions = [q.strip() for q in questions_contexts["question"]]

        preproc_questions_contexts = self.tokenizer(
            questions,
            questions_contexts["context"],
            max_length=self.max_length,
            truncation="only_second",
            stride=self.stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        sample_map = preproc_questions_contexts.pop("overflow_to_sample_mapping")
        example_ids = []

        for i in range(len(preproc_questions_contexts["input_ids"])):
            sample_idx = sample_map[i]
            example_ids.append(questions_contexts["id"][sample_idx])

            sequence_ids = preproc_questions_contexts.sequence_ids(i)
            offset = preproc_questions_contexts["offset_mapping"][i]
            preproc_questions_contexts["offset_mapping"][i] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]

        preproc_questions_contexts["example_id"] = example_ids

        return preproc_questions_contexts

    def preprocess_data(self, dataset_name, verbose=True):
        if dataset_name == "train":
            preprocessed_data = self.train_data_raw.map(
                self.preprocess_train_data,
                batched=True,
                remove_columns=self.train_data_raw.column_names,
            )
            initial_pairs_no = len(self.train_data_raw)
        else:
            preprocessed_data = self.dev_data_raw.map(
                self.preprocess_dev_data,
                batched=True,
                remove_columns=self.dev_data_raw.column_names,
            )
            initial_pairs_no = len(self.dev_data_raw)

        if verbose:
            print(
                f"Number of {dataset_name} question - context pairs:\nInitially: {initial_pairs_no}\nAfter preprocessing:{len(preprocessed_data)}"
            )
        return preprocessed_data


### 2. Postprocess the predictions

In [9]:
class PredictionsPostprocessor(ABC):
    def __init__(self, n_best=32, max_answer_length=128) -> None:
        self.n_best = n_best
        self.max_answer_length = max_answer_length

    @abstractmethod
    def postprocess_predictions(self, start_logits, end_logits, preprocessed_data, questions_contexts):
        pass


In [10]:
class DefaultPredictionsPostprocessor(PredictionsPostprocessor):
    def __init__(self):
        PredictionsPostprocessor.__init__(self)

    def postprocess_predictions(self, start_logits, end_logits, preprocessed_data, questions_contexts):
        mappings = collections.defaultdict(list)

        for idx, preprocessed_data_ in enumerate(preprocessed_data):
            mappings[preprocessed_data_["example_id"]].append(idx)

        predicted_answers = []

        for question_context in questions_contexts:
            example_id = question_context["id"]
            context = question_context["context"]
            answers = []

            for feature_index in mappings[example_id]:
                start_logit = start_logits[feature_index]
                end_logit = end_logits[feature_index]
                offsets = preprocessed_data[feature_index]["offset_mapping"]

                start_indexes = np.argsort(start_logit)[-1 : -self.n_best - 1 : -1].tolist()
                end_indexes = np.argsort(end_logit)[-1 : -self.n_best - 1 : -1].tolist()

                for start_index in start_indexes:
                    for end_index in end_indexes:
                        if offsets[start_index] is None or offsets[end_index] is None:
                            continue

                        if end_index < start_index or end_index - start_index + 1 > self.max_answer_length:
                            continue

                        answer = {
                            "text": context[offsets[start_index][0] : offsets[end_index][1]],
                            "logit_score": start_logit[start_index] + end_logit[end_index],
                        }
                        answers.append(answer)

            if len(answers) > 0:
                best_answer = max(answers, key=lambda x: x["logit_score"])
                predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})
            else:
                predicted_answers.append({"id": example_id, "prediction_text": ""})

        correct_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in questions_contexts]

        return predicted_answers, correct_answers


### 3. Evaluate performance

In [11]:
class MetricsEvaluator:
    def __init__(self, dataset_name="squad"):
        self.metric = evaluate.load(dataset_name)

    def compute_metric(self, predicted_answers, correct_answers):
        metric_values = self.metric.compute(predictions=predicted_answers, references=correct_answers)

        return metric_values


### 4. Fine-tune TinyBERT

In [12]:
class ModelTrainer:
    def __init__(self, model_name,  train_data, val_data, tokenizer, batch_size, learning_rate, num_train_epochs=EPOCHS):
        self.model_name = model_name
        self.val_data = copy.deepcopy(val_data)
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_train_epochs = num_train_epochs
        self.metrics_evaluator = MetricsEvaluator()
        self.default_prediction_postprocessor = DefaultPredictionsPostprocessor()

        (
            self.model,
            self.optimizer,
            self.train_data_loader,
            self.val_data_loader,
            self.accelerator,
            self.lr_scheduler,
            self.num_training_steps,
        ) = self.__initialize_components(train_data, val_data)

    def __format_data(self, train_data, val_data):
        train_data.set_format("torch")
        val_data = val_data.remove_columns(["example_id", "offset_mapping"])
        val_data.set_format("torch")

        return train_data, val_data

    def __generate_data_loaders(self, train_data, val_data):
        train_data_loader = DataLoader(
            train_data,
            shuffle=True,
            collate_fn=default_data_collator,
            batch_size=self.batch_size,
        )

        val_data_loader = DataLoader(val_data, collate_fn=default_data_collator, batch_size=self.batch_size)

        return train_data_loader, val_data_loader

    def __set_learning_rate_decay(self, train_data_loader, optimizer):
        num_update_steps_per_epoch = len(train_data_loader)
        num_training_steps = self.num_train_epochs * num_update_steps_per_epoch

        lr_scheduler = get_scheduler(
            "linear",
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=num_training_steps,
        )

        return lr_scheduler, num_training_steps

    def __initialize_components(self, train_data, val_data):
        train_data, val_data = self.__format_data(train_data, val_data)
        train_data_loader, val_data_loader = self.__generate_data_loaders(train_data, val_data)
        model = AutoModelForQuestionAnswering.from_pretrained(PRTRAINED_MODEL_CHECKPOINT)
        optimizer = AdamW(model.parameters(), lr=self.learning_rate)

        accelerator = Accelerator()
        model, optimizer, train_data_loader, val_data_loader = accelerator.prepare(
            model, optimizer, train_data_loader, val_data_loader
        )

        lr_scheduler, num_training_steps = self.__set_learning_rate_decay(train_data_loader, optimizer)

        return model, optimizer, train_data_loader, val_data_loader, accelerator, lr_scheduler, num_training_steps

    def __store_artifacts(self):
        self.accelerator.wait_for_everyone()
        unwrapped_model = self.accelerator.unwrap_model(self.model)
        unwrapped_model.save_pretrained(f"{ROOT_PATH}{self.model_name}_model", save_function=self.accelerator.save)

        if self.accelerator.is_main_process:
            self.tokenizer.save_pretrained(f"{ROOT_PATH}{self.model_name}_tokenizer")

    def train(self):
        epochs_metrics = {}
        progress_bar = tqdm(range(self.num_training_steps))

        for epoch in range(self.num_train_epochs):
            self.model.train()

            for batch in self.train_data_loader:
                outputs = self.model(**batch)
                loss = outputs.loss
                self.accelerator.backward(loss)

                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()
                progress_bar.update(1)

            self.model.eval()
            start_logits = []
            end_logits = []
            self.accelerator.print("Evaluation!")

            for batch in self.val_data_loader:
                with torch.no_grad():
                    outputs = self.model(**batch)

                start_logits.append(self.accelerator.gather(outputs.start_logits).cpu().numpy())
                end_logits.append(self.accelerator.gather(outputs.end_logits).cpu().numpy())

            start_logits = np.concatenate(start_logits)
            end_logits = np.concatenate(end_logits)
            start_logits = start_logits[: len(self.val_data)]
            end_logits = end_logits[: len(self.val_data)]

            predicted_answers, correct_answers = self.default_prediction_postprocessor.postprocess_predictions(
                start_logits, end_logits, self.val_data, raw_data["validation"]
            )

            metrics = self.metrics_evaluator.compute_metric(predicted_answers, correct_answers)
            print(f"epoch {epoch}:", metrics)
            epochs_metrics[f"epoch {epoch} EM"] = metrics["exact_match"]
            epochs_metrics[f"epoch {epoch} F1"] = metrics["f1"]

            self.__store_artifacts()

        return epochs_metrics

### 5. Hyperparameter search

In [None]:
batch_sizes = [8, 16, 32, 64, 128]
learning_rates = [5e-4, 1e-3, 3e-3, 5e-3, 1e-2]
results = {
    "batch size": [], "learning rate": [],
    "epoch 0 EM": [], "epoch 0 F1": [],
    "epoch 1 EM": [], "epoch 1 F1": [],
    "epoch 2 EM": [], "epoch 2 F1": [],
}

for batch_size, learning_rate in itertools.product(batch_sizes, learning_rates):
    print(64 * "-")
    transformers.set_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    model_name = f"fine_tuned_tiny_bert_{batch_size}_{learning_rate}"
    tokenizer = AutoTokenizer.from_pretrained(PRTRAINED_MODEL_CHECKPOINT)

    if TEST:
        default_data_preprocessor = DefaultDataPreprocessor(tokenizer, raw_data["train"].select(range(100)), raw_data["validation"].select(range(100)))
    else:
        default_data_preprocessor = DefaultDataPreprocessor(tokenizer, raw_data["train"], raw_data["validation"])

    train_dataset = default_data_preprocessor.preprocess_data("train")
    validation_dataset = default_data_preprocessor.preprocess_data("validation")

    model_trainer = ModelTrainer(model_name, train_dataset, validation_dataset, tokenizer, batch_size, learning_rate)
    epochs_metrics = model_trainer.train()

    results["batch size"].append(batch_size)
    results["learning rate"].append(learning_rate)

    for metric in epochs_metrics.keys():
        results[metric].append(epochs_metrics[metric])

    pd.DataFrame(results).to_csv(ROOT_PATH + "results.csv", index=None)