In [4]:
!pip install datasets
#!pip install evaluate
#!pip install rouge_score
!pip install wandb
!pip install transformers==4.48.3
#!pip install datasets==2.15.0
#!pip install evaluate==0.4.1
!pip install accelerate==0.27.2
#!pip install tqdm==4.66.1
!pip install peft==0.10.0
!pip install "numpy<2.0"

import transformers
import torch
import os

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
from datasets import load_dataset, concatenate_datasets
from time import time
import random


def sentence_permutation(text: str) -> str:
    """
    A document is divided into sentences based on full stops, and these sentences are shuffled in a random order.
    **This function operates on text strings.**
    :param sentence: The sentence to be permuted.
    :return: The permuted sentence.
    """
    sentences = text.split(".")
    permuted_sentences = torch.randperm(len(sentences))
    permuted_text = ""
    for i in permuted_sentences:
        if sentences[i] != "":
            permuted_text += sentences[i] + ". "
    return permuted_text.strip()


def token_infilling(
    tokenized_sequence: torch.Tensor,
    mask_token_id: int,
    mask_probability: float = 0.15,
    list_special_tokens: list = [],
) -> str:
    """
    A number of text spans are sampled, with span lengths drawn from a Poisson distribution (λ = 3).
    Each span is replaced with a single [MASK] token. 0-length spans correspond to the insertion of
    [MASK] tokens. Text infilling is inspired by SpanBERT (Joshi et al., 2019), but SpanBERT samples
    span lengths from a different (clamped geometric) distribution, and replaces each span with a
    sequence of [MASK] tokens of exactly the same length. Text infilling teaches the model to predict
    how many tokens are missing from a span.
    **This function operates on tokenized text.**
    :param text: The text to be infilled.
    :return: The infilled text.
    """
    span_length = int(torch.poisson(torch.tensor([3.0])))
    perturbed_ids = torch.empty(0, dtype=torch.long)
    if span_length > 0:
        for i in range(0, len(tokenized_sequence), span_length):
            if torch.rand(1) < mask_probability:
                # check if the span does not contain special tokens
                if not any(token in list_special_tokens for token in tokenized_sequence[i : i + span_length]):
                    perturbed_ids = torch.cat(
                        (perturbed_ids, torch.tensor([mask_token_id], dtype=torch.long))
                    )
            else:
                perturbed_ids = torch.cat(
                    (perturbed_ids, tokenized_sequence[i : i + span_length])
                )
    else:
        perturbed_ids = tokenized_sequence # if the span length is 0, the text is not perturbed
    return perturbed_ids


def token_masking(
    tokenized_sequence: torch.Tensor,
    mask_token_id: int,
    mask_probability: float = 0.15,
    list_special_tokens: list = [],
) -> str:
    """
    Random tokens are replaced with the [MASK] token. This task trains the model to predict the original value of the masked tokens.
    **This function operates on tokenized text.**
    :param text: The text to be masked.
    :return: The masked text.
    """
    for i in range(len(tokenized_sequence)):
        if torch.rand(1) < mask_probability:
            if tokenized_sequence[i] not in list_special_tokens:
                tokenized_sequence[i] = mask_token_id
    return tokenized_sequence


def token_deletion(
    tokenized_sequence: torch.Tensor,
    mask_token_id: int,
    mask_probability: float = 0.15,
    list_special_tokens: list = [],
) -> str:
    """
    Random tokens are deleted from the input. In contrast to token masking, the model must decide which positions are missing inputs.
    **This function operates on tokenized text.**
    :param text: The text to be token deleted.
    :return: The token deleted text.
    """
    delete_mask = torch.rand(len(tokenized_sequence)) < mask_probability
    tokenized_sequence = tokenized_sequence[~delete_mask]
    return tokenized_sequence


def document_rotation(text: str) -> str:
    """
    A token is chosen uniformly at random, and the document is rotated so that it begins with that token.
    This task trains the model to identify the start of the document.
    **This function operates on text strings.**
    :param text: The text to be rotated.
    :return: The rotated text.
    """
    text = text.split(" ")
    rotation_index = random.randint(0, len(text) - 1)
    rotated_text = text[rotation_index:] + text[:rotation_index]
    return " ".join(rotated_text)


# PARAMETERS BART BASE
# ==============================================================================
VOCAB_SIZE = 52000
MAX_POSITION_EMBEDDINGS = 1024
ENCODER_LAYERS = 6
ENCODER_FFN_DIM = 3072
ENCODER_ATTENTION_HEADS = 12
DECODER_LAYERS = 6
DECODER_FFN_DIM = 3072
DECODER_ATTENTION_HEADS = 12
D_MODEL = 768
DROPOUT = 0.1
# ==============================================================================
# PARAMETERS



# Initialize a BART-Base model
tokenizer = BartTokenizer.from_pretrained("/content/drive/MyDrive/Colab Notebooks/tokenizer_bart_la_m")


# Tiny version of BART
model = BartForConditionalGeneration(
    BartConfig(
        vocab_size=VOCAB_SIZE,
        max_position_embeddings=MAX_POSITION_EMBEDDINGS,
        encoder_layers=ENCODER_LAYERS,
        encoder_ffn_dim=ENCODER_FFN_DIM,
        encoder_attention_heads=ENCODER_ATTENTION_HEADS,
        decoder_layers=DECODER_LAYERS,
        decoder_ffn_dim=DECODER_FFN_DIM,
        decoder_attention_heads=DECODER_ATTENTION_HEADS,
        d_model=D_MODEL,
        dropout=DROPOUT,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        is_encoder_decoder=True,
        decoder_start_token_id=tokenizer.eos_token_id,
    )
)

