[Reference](https://pub.towardsai.net/token-masking-strategies-for-llms-d2e6c926b22d)

In [None]:
pip install stanza

Collecting stanza
  Downloading stanza-1.8.1-py3-none-any.whl (970 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m970.4/970.4 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting emoji (from stanza)
  Downloading emoji-2.11.0-py2.py3-none-any.whl (433 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m433.8/433.8 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.3.0->stanza)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.3.0->stanza)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nv

In [None]:
import stanza
stanza.download('en')

# Text used in our examples
text = "Huntington's disease is a neurodegenerative autosomal disease
results due to expansion of polymorphic CAG repeats in the huntingtin gene.
Phosphorylation of the translation initiation factor 4E-BP results in the
alteration of the translation control leading to unwanted protein synthesis
and neuronal function. Consequences of mutant huntington (mhtt) gene
transcription are not well known. Variability of age of onset is an
important factor of Huntington's disease separating adult and juvenile types.
The factors which are taken into account are-genetic modifiers, maternal
protection i.e excessive paternal transmission, superior ageing genes
and environmental threshold. A major focus has been given to the molecular
pathogenesis which includes-motor disturbance, cognitive disturbance and
neuropsychiatric disturbance. The diagnosis part has also been taken care of.
This includes genetic testing and both primary and secondary symptoms.
The present review also focuses on the genetics and pathology of Huntington's
disease."


# We will use a stanza model for getting each different sentence
# as an element of the list
nlp = stanza.Pipeline('en', use_gpu=False)
doc = nlp(text)
sentences = [sentence.text for sentence in doc.sentences]

In [None]:
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
import torch

def load_dataset_mlm(sentences, tokenizer_class=AutoTokenizer,
                     collator_class=DataCollatorForLanguageModeling,
                     mlm=True, mlm_probability=0.20):
    tokenizer = tokenizer_class.from_pretrained('google-bert/bert-base-uncased')
    inputs = tokenizer(sentences, return_tensors='pt', padding=True,
                       truncation=True)

    # Random masking configuration
    data_collator = collator_class(
        tokenizer=tokenizer,
        mlm=mlm,
        mlm_probability=mlm_probability
    )

    """The collator expects a tuple of tensors, so you have to split
    the input tensors and then remove the first dimension and pass it
    to a tuple. """
    tuple_ids = torch.split(inputs['input_ids'], 1, dim=0)
    tuple_ids = list(tuple_ids)
    for tensor in range(len(tuple_ids)):
        tuple_ids[tensor] = tuple_ids[tensor].squeeze(0)
    tuple_ids = tuple(tuple_ids)

    # Get input_ids, attention_masks and labels for each sentence.
    batch = data_collator(tuple_ids)
    return batch['input_ids'], inputs['attention_mask'], batch['labels']


input_ids, attention_mask, labels = load_dataset_mlm(sentences)

In [None]:
from transformers import BartTokenizer, DataCollatorForLanguageModeling
import torch

def load_dataset_mlm(sentences, tokenizer_class=BartTokenizer,
                     collator_class=DataCollatorForLanguageModeling,
                     mlm=True, mlm_probability=0.20):
    tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
    inputs = tokenizer(sentences, return_tensors='pt', padding=True,
                       truncation=True)

    # Random masking configuration
    data_collator = collator_class(
        tokenizer=tokenizer,
        mlm=mlm,  # True for Masked Language Modelling
        mlm_probability=mlm_probability  # Chance for every token to get masked
    )

    """The collator expects a tuple of tensors, so you have to split
    the input tensors and then remove the first dimension and pass it
    to a tuple. """
    tuple_ids = torch.split(inputs['input_ids'], 1, dim=0)
    tuple_ids = list(tuple_ids)
    for tensor in range(len(tuple_ids)):
        tuple_ids[tensor] = tuple_ids[tensor].squeeze(0)
    tuple_ids = tuple(tuple_ids)

    # Get input_ids, attention_masks and labels for each sentence.
    batch = data_collator(tuple_ids)
    batch['labels'] = inputs['input_ids']
    return batch['input_ids'], inputs['attention_mask'],  batch['labels']

input_ids, attention_mask, labels = load_dataset_mlm(sentences)


In [None]:
def token_deletion(sentences, tokenizer_class=BartTokenizer, collator_class=DataCollatorForLanguageModeling,
                 mlm=True, mlm_probability=0.20):
    tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
    inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)

    data_collator = collator_class(
        tokenizer=tokenizer,
        mlm=mlm,
        mlm_probability=mlm_probability
    )

    tuple_ids = torch.split(inputs['input_ids'], 1, dim=0)
    tuple_ids = list(tuple_ids)
    for tensor in range(len(tuple_ids)):
        tuple_ids[tensor] = tuple_ids[tensor].squeeze(0)
    tuple_ids = tuple(tuple_ids)

    batch = data_collator(tuple_ids)

    # We use the initial inputs as labels
    batch['labels'] = batch['input_ids'].clone()

    # We remove tokens with mask identifier and thus make token deletion
    # Change the value to the mask identifier of the specific token model
    # It is necessary to know the identifier of the mask token for
    # that specific model
    mask = batch['input_ids'] != 50264
    initial_size = batch['input_ids'].size(1)
    total_sentences = batch['input_ids'].size(0)

    # When we remove the specific token, we must fill with the padding
    # token otherwise the tensor size is not respected.
    for i in range(total_sentences):
        new_tensor = batch['input_ids'][i][mask[i]]
        new_tensor = F.pad(new_tensor, (0, initial_size - new_tensor.size(0)), value=1)
        batch['input_ids'][i] = new_tensor
        attention_mask = batch['input_ids'][i] == 1
        inputs['attention_mask'][i][attention_mask] = 0

    return batch['input_ids'], inputs['attention_mask'], batch['labels']

