In [40]:
import torch
import re
from typing import Callable
from transformers import BertTokenizer
from masking_strategies import mask_random_positions, mask_end_positions

In [41]:
def assure_CLS_and_SEP_false(id_tensor, true_false_tensor, tokenizer):
    """
    Receives a tensor of ids and a tensor of masking, with
    True and False values. Returns a copy of the true_false_tensor
    where every CLS and SEP position is for sure set to False.

    Doesn't mutate any input value and it's free of aliasing.
    """
    CLS = tokenizer.convert_tokens_to_ids("[CLS]")
    SEP = tokenizer.convert_tokens_to_ids("[SEP]")

    ans = true_false_tensor.detach().clone()
    ans = ans & (id_tensor != CLS) & (id_tensor != SEP)
    return ans

In [42]:
def mask(masking_strategy, tensor, tokenizer):
    """
    Mutates tensor accordingly to the masking_strategy.
    
    Args:
        masking_strategy: strategy to create the masked
                          input_id.
        tensor: tensor of tensors, the last representing
                the tokenized inputs.
        tokenizer: tokenizer to use.

    Returns:
        None
    """
    shape = tensor.shape
    mask_arr = masking_strategy(tensor)
    mask_arr = assure_CLS_and_SEP_false(tensor, mask_arr, tokenizer)
    rows = shape[0]

    MASKED = tokenizer.convert_tokens_to_ids('[MASKED]')

    def indices_to_mask(idx):
        """
        Returns a list containing all positions in the
        idx-th row of mask_arr that have to be masked.

        0 <= idx < rows has to be satisfied.
        """
        to_mask_positions = mask_arr[idx].nonzero()
        return torch.flatten(to_mask_positions).tolist()

    for i in range(rows):
        selection = indices_to_mask(i)
        tensor[i, selection] = MASKED

In [43]:
def save_labels_and_ids(
    sentences: list[str], path_to_save: str, masking_strategy: Callable[[torch.Tensor], None], tokenizer
) -> None:
    """
    Saves the dictionary of input_ids, labels and attention_mask
    to memory at the specified path.

    Args:
        sentences: list of all the sentences to tokenize.
        path_to_save: path to save the tokenized inputs.
        masking_strategy:
        tokenizer:
    """
    inputs = tokenizer(
        sentences,
        return_tensors="pt",
        max_length=tokenizer.model_max_length,
        padding=True,
        truncation=True,
    )

    # "labels" is the answer key
    inputs["labels"] = inputs["input_ids"].detach().clone()

    # corrupt "input_ids"
    mask(masking_strategy, inputs["input_ids"], tokenizer)

    # sanity check
    assert(not torch.equal(inputs['labels'], inputs['input_ids']))

    # finally save the object to memory
    torch.save(
        {
            "labels": inputs["labels"],
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
        },
        path_to_save,
    )

In [44]:
# load huge file full of sentences
sentences = []

with open('../../data/portuguese_sentences.txt', 'r') as f:
    for sentence in f:
        sentences.append(sentence)

In [None]:
# get rid of empty sentences
sentences = [sent for sent in sentences if sent]

In [45]:
print(len(sentences))

2149297


# Tokenizer

In [46]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# 15% of random positions

In [47]:
save_labels_and_ids(
    sentences,
    "../../data/training_tokenized_random_positions.pt",
    mask_random_positions,
    tokenizer,
)

# End tokens

I want to chop several sentences so the model gets used to sentences lacking full meaning.

In [None]:
third_part = len(sentences)//3
new_sentences = []

for sentence in sentences[:third_part]:
    words = list(re.finditer(r'\w+', sentence))

    third_part_words = len(words)//3

    for ix in [third_part_words, 2*third_part_words, 3*third_part_words]:
        real_index = min(ix, len(words))
        start = words[real_index].start()
        
        if start:
            new_sentences.append(sentence[:start])

In [None]:
save_labels_and_ids(
    sentences[:third_part]+new_sentences,
    "../../data/training_tokenized_end_positions.pt",
    mask_end_positions,
    tokenizer,
)