In [1]:
# Changes v4:
# - Bugfixing in tokenize_samples: If sample has a single token answer and that token is the last token in context, but not the last token of the sample (i.e. there are further <pad> tokens), then token_start_index is greater than token_end_index
# - Revision of code in tokenize_samples: Outsourcing of code fragments into separate functions to reduce complexity and redundancy of code

# Minor changes v4.1:
# - Print statement in preparation for training has been modified: print("#Samples with answer:",total_samples-samples_with_null_answer) instead of print("#Samples with answer:",samples_with_null_answer)

# Minor changes v4.2:
# - Final network update (training), if there are batch samples left

# Install and Import Dependencies

## Install Dependencies

In [None]:
! pip install transformers datasets huggingface_hub
! pip install wandb -qqq

## Import Dependencies

In [None]:
# for loading the input dataset from disk
from datasets import load_dataset

from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering, create_optimizer

# we use Weights & Biases 
import wandb
from wandb.keras import WandbCallback

# ... as well as TensorFlow for training
import tensorflow as tf

from tqdm import tqdm
import random

## Check GPU Support

In [None]:
# Check TensorFlow version and whether GPU is available (there should be at least on physical device being listed)
tf.__version__, tf.config.list_physical_devices("GPU")

# Settings

In [5]:
# Project and run name on Weights & Biases
project_name = "WANDB_PROJECT_NAME"
run_name = "WANDB_RUN_NAME"

# path to directory containing QA-samples created in 'Data Preparation V3'
input_path = "/home/user/output_directory_of_data_preparation/"

# path to directory where checkpoints should be stored
checkpoint_path= "/home/user/directory_for_checkpoints/"

# path to directory used for caching
caching_path = "/home/user/directory_for_hf_cache/"

# Checkpoint of base model on Hugging Face
model_checkpoint = "microsoft/codebert-base"

# Parameters for tokenizing:
max_length = 512  # Maximum number of features (i.e. indices) consisting of tokenized question and context in a tokenized sample
doc_stride = 128  # Allowed overlap of the tokenized context if a QA-sample must be split into multiple tokenized samples due to length limitations

# As a QA-sample might be split into multiple tokenized samples resulting into the situation that only one or a few tokenized samples contain the answer in its sub-context and the rest are no answerable samples,
# we limit the number of no answerable samples by the following threshold parameter:
max_no_answers_per_possible_answer = 3 # for each possible answer pick X no answerable samples

# Training parameters:
batch_size = 16
learning_rate = 2e-5
num_train_epochs = 10
weight_decay = 0.01  # never be used?
xla = False

In [6]:
# Variables for randomly picking X no answerable samples from all tokenized samples of a QA-sample
random.seed(42)
chosen_sample_indices = dict()
first_preprocessing_iteration = True

# Input Pipeline

In [7]:
# prepare tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
pad_on_right = tokenizer.padding_side == "right" 

In [8]:
def expand_samples(examples):
    """
    'Data Preparation V3' has created a dataset where multiple question-answer pairs, each sharing same context, are collated within one sample. 
    This function expands the dataset structure so that in the resulting dataset each row represents a single question-answer pair.
    
    Parameters
    ----------
    examples : dict()
        Dictionary having the format:
        {
            'id': [string],
            'title': [string],
            'context': [string],
            'questions':
                [
                    {
                        'id': string
                        'question': string
                        'question_length': int
                        'answers': 
                            {
                                'text': [string],
                                'answer_start': [int]
                            }
                    }
                ]
   
        }
        
    Returns
    -------
    Python dictionary having the format:
    {
        'id': [string],
        'title': [string],
        'context': [string],
        'question': [string],
        'question_id': [string],
        'answers': 
            [
                {
                    'text': [string],
                    'answer_start': [int]
                }
            ]
            
    }
    """
    d = {}
    d["id"] = []
    d["title"] = []
    d["context"] = []
    d["question"] = []
    d["question_id"] = []
    d["answers"] = []
  
    for i in range(len(examples["id"])):
        for question in examples["questions"][i]:
            d["id"].append(examples["id"][i])
            d["title"].append(examples["title"][i])
            d["context"].append(examples["context"][i])
            d["question"].append(question["question"])
            d["question_id"].append(question["id"]) 
            d["answers"].append(question["answers"])
            
    return d