dataset = load_dataset("json", split='train', data_files="/content/drive/MyDrive/Colab Notebooks/dataset/cc100_latin_cleaned_medium.json").with_format(type="torch")
dataset_split = dataset.train_test_split(test_size=0.2, train_size=0.8, shuffle=True)
print(dataset_split)
train_streaming_dataset = dataset_split['train']
eval_streaming_dataset = dataset_split['test']
print(train_streaming_dataset)
print(eval_streaming_dataset)

# perturbation in string: document_rotation, sentence_permutation
# perturbation in token : token_infilling, token_masking, token_deletion
perturbations = [
    document_rotation,
    sentence_permutation,
    token_infilling,
    token_masking,
    token_deletion,
]

perturbations_text_domain = [
    document_rotation,
    sentence_permutation,
]

perturbations_token_domain = [
    token_infilling,
    token_masking,
    token_deletion,
]


def collate_fn(examples):
    """
    Collate function to be used in the dataloader.
    It applies the perturbations to the examples and returns the batch.
    TODO: improve efficiency
    :param examples: list of examples
    :return: batch ready to be fed to the model
    """
    original_texts = [example["text"] for example in examples]

    input_ids = None
    for text in original_texts:
        perturbation_function = random.choice(perturbations)
        if perturbation_function in perturbations_text_domain:
            # need to truncate the text to 1024 tokens
            t_text = tokenizer(text, truncation=True, max_length=1024)
            text_truncated = tokenizer.decode(t_text["input_ids"], skip_special_tokens=True)
            perturbed_text = perturbation_function(text_truncated)
            perturbed_input_ids = tokenizer(
                perturbed_text, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_POSITION_EMBEDDINGS
            )["input_ids"][0]
        else:
            original_input_ids = tokenizer(
                text, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_POSITION_EMBEDDINGS
            )["input_ids"][0]
            perturbed_input_ids = perturbation_function(
                    tokenized_sequence=original_input_ids,
                    mask_token_id=tokenizer.mask_token_id,
                    mask_probability=0.15,
                    list_special_tokens=tokenizer.all_special_ids,
                )
            if perturbed_input_ids.shape[-1] < MAX_POSITION_EMBEDDINGS: # apply padding
                perturbed_input_ids = torch.cat(
                    (perturbed_input_ids, torch.full((MAX_POSITION_EMBEDDINGS - perturbed_input_ids.shape[-1],),
                    tokenizer.pad_token_id,
                    dtype=torch.long)))
            perturbed_input_ids = torch.squeeze(perturbed_input_ids, dim=0)

        if input_ids is None:
            input_ids = perturbed_input_ids.unsqueeze(0)
        else:
            input_ids = torch.cat((input_ids, perturbed_input_ids.unsqueeze(0)), dim=0)

    tokenized_examples = {}
    # update the tokenized examples with the perturbed input ids and convert to tensors
    tokenized_examples["input_ids"] = input_ids
    # update the attention mask
    tokenized_examples["attention_mask"] = [
        [1 if token_id != tokenizer.pad_token_id else 0 for token_id in input_ids]
        for input_ids in tokenized_examples["input_ids"]
    ]
    tokenized_examples["attention_mask"] = torch.tensor(tokenized_examples["attention_mask"])

    tokenized_examples["labels"] = tokenizer(
        original_texts, padding="max_length", truncation=True, max_length=MAX_POSITION_EMBEDDINGS, return_tensors="pt"
    )["input_ids"]

    return tokenized_examples


# total_steps (1 epoch, see it5) = 103_000_000 / 64 = 1_609_375 -- 1_700_000
# warmup_steps = 1_700_000 * 0.01 = 17_000

# Prepare training arguments
training_args = transformers.TrainingArguments(
    output_dir="/content/drive/MyDrive/Colab Notebooks/BART/bart-la-size-m",
    overwrite_output_dir=True,
    auto_find_batch_size=True,
    num_train_epochs=3,
    warmup_steps=2000,
    weight_decay=0.01,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_dir="/content/drive/MyDrive/Colab Notebooks/BART/logs-bart-it-size-m",
    logging_steps=300,
    save_total_limit=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    remove_unused_columns=False,
    fp16=True,
    #tpu_num_cores=8,
    dataloader_num_workers=12,
    learning_rate=1e-4,
    report_to="none"
)

# Initialize the trainer

trainer = transformers.Trainer(
    model=model,
    args=training_args,
    train_dataset=train_streaming_dataset,
    eval_dataset=eval_streaming_dataset,
    data_collator=collate_fn,
)

# Train the model
print("train started")
trainer.train()

# Evaluate the model
print(trainer.evaluate(eval_streaming_dataset))

# Save the model
trainer.save_model("/content/drive/MyDrive/Colab Notebooks/BART/bart-la-m")

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2400000
    })
    test: Dataset({
        features: ['text'],
        num_rows: 600000
    })
})
Dataset({
    features: ['text'],
    num_rows: 2400000
})
Dataset({
    features: ['text'],
    num_rows: 600000
})
train started


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
