In [1]:
import warnings

warnings.filterwarnings("ignore")

import transformers
import torch

from datasets import load_metric
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from sklearn.model_selection import train_test_split
from src.data import load_main_dataset

In [2]:
transformers.set_seed(42)
df = load_main_dataset()
metric = load_metric("sacrebleu")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

In [4]:
def preprocess(__input_text, __output_text):
    global tokenizer
    _prefix = "Detoxify: "
    _max_input_length = 1500
    _max_output_length = 1000

    __input_text = _prefix + __input_text
    __input_tensor = tokenizer(__input_text, return_tensors="pt", max_length=_max_input_length, truncation=True)

    __output_tensor = tokenizer(__output_text, return_tensors="pt", max_length=_max_output_length, truncation=True)

    __input_tensor["labels"] = __output_tensor["input_ids"][0]
    __input_tensor["input_ids"] = __input_tensor["input_ids"][0]
    __input_tensor["attention_mask"] = __input_tensor["attention_mask"][0]
    return __input_tensor

In [5]:
dataset = []
for i in range(len(df)):
    dataset.append(preprocess(df["reference"][i], df["translation"][i]))

In [6]:
dataset[0]

{'input_ids': tensor([    0, 43170,  4325,  4591,    35,   114,   726,  8722, 12530,    69,
           19,    69,  2536,  3844,     6,    24,    74,  3922,     5,   239,
         1389,     9, 44755, 44370,     4,     2]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1]), 'labels': tensor([    0,  1106,   726,  8722,    16,  5681,    69,    19, 39297,  3844,
            6,    14,  4529,     5,   239,   672,     9, 44755,  8974,  2696,
            4,     2])}

In [7]:
split = 0.2
text_train, text_test = train_test_split(dataset, test_size=split, shuffle=True)

In [8]:
model = AutoModelForSeq2SeqLM.from_pretrained("SkolkovoInstitute/bart-base-detox").to(device)

In [9]:
batch_size = 32
args = Seq2SeqTrainingArguments(
    "Model fine-tuning",
    evaluation_strategy="epoch",
    learning_rate=0.001,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    report_to="tensorboard"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=text_train,
    eval_dataset=text_test,
    data_collator=data_collator,
    tokenizer=tokenizer
)

In [None]:
trainer.train()
trainer.save_model("../models/best")

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss


In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("../models/best")
model.eval()

In [None]:
model()