In [9]:
def mask_offset_mapping(sequence_ids, context_index, offset_mapping):
    """
    Masks the passed 'offset_mapping', i.e. replaces all its entries with 'None' that represent tokens that are not part of the context.
    The method returns the masked offset mapping.
    
    Parameters
    ----------
    sequence_ids :
        Vector that defines for each token whether the token is part of the question, of the context, or is a special token.
        An entry representing a token that is part of the context must have the value specified in 'context_index'
    context_index : int
        Value that is used to mark a token as part of the context in 'sequence_ids'
    offset_mapping
        Vector that contains for each token its start and end index on character level in the original input (consisting of question, context, and special characters). 
    
    Returns
    -------
    Masked 'offset_mapping'; entries that represent tokens that are not part of the context have the value 'None'
    """
    masked_offset_mapping = [(o if sequence_ids[k] == context_index else None) for k, o in enumerate(offset_mapping)]
    return masked_offset_mapping

def calc_start_end_index(cls_index, offset_mapping, answers):
    """
    Calculates start and end index on token level of the passed answer based on the specified 'offset_mapping'.
    The method returns start and end index (token level) as well as flags indicating whether answer is out of span (third return parameter) and
    whether the sample is erroneous since the calculated start index is greater than end index (fourth return parameter). If the answer is out of span,
    start and end index have the value of the passed 'cls_index'.
    
    Parameters
    ----------
    cls_index : int
        Index of the [CLS] token that is used as start and end index if the answer is out of span
    offset_mapping 
        Vector that contains for each token its start and end index on character level in the original input (consisting of question, context, and special characters). 
        All entries that represent a token that is not part of the context must be 'None' (use the method 'mask_offset_mapping(...)' to mask the 'offset_mapping' vector before using this method).
    answers: dict()
        Dictionary having the format:
            {
                'text': [string],
                'answer_start': [int]
            }
    
    Returns
    -------
    start_index : int
        Calculated start index on token level (might have the value of 'cls_index' if the answer is out of span)
    end_index : int
        Calculated end index on token leven (might have the value of 'cls_index' if the answer is out of span)
    answer_out_of_span : bool
        Flag that indicates whether the answer is completely or partially out of span (True) or within span (False)
    erroneous : bool
        Flag that indicates whether the result of the calculation is erroneous. This might be the case if the answer 
        does not start or ends with the first/last character of a token, but elsewhere within the span of a token. 
    """
    
    # if no answer is given
    if not answers["answer_start"]:
        # return 'answer out of span' result
        return cls_index, cls_index, True, False
    
    # load answer start (on character level)
    start_char = answers["answer_start"][0]
    
    # calculate answer end (on character level)
    end_char = start_char + len(answers["text"][0])
    
    # initialize start and end index (on token level)
    start_token = 0
    end_token = len(offset_mapping) - 1
    
    # move start index (token level) to the first token of the context
    while offset_mapping[start_token] is None:
        start_token += 1
    
    # move end index (token level) to the last token of the context
    while offset_mapping[end_token] is None:
        end_token -= 1
        
    # check whether answer is within context span:
    # if start_char is smaller than index of first character of first context token OR end_char is greater than index of last character of last context token
    if start_char < offset_mapping[start_token][0] or end_char > offset_mapping[end_token][1]:
        # return 'answer out of span' result
        return cls_index, cls_index, True, False
    
    # move start index (token level) to the token whose start index on character level matches answer start
    while start_token < len(offset_mapping) and offset_mapping[start_token] is not None and start_char != offset_mapping[start_token][0]:
        start_token += 1
        
    # move end index (token level) to the token whose end index on character level matches answer end
    while end_token > 0 and offset_mapping[end_token] is not None and end_char != offset_mapping[end_token][1]:
        end_token -= 1
        
    # check whether start_token is greater than end_token
    # This could be the case if tokenization is not fine-grained enough, i.e. the answer does not start or ends with the first/last character of a token, but elsewhere within the span of a token 
    if start_token > end_token:
        # return error result
        return cls_index, cls_index, True, True
    
    return start_token, end_token, False, False

def prepare_no_answerable_collection(sample_mapping):
    """
    Prepares and returns a collection that contains for each original sample (indicated by its 'sample_index' in 'sample_mapping') an empty list.
    This collection is used to collect tokenized samples for each original sample that are not answerable. 
    Note that the method returns 'None', if the feature of limiting 'no answerable samples' is disabled (see 'max_no_answers_per_possible_answer').
    
    Parameter
    ---------
    sample_mapping
        List of indices where each item represents a tokenized samples and its value is the index of the original QA-sample
    
    Returns
    -------
    Collection that contains for each original sample an empty list. Use the 'sample_index' to access the respective empty list.
    The return value is 'None', if the feature of limiting 'no answerable samples' is disabled (see 'max_no_answers_per_possible_answer').
    """
    
    # create a collection that contains for each 'sample_index' an empty list (we will fill these lists with indices of 'no answerable samples')
    if max_no_answers_per_possible_answer is not None:
        # prepare a collection where each QA-sample has a list containg its no answerable tokenized samples
        no_answerable_collection = dict()
        
        # iterate over all indices of QA-samples ...
        for sample_index in sample_mapping:
            if sample_index not in no_answerable_collection:
                # ... and prepare an empty list
                no_answerable_collection[sample_index] = []
        return no_answerable_collection
    else:
        return None
    
