## Functions and parts of the code taken from Jupyter notebook from HuggingFace GitHub https://github.com/huggingface/notebooks/blob/main/examples/question_answering.ipynb

In [None]:
! pip install datasets transformers

In [None]:
import transformers
from datasets import load_dataset, load_metric
from datasets import Dataset
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
from transformers import default_data_collator
import shutil
import pandas as pd

print(transformers.__version__)
model_checkpoint = "bert-base-uncased"
batch_size = 16
squad_v2 = False
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

In [None]:
datasets = load_dataset("squad_v2" if squad_v2 else "squad")

In [None]:
# train SQuAD with computed heuristics
# file path may need some tweeks based on the location
data = pd.read_json('../input/squad-train-supersampled-all-heuristics/squad_train_with_heuristics_flags.json')

In [None]:
def supersample_dataset(data_lower, data_higher):
    """Super-sample the training dataset
    balances the size of two subsets obtained by split on threshold for specific heuristic

    Args:
        data_lower (Pandas Dataframe): dataset split with values lower or equal than the threshold
        data_higher (Pandas Dataframe): dataset split with values higher than the threshold
    
    Returns:
        Pandas Dataframe: super-sampled training dataset
    """
    list_subsets = []
    list_subsets.append(data_lower)
    list_subsets.append(data_higher)

    if len(data_higher) > len(data_lower):
        for i in range(len(data_higher)//len(data_lower)):
#             print(f"higher {i}")
            list_subsets.append(data_lower.sample(frac=1))
        new_data = pd.concat(list_subsets)
    else:
        for i in range(len(data_lower)//len(data_higher)):
#             print(f"lower {i}")
            list_subsets.append(data_higher.sample(frac=1))
        new_data = pd.concat(list_subsets)
    
    return new_data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
pad_on_right = tokenizer.padding_side == "right"

In [None]:
# function from HuggingFace GitHub
# with original comments
def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [None]:
# import for Kaggle logging problem
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
def training(training_type, type_):
    """Dataset split, preprocessing and model fine-tuning

    Args:
        training_type (str): name of the heuristic for the dataset split and super-sampling
        type_ (str): suffix for fine-tuned model
    """
    
    # split of dataset based on the training_type argument
    if training_type == 'base':
        new_data = data
    elif training_type == 'distance':
        data_higher, data_lower = [x for _, x in data.groupby(data['distances'] <= 7)]
        new_data = supersample_dataset(data_lower, data_higher)
    elif training_type == 'similar':
        data_higher, data_lower = [x for _, x in data.groupby(data['similar_words'] <= 4)]
        new_data = supersample_dataset(data_lower, data_higher)
    elif training_type == 'cosine':
        data_higher, data_lower = [x for _, x in data.groupby(data['cosine_similarity'] <= 0.1)]
        new_data = supersample_dataset(data_lower, data_higher)
    elif training_type == 'answer':
        data_higher, data_lower = [x for _, x in data.groupby(data['answer_lenght'] <= 3)]
        new_data = supersample_dataset(data_lower, data_higher)
    elif training_type == 'entities':
        data_higher, data_lower = [x for _, x in data.groupby(data['max_sim_ents'] <= 0)]
        new_data = supersample_dataset(data_lower, data_higher)
    elif training_type == 'position':
        data_higher, data_lower = [x for _, x in data.groupby(data['answer_subject_positions'] <= 1)]
        new_data = supersample_dataset(data_lower, data_higher)
    elif training_type == 'all':
        new_data = pd.concat([data.query('dist_flag == 0 and sim_flag == 0 and ans_flag == 0 and cos_flag == 0 and pos_flag == 0 and ents_flag == 0'),
                            data.query('dist_flag == 1'),
                            data.query('sim_flag == 1'),
                            data.query('ans_flag == 1'),
                            data.query('cos_flag == 1'),
                            data.query('pos_flag == 1'),
                            data.query('ents_flag == 1')]).sample(frac=1)
    elif training_type == 'top_3':
        new_data = pd.concat([data,
                        data.query('ans_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('dist_flag == 1'),
                        data.query('dist_flag == 1')]).sample(frac=1)
    elif training_type == 'dist_cos':
        new_data = pd.concat([data,
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('dist_flag == 1'),
                        data.query('dist_flag == 1'),
                        data.query('dist_flag == 1')]).sample(frac=1)
    elif training_type == 'dist_ans':
        new_data = pd.concat([data,
                        data.query('ans_flag == 1'),
                        data.query('dist_flag == 1'),
                        data.query('dist_flag == 1'),
                        data.query('dist_flag == 1')]).sample(frac=1)
    elif training_type == 'ans_cos':
        new_data = pd.concat([data,
                        data.query('ans_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1'),
                        data.query('cos_flag == 1')]).sample(frac=1)
    else:
        print('something went wrong! wrong type.')

    # removal of redundant columns
    if training_type != 'base':
        new_data = new_data.drop(['distances', 'similar_words', 'answer_lenght', 'cosine_similarity', 'answer_subject_positions', 'max_sim_ents', 'dist_flag', 'sim_flag', 'ans_flag', 'cos_flag', 'pos_flag', 'ents_flag'], axis = 1)
        datasets['train'] = Dataset.from_pandas(new_data)
        datasets['train'] = datasets['train'].remove_columns("__index_level_0__")
    
    # data preprocessing
    tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names)
    model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

    # Training arguments
    model_name = model_checkpoint.split("/")[-1]
    args = TrainingArguments(
        f"{model_name}-finetuned-squad_with_callbacks",
        evaluation_strategy = "steps",
        eval_steps = 200, # Evaluation and Save happens every 200 steps
        save_steps = 200,
        logging_steps = 200,
        save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=3,
        weight_decay=0.01,
        report_to="none",
        load_best_model_at_end=True,
    )

    data_collator = default_data_collator

    # Trainer for 
    trainer = Trainer(
        model,
        args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,    
        callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
    )

    trainer.train()

    trainer.save_model("test-squad-trained")

    # save model to zip
    shutil.make_archive(f"{model_name}-finetuned-squad_with_callbacks_{type_}", 'zip', './test-squad-trained')

In [None]:
# all possible heuristics and combinations supported to super-sampled the data and fine-tune BERT
# you can choose what you want
#--------------------------------

training_types = []
types_ = []

# training_types.append('base')
# types_.append('base')

# training_types.append('distance')
# types_.append('supersampled_distances_7')

# training_types.append('similar')
# types_.append('supersampled_similar_4')

# training_types.append('cosine')
# types_.append('supersampled_cosine_01')

# training_types.append('answer')
# types_.append('supersampled_answer_3')

# training_types.append('entities')
# types_.append('supersampled_entities_0')

# training_types.append('position')
# types_.append('supersampled_position_1')

# training_types.append('all')
# types_.append('supersampled_all')

# training_types.append('top_3')
# types_.append('supersampled_top_3')

# training_types.append('dist_cos')
# types_.append('supersampled_dist_cos')

# training_types.append('dist_ans')
# types_.append('supersampled_dist_ans')

# training_types.append('ans_cos')
# types_.append('supersampled_ans_cos')

for training_type, type_ in zip(training_types, types_):
    print(f"{training_type}, {type_}")
    training(training_type, type_)