In [None]:
import torch
from masking_strategies import mask_random_positions, mask_end_positions

In [None]:
def assure_CLS_and_SEP_false(id_tensor, true_false_tensor):
    """
    Receives a tensor of ids and a tensor of masking, with
    True and False values. Returns a copy of the true_false_tensor
    where every CLS and SEP position is for sure set to False.

    Doesn't mutate any input value and it's free of aliasing.
    """
    ans = true_false_tensor.detach().clone()
    ans = ans & (id_tensor != CLS) & (id_tensor != SEP)
    return ans

In [None]:
def mask(masking_strategy, tensor):
    """
    Mutates tensor accordingly to the masking_strategy.
    
    Args:
        masking_strategy: strategy to create the masked
                          input_id.
        tensor: tensor of tensors, the last representing
                the tokenized inputs.

    Returns:
        None
    """
    shape = tensor.shape
    mask_arr = masking_strategy(tensor)
    mask_arr = assure_CLS_and_SEP_false(tensor, mask_arr)
    rows = shape[0]

    def indices_to_mask(idx):
        """
        Returns a list containing all positions in the
        idx-th row of mask_arr that have to be masked.

        0 <= idx < rows has to be satisfied.
        """
        to_mask_positions = mask_arr[idx].nonzero()
        return torch.flatten(to_mask_positions).tolist()

    for i in range(rows):
        selection = indices_to_mask(i)
        tensor[i, selection] = MASKED

In [None]:
def save_labels_and_ids(
    sentences: list[str], path_to_save: str, masking_strategy: function
) -> None:
    """
    Saves the dictionary of input_ids, labels and attention_mask
    to memory at the specified path.
    """

    # get rid of empty sentences
    sentences = [sent for sent in sentences if sent]

    inputs = tokenizer(
        sentences,
        return_tensors="pt",
        max_length=tokenizer.model_max_length,
        padding=True,
        truncation=True,
    )

    # "labels" is the answer key
    inputs["labels"] = inputs["input_ids"].detach().clone()

    # corrupt "input_ids"
    mask(masking_strategy, inputs["input_ids"])

    torch.save(
        {
            "labels": inputs["labels"],
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
        },
        f"{path}",
    )