def pick_no_answerable_samples(no_answerable_collection, original_sample_index, tokenized_examples, question_id:str):
    """
    Method randomly picks 'n' samples from the list in the passed no_answerable_collection that has the specified 'original_sample_index',
    where 'n' has the value of 'max_no_answers_per_possible_answer'. The method removes the ignore flag in the passed 'tokenized_examples' collection for each picked sample and
    adds the indicies of the picked samples to the global 'chosen_sample_indices' dictionary.
    
    Parameters
    ----------
    no_answerable_collection 
        Collection that contains for each original sample a list with indices of no answerable samples. Use the method prepare_no_answerable_collection(...) to prepare this collection. 
        Make sure that this collection is not 'None' before calling this method.
    original_sample_index : int
        Index of the original QA-sample
    tokenized_examples
        Tokenized samples structure created by tokenizer
    question_id : str
        Question ID of the original QA-sample
    """
    # If number of 'no answerable samples' is lower than number of samples to pick
    if len(no_answerable_collection[original_sample_index]) < max_no_answers_per_possible_answer:
        # pick them all
        picked_indices = no_answerable_collection[original_sample_index]
    else:
        # else ... pick randomly 'no answerable samples'
        picked_indices = random.sample(no_answerable_collection[original_sample_index],max_no_answers_per_possible_answer)

    # For each (randomly) pick 'no answerable sample' 
    for picked_index in picked_indices: 
        # ... remove the 'ignore' flag
        tokenized_examples["ignore"][picked_index] = False

    # Furthermore, in preparation for next iteration (epoch > 0), prepare a dictionary of chosen sample indices by memorizing their unique 'question_id' key of the original QA-sample and the list of samples indices of 'no answerable samples'
    if question_id in chosen_sample_indices.keys(): 
        print("Warning! Collision detected for: ",question_id,", sample index is ",original_sample_index)
    chosen_sample_indices[question_id] = picked_indices

def remove_ignore_flag_on_selected_answerable_samples(question_id: str, tokenized_examples):
    """
    Removes the ignore flag in the passed 'tokenized_examples' collection for each no answerable sample whose index is stored in 'chosen_sample_indices' (global attribute) for the passed 'question_id'.
    
    Parameters
    ----------
    question_id : str
        Question ID of the original QA-sample
    tokenized_examples
        Tokenized samples structure created by tokenizer
    """
    for selected_index in chosen_sample_indices[question_id]:
        # ... remove the 'ignore' flag
        tokenized_examples["ignore"][selected_index] = False
                 
