### Imports

In [1]:
import pandas as pd
import torch
from datasets import concatenate_datasets, load_dataset
from datasets.dataset_dict import DatasetDict
from tqdm import tqdm
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
                          DataCollatorForSeq2Seq, Trainer, TrainingArguments)

### Pretrained model

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/mt5-xl")
model = AutoModelForSeq2SeqLM.from_pretrained("textdetox/mt5-xl-detox-baseline")

### Load data

In [10]:
data = load_dataset("textdetox/multilingual_paradetox")
tmp = [data[key].train_test_split(test_size=0.2, seed=42) for key in data]
dataset = DatasetDict({key: concatenate_datasets([tmp[i][key] for i in range(len(tmp))]) for key in ['test', 'train']})

In [11]:
dataset_test = load_dataset("textdetox/multilingual_paradetox_test")

In [12]:
def preprocess_fuction(dataset):
    encodings = tokenizer(dataset['toxic_sentence'])
    return dict(
        input_ids = encodings['input_ids'],
        attention_mask = encodings['attention_mask'],
        labels = tokenizer(dataset['neutral_sentence'])['input_ids'],
    )

In [13]:
train_encodings = preprocess_fuction(dataset['train'])
val_encodings = preprocess_fuction(dataset['test'])

In [14]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings['input_ids'])

train_dataset = CustomDataset(train_encodings)
val_dataset = CustomDataset(val_encodings)

### Train

In [42]:
training_args = TrainingArguments(
    per_device_train_batch_size=20,
    per_device_eval_batch_size=20,
    num_train_epochs=30,
    logging_dir='./logs',
    evaluation_strategy='steps',  # Evaluate at the end of each epoch
    load_best_model_at_end=True,
    output_dir='./results',
    logging_steps=500,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,  # Validate on the test dataset
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
)


In [None]:
# Train the model
trainer.train()


### Submit

In [15]:
df = pd.read_csv('/main/sample_submission_test.tsv', sep='\t')

In [None]:
detox = []
for i in tqdm(range(len(df))):
    text = df['toxic_sentence'].iloc[i]
    input_embed = tokenizer(text, return_tensors='pt').to(model.device)
    detox.append(tokenizer.decode(model.generate(**input_embed)[0], skip_special_tokens=True))

In [17]:
df['neutral_sentence'] = detox

In [18]:
df.to_csv('/main/mt5_sub_test.tsv', sep='\t', index=0)