In [None]:
import warnings

warnings.filterwarnings("ignore")

import transformers
import torch

from datasets import load_metric
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from sklearn.model_selection import train_test_split
from src.data import load_main_dataset

In [None]:
transformers.set_seed(42)
df = load_main_dataset()
metric = load_metric("sacrebleu")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

First of all, we have to prepare the data for the training:
In initial dataset, we have to tokenize the input and put the output with `labels` key in the input dict.
Pay attention that tokenizer reproduces 2d tensors, but for the training stuff it should be 1d.

In [None]:
def preprocess(__input_text, __output_text):
    global tokenizer
    _prefix = "Detoxify: "
    _max_input_length = 1500
    _max_output_length = 1000

    __input_text = _prefix + __input_text
    __input_tensor = tokenizer(__input_text, return_tensors="pt", max_length=_max_input_length, truncation=True)

    __output_tensor = tokenizer(__output_text, return_tensors="pt", max_length=_max_output_length, truncation=True)

    __input_tensor["labels"] = __output_tensor["input_ids"][0]
    __input_tensor["input_ids"] = __input_tensor["input_ids"][0]
    __input_tensor["attention_mask"] = __input_tensor["attention_mask"][0]
    return __input_tensor

In [None]:
dataset = []
for i in range(len(df)):
    dataset.append(preprocess(df["reference"][i], df["translation"][i]))

Then let's split the data to evaluation and training sets.

In [None]:
split = 0.2
text_train, text_test = train_test_split(dataset, test_size=split, shuffle=True)

Finally, create model, arguments for the training and trainer.

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("SkolkovoInstitute/bart-base-detox").to(device)

In [None]:
batch_size = 32
args = Seq2SeqTrainingArguments(
    "Model fine-tuning",
    evaluation_strategy="epoch",
    learning_rate=0.001,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    report_to="tensorboard"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=text_train,
    eval_dataset=text_test,
    data_collator=data_collator,
    tokenizer=tokenizer
)

Unfortunately, I did not finish the training, but you can do it instead of me:)

In [None]:
trainer.train()
trainer.save_model("../models/best")

Everything for training is stored in `train` function.
You just need to load some model and tokenizer for that, and also load dataset.

Let's check if it works:)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("../models/best")
model.eval()

In [None]:
print(f"Toxic text: {df['reference'][3]}")

input_ids = tokenizer(df["reference"][3], return_tensors="pt").input_ids
outputs = model.generate(input_ids=input_ids)
nontoxic = tokenizer.decode(outputs[0], skip_special_tokens=True, temperature=0)

print(f"It's translation: {nontoxic}")
print(f"Metric score: ", metric.compute(predictions=[nontoxic], references=[[df["translation"][3]]]))

You may also try to detoxify anything using `detox` function. 