def tokenize_training_samples(examples):
    
    global total_qa_samples
    total_qa_samples+=len(examples["question"])
    
    # Tokenizes the input dataset consisting of multiple QA-samples with truncation, padding, and overflow (stride).
    # More precisely, each QA-sample may result into multiple tokenized samples if the combination of tokenized question and context exceeds the 'max_length'.
    # In this case, the sub-context of the tokenized samples originating from the same QA-sample overlaps. In detail, the sub-context of the second tokenized sample starts with the last N token indices of the sub-context of the first sub-context and so further.
    # So in fact, the sub-context of the subsequent tokenized sample overlaps a bit the sub-context of the previous tokenized sample.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    
    # List of indices, where each item represents a tokenized sample and its value is the index of the original QA-sample, e.g. [0,0,1,2] which means that the first two tokenized samples belong to the first original QA-sample
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    
    # prepare a collection that contains for each original sample an empty list (we will fill these lists with indices of 'no answerable samples')
    no_answerable_collection = prepare_no_answerable_collection(sample_mapping)
    
    # Prepare output dataset structure
    # 'start_positions' (int) contains the token index of the start of the answer for each tokenized sample. It value might be the 'cls_index' if its 'no answerable' sample.
    tokenized_examples["start_positions"] = []
    # 'end_positions' (int) contains the token index of the end of the answer for each tokenized sample. It value might be the 'cls_index' if its 'no answerable' sample.
    tokenized_examples["end_positions"] = []
    # Flag that indicates whether sample should be removed in the subsequent processing step (see remove_flagged_samples(...))
    tokenized_examples["ignore"] = []
    
    # Iterate over all tokenized samples and extract its offset_mapping
    for i, offset_mapping in enumerate(tokenized_examples["offset_mapping"]):
        
        # Mask offset mapping
        masked_offset_mapping = mask_offset_mapping(
            sequence_ids = tokenized_examples.sequence_ids(i),
            context_index = 1 if pad_on_right else 0,
            offset_mapping = offset_mapping)
        
        # Load the index of the original QA-sample
        sample_index = sample_mapping[i]
        
        # Load the answer
        answers = examples["answers"][sample_index]
        
        # Load cls_index
        cls_index = tokenized_examples["input_ids"][i].index(tokenizer.cls_token_id)
        
        # Calculate start and end index (on token level)
        start_index, end_index, answer_out_of_span, erroneous = calc_start_end_index(cls_index,masked_offset_mapping,answers)
        
        # skip if erroneous
        if erroneous:
            continue
        
        if answer_out_of_span:
            # If number of 'no answerable samples' should be limited ...
            if max_no_answers_per_possible_answer is not None:
                # ... preliminarily flag tokenized sample as 'ignore'
                tokenized_examples["ignore"].append(True)
                # ... add tokenized sample to 'no answerable samples' collection
                no_answerable_collection[sample_index].append(i)
            else:
                # Else do nothing (do not ignore sample)
                tokenized_examples["ignore"].append(False)
        else:
            tokenized_examples["ignore"].append(False) 
            
        # append meta data to tokenized examples
        tokenized_examples["start_positions"].append(start_index)
        tokenized_examples["end_positions"].append(end_index)
        tokenized_examples["offset_mapping"][i] = masked_offset_mapping
        
        
    # If number of 'no answerable samples' should be limited ...
    if max_no_answers_per_possible_answer is not None:
        # For each original QA-sample
        for sample_index in no_answerable_collection.keys():
            # If it is the first iteration (epoc), we will pick 'no answerable samples' randomly
            if first_preprocessing_iteration:
                pick_no_answerable_samples(no_answerable_collection,sample_index,tokenized_examples,examples["question_id"][sample_index])
            # If it is the second, third, ... iteration (epoch > 0), select 'no answerable samples' by their unique 'question_id' key
            else:
                remove_ignore_flag_on_selected_answerable_samples(examples["question_id"][sample_index],tokenized_examples)
                
    return tokenized_examples

In [10]:
def remove_flagged_samples(examples):
    """
    Removes all tokenized samples that have a ignore flag from the passed dataset and returns this dataset
    
    Parameters
    ----------
    examples : dict()
        Tokenized samples
    
    Returns
    -------
    Tokenized samples without samples that have been marked as ignored
    """
    d = {}
    d["start_positions"] = []
    d["end_positions"] = []
    d["input_ids"] = []
    d["attention_mask"] = []
    
    for i in range(len(examples["input_ids"])):
        if not examples["ignore"][i]:
            d["start_positions"].append(examples["start_positions"][i])
            d["end_positions"].append(examples["end_positions"][i])
            d["input_ids"].append(examples["input_ids"][i])
            d["attention_mask"].append(examples["attention_mask"][i])
    return d

In [11]:
# Specifies the data source (i.e. the input) of our input pipeline. We will stream data without keeping loaded samples in memory
dataset = load_dataset('json', data_files={'train':input_path+"train/*.json",'test':input_path+"test/*.json", 'validation':input_path+"validation/*.json"},streaming=True, keep_in_memory=None, cache_dir=caching_path)

Using custom data configuration default-423d6e444cb4ec4d


In [12]:
# Connects the data source with expand_samples function
expanded_dataset = dataset.map(expand_samples, batched=True, remove_columns=["questions"], batch_size=1024)

In [13]:
# Connects the expanded results with the tokenizer
tokenized_dataset = expanded_dataset.map(tokenize_training_samples, batched=True, remove_columns=["id","title","context","question","answers","question_id"], batch_size=32)

In [14]:
# Remove flagged samples
final_dataset = tokenized_dataset.map(remove_flagged_samples, batched=True, remove_columns=["ignore","offset_mapping"], batch_size=1024)

In [15]:
# Shuffle dataset
#final_dataset = removed_flag_dataset.shuffle(buffer_size=1024*8, seed=42)

# Training

## Preparation

In [None]:
total_qa_samples = 0
total_qa_samples_d = 0
total_tokenized_samples = 0
unanswerable_tokenized_samples = 0
print_counter = 0