input_ids, attention_mask, labels = token_deletion(sentences)

In [None]:
import numpy as np
from transformers import BartTokenizer

def text_infilling(sentence, probability=0.2, poisson_lambda=3):
    # We'll use a binary mask to determine which words to replace
    mask = np.random.choice([0, 1], size=len(sentence), p=[1-probability, probability])

    # Now we'll replace the chosen words with a mask token
    # We'll also use a Poisson distribution to determine the length of the spans to mask
    for i in range(len(mask)):
        if mask[i] == 1:
            span_length = np.random.poisson(poisson_lambda)
            for j in range(span_length):
                if i + j < len(sentence):
                    sentence[i + j] = "<mask>"

    infilled_sentence = []
    for token in range(len(sentence)):
        if sentence[token] == "<mask>":
            if token < len(sentence)-1:
                if sentence[token+1] == "<mask>":
                    continue
                else:
                    infilled_sentence.append(sentence[token])
            else:
                infilled_sentence.append(sentence[token])
        else:
            infilled_sentence.append(sentence[token])
    return " ".join(infilled_sentence)

def text_infilling_input(masked_sentences, sentences, tokenizer_class=BartTokenizer):
    tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
    inputs = tokenizer(masked_sentences, return_tensors='pt', padding=True, truncation=True)
    labels = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)
    return inputs['input_ids'], inputs['attention_mask'], labels['input_ids']

input_ids, attention_mask, labels = text_infilling_input(masked_sentences, sentences)

In [None]:
# It selects the first "number_sentences" within a given set of "sentences"
# and returns those sentences in a random order.
def sentence_permutation(sentences, number_sentences):
    new_sentences = sentences[:number_sentences]
    random.shuffle(new_sentences)
    new_sentences = sentence_joiner(new_sentences)
    return new_sentences

def permuted_data_generation(sentences: list, total_sentences: int):
    training_sentences = []
    training_labels = []
    sentences_copy = sentences.copy()
    # We can apply sentence_permutation a number of times equal to the
    # size of the list - 1 to get an example with each new sentence in
    # the text, removing the oldest one.
    for _ in range(len(sentences)-total_sentences+1):
        new_sentences = sentence_permutation(sentences_copy, total_sentences)
        joined_sentences = sentence_joiner(sentences_copy[:total_sentences])
        sentences_copy = sentences_copy[1:]
        training_sentences.append(new_sentences)
        training_labels.append(joined_sentences)

    return training_sentences, training_labels


def permutation_training(sentences: list, sentences_labels: list,
                         tokenizer_class=BartTokenizer,
                         collator_class=DataCollatorForLanguageModeling,
                        mlm=True, mlm_probability=0.0):
    # We get input_ids and attention mask from the permuted sentences
    input, attention_mask, _ = load_dataset_mlm(sentences, tokenizer_class, collator_class, mlm, mlm_probability)

    # Labels from the original sentences
    labels, _, _ = load_dataset_mlm(sentences_labels, tokenizer_class, collator_class, mlm, mlm_probability)

    return input.squeeze(0), attention_mask.squeeze(0), labels.squeeze(0)

input_ids, attention_mask, labels = permutation_training(training_sentences, training_labels_sentences)

In [None]:
def sentence_joiner(sentences: list):
  return ' '.join(sentences)

# With this function we gather as many sentences as we want to form the input data to the tokenizer.
def rotated_data_generation(sentences: list, total_sentences: int):
  training_sentences = []
  sentences_copy = sentences.copy()
  for _ in range(len(sentences)-total_sentences+1):
    new_sentences = sentences_copy[:total_sentences]
    new_sentences = sentence_joiner(new_sentences)
    sentences_copy = sentences_copy[1:]
    training_sentences.append(new_sentences)
  return training_sentences

# Apply this function over the rotated sentences from previous function
def document_rotation_training(sentences, tokenizer_class=BartTokenizer):
  tokenizer = tokenizer_class.from_pretrained('facebook/bart-base')
  tokens = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)
  tokens['input_ids'] = tokens['input_ids'].squeeze(0)
  tokens['labels'] = tokens['input_ids'].clone()

  iterations = tokens['input_ids'].size(0)
  for i in range(iterations):
    # Get the attention mask and convert to list
    attention_mask = tokens['attention_mask'][i].tolist()
    # Calculate the position where padding starts
    if 0 in attention_mask:
      padding_start_position = attention_mask.index(0)
    else:
      padding_start_position = False
    # We take into account the position of the padding so as not to rotate it along with the rest of the document.
    if padding_start_position:
      random_token = torch.randint(1, padding_start_position-1, (1,))
      tokens['input_ids'][i] = torch.cat((tokens['input_ids'][i][0].unsqueeze(0), #initial token
                                      tokens['input_ids'][i][random_token.item():padding_start_position-1], #from random to padding
                                      tokens['input_ids'][i][1:random_token.item()], #from 1 to random
                                      tokens['input_ids'][i][padding_start_position-1:-1],
                                      tokens['input_ids'][i][-1].unsqueeze(0)), 0)

    # If there is no padding, we rotate the document without taking the padding into account.
    else:
      random_token = torch.randint(1, tokens['input_ids'].size(0)-1, (1,))
      tokens['input_ids'][i] = torch.cat((tokens['input_ids'][i][0].unsqueeze(0), #initial token
                                      tokens['input_ids'][i][random_token.item():-1], #from random to end
                                      tokens['input_ids'][i][1:random_token.item()],
                                      tokens['input_ids'][i][-1].unsqueeze(0)), 0)
  return tokens['input_ids'], tokens['attention_mask'].squeeze(0), tokens['labels']

data = rotated_data_generation(sentences, 3)
input_ids, attention_mask, labels = document_rotation_training(data)