for element in final_dataset["train"]:

    total_tokenized_samples+=1
    print_counter+=1
    
    cls_index = element["input_ids"].index(tokenizer.cls_token_id)
    if element["start_positions"] == cls_index and element["end_positions"] == cls_index:
        unanswerable_tokenized_samples+=1
    
    if print_counter == 1000:
        print_counter = 0
        print(str(total_tokenized_samples/1000)+"K", end='\r', flush=True)

first_preprocessing_iteration = False

In [None]:
print("#QA samples:",total_qa_samples)
print("#Tokenized samples:",total_tokenized_samples)
print("#Unanswerable tokenized samples:", unanswerable_tokenized_samples," (",(unanswerable_tokenized_samples/total_tokenized_samples)*100,"%)")
print("#Answerable tokenized samples:",total_tokenized_samples-unanswerable_tokenized_samples," (",((total_tokenized_samples-unanswerable_tokenized_samples)/total_tokenized_samples)*100,"%)")

In [None]:
model = TFAutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
model.summary()

In [19]:
total_train_steps = (total_tokenized_samples // batch_size) * num_train_epochs

optimizer, schedule = create_optimizer(init_lr=learning_rate, num_warmup_steps=0, num_train_steps=total_train_steps)

In [None]:
model.compile(optimizer=optimizer)

In [None]:
wandb.login(key="")
wandb.init(project=project_name, name=run_name, config={
    "input_path":input_path,
    "checkpoint_path":checkpoint_path,
    "model_checkpoint":model_checkpoint,
    "max_length":max_length,
    "doc_stride":doc_stride,
    "max_no_answers_per_possible_answer":max_no_answers_per_possible_answer,
    "batch_size":batch_size,
    "learning_rate":learning_rate,
    "num_train_epochs":num_train_epochs,
    "weight_decay":weight_decay,
    "xla":xla,
    "total_qa_samples": total_qa_samples,
    "total_tokenized_samples": total_tokenized_samples,
    "unanswerable_tokenized_samples":unanswerable_tokenized_samples,
    "answerable_tokenized_samples":total_tokenized_samples-unanswerable_tokenized_samples
})


## Training

In [None]:
for i in range(num_train_epochs):
    #print("Epoch "+str(i+1)+"/"+str(num_train_epochs))
    with tqdm(total=total_tokenized_samples // batch_size, desc="Epoch "+str(i+1)+"/"+str(num_train_epochs)) as pbar:
        total_step_counter = 0
        batch_counter = 0
        loss = 0
        sample_counter_per_epoch = 0
        
        batch = dict() #added in V4.2
        
        for example in final_dataset["train"]:
            attention_mask_t = tf.constant([example["attention_mask"]])
            input_ids_t = tf.constant([example["input_ids"]])
            start_positions_t = tf.constant([example["start_positions"]])
            end_positions_t = tf.constant([example["end_positions"]])
      
            if batch_counter == 0:
                #batch = dict() #removed in V4.2
                batch["attention_mask"] = attention_mask_t
                batch["input_ids"] = input_ids_t
                batch["start_positions"] = start_positions_t
                batch["end_positions"] = end_positions_t
            else:
                batch["attention_mask"] = tf.concat([batch["attention_mask"], attention_mask_t],axis=0)
                batch["input_ids"] = tf.concat([batch["input_ids"], input_ids_t],axis=0)
                batch["start_positions"] = tf.concat([batch["start_positions"], start_positions_t],axis=0)
                batch["end_positions"] = tf.concat([batch["end_positions"], end_positions_t],axis=0)
      
            batch_counter+=1
            sample_counter_per_epoch+=1

            if batch_counter == batch_size:
                loss = model.train_on_batch(x=batch)
                batch_counter = 0
                total_step_counter+=1
                pbar.update(1)
                batch = dict() #added in V4.2
        
        # (added in V4.2) if there are untrained samples lef
        if batch_counter > 0:
            print("Final network update with remaining '",batch_counter,"' samples")
            loss = model.train_on_batch(x=batch)
            total_step_counter+=1
                
        print("Loss ",loss)
        wandb.log({"loss": loss, "epoch":(i+1), "samples": sample_counter_per_epoch})
        
        
        #model.save(filepath=checkpoint_path+str(i)+"_"+str(loss))
        model.save_pretrained(checkpoint_path+str(i)+"_"+str(loss))
        print("Saving model at "+checkpoint_path+str(i)+"_"+str(loss))    

In [None]:
wandb.finish()