In [1]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
import warnings

warnings.filterwarnings("ignore")

In [2]:
df = pd.read_csv('../data/raw/filtered.tsv', sep='\t')
sents = df[(df['similarity'] < 0.7) & (df['ref_tox'] > df['trn_tox'])]
sents = sents[['reference', 'translation']]
toxic_sentences = sents['reference'].tolist()
non_toxic_sentences = sents['translation'].tolist()

In [3]:
model_name = "t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [4]:
# model_name = "bert-base-uncased"
# encoder = BertGenerationEncoder.from_pretrained(model_name, bos_token_id=101, eos_token_id=102)
# decoder = BertGenerationDecoder.from_pretrained(model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
# bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
# bert2bert.config.decoder_start_token_id = 101
# bert2bert.config.pad_token_id = 0
# tokenizer = BertTokenizer.from_pretrained(model_name)
# data_collator = DataCollatorForSeq2Seq(tokenizer, model=bert2bert)

In [5]:
MAX_LEN = 128
prefix = 'Detoxify text: '

def preprocess_text(examples):
    inputs = [prefix + ex for ex in examples['text']]
    target = [ex for ex in examples['labels']]

    batch = tokenizer(inputs, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt')
    batch["labels"] = tokenizer(target, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt').input_ids

    return batch

In [6]:
dataset = Dataset.from_dict({"text": toxic_sentences, "labels": non_toxic_sentences})
dataset = dataset.map(preprocess_text, batched=True)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset = dataset.select(range(train_size))
val_dataset = dataset.select(range(train_size, train_size + val_size))

Map:   0%|          | 0/101535 [00:00<?, ? examples/s]

In [69]:
def detoxificate_text(text, model, tokenizer):
    text = prefix + text
    test_input = tokenizer(text, return_tensors='pt')
    input_ids = test_input.input_ids
    with torch.no_grad():
        output = model.generate(input_ids=input_ids, max_length=MAX_LEN)
    non_toxic_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return non_toxic_text

In [8]:
text = "You gotta be shitting me."

print(text)

print(detoxificate_text(text, model.to("cpu"), tokenizer))


You gotta be shitting me.
tensor([[    0, 10747,     7,    15,     1]])
<pad> False</s>


In [15]:
training_args = Seq2SeqTrainingArguments(
    "../models/",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    predict_with_generate=True,
    save_total_limit=3,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

  0%|          | 0/17136 [00:00<?, ?it/s]

{'loss': 0.2574, 'learning_rate': 1.9416433239962655e-05, 'epoch': 0.09}
{'loss': 0.2506, 'learning_rate': 1.8832866479925304e-05, 'epoch': 0.18}
{'loss': 0.2499, 'learning_rate': 1.8249299719887958e-05, 'epoch': 0.26}
{'loss': 0.2437, 'learning_rate': 1.7665732959850607e-05, 'epoch': 0.35}
{'loss': 0.2455, 'learning_rate': 1.708216619981326e-05, 'epoch': 0.44}
{'loss': 0.2495, 'learning_rate': 1.649859943977591e-05, 'epoch': 0.53}
{'loss': 0.2455, 'learning_rate': 1.5915032679738563e-05, 'epoch': 0.61}
{'loss': 0.2471, 'learning_rate': 1.5331465919701216e-05, 'epoch': 0.7}
{'loss': 0.2423, 'learning_rate': 1.4747899159663868e-05, 'epoch': 0.79}
{'loss': 0.2438, 'learning_rate': 1.416433239962652e-05, 'epoch': 0.88}
{'loss': 0.2413, 'learning_rate': 1.358076563958917e-05, 'epoch': 0.96}


  0%|          | 0/635 [00:00<?, ?it/s]

{'eval_loss': 0.2224476933479309, 'eval_runtime': 83.8485, 'eval_samples_per_second': 121.099, 'eval_steps_per_second': 7.573, 'epoch': 1.0}
{'loss': 0.2396, 'learning_rate': 1.2997198879551822e-05, 'epoch': 1.05}
{'loss': 0.2406, 'learning_rate': 1.2413632119514474e-05, 'epoch': 1.14}
{'loss': 0.2403, 'learning_rate': 1.1830065359477125e-05, 'epoch': 1.23}
{'loss': 0.2376, 'learning_rate': 1.1246498599439776e-05, 'epoch': 1.31}
{'loss': 0.2382, 'learning_rate': 1.0662931839402428e-05, 'epoch': 1.4}
{'loss': 0.2425, 'learning_rate': 1.007936507936508e-05, 'epoch': 1.49}
{'loss': 0.2376, 'learning_rate': 9.49579831932773e-06, 'epoch': 1.58}
{'loss': 0.2413, 'learning_rate': 8.912231559290384e-06, 'epoch': 1.66}
{'loss': 0.2389, 'learning_rate': 8.328664799253035e-06, 'epoch': 1.75}
{'loss': 0.2362, 'learning_rate': 7.745098039215687e-06, 'epoch': 1.84}
{'loss': 0.2381, 'learning_rate': 7.161531279178339e-06, 'epoch': 1.93}


  0%|          | 0/635 [00:00<?, ?it/s]

{'eval_loss': 0.21875521540641785, 'eval_runtime': 83.8575, 'eval_samples_per_second': 121.086, 'eval_steps_per_second': 7.572, 'epoch': 2.0}
{'loss': 0.2409, 'learning_rate': 6.5779645191409904e-06, 'epoch': 2.01}
{'loss': 0.2385, 'learning_rate': 5.994397759103642e-06, 'epoch': 2.1}
{'loss': 0.2358, 'learning_rate': 5.410830999066293e-06, 'epoch': 2.19}
{'loss': 0.2354, 'learning_rate': 4.8272642390289456e-06, 'epoch': 2.28}
{'loss': 0.237, 'learning_rate': 4.243697478991597e-06, 'epoch': 2.36}
{'loss': 0.2349, 'learning_rate': 3.6601307189542484e-06, 'epoch': 2.45}
{'loss': 0.238, 'learning_rate': 3.0765639589169007e-06, 'epoch': 2.54}
{'loss': 0.2359, 'learning_rate': 2.492997198879552e-06, 'epoch': 2.63}
{'loss': 0.2343, 'learning_rate': 1.9094304388422036e-06, 'epoch': 2.71}
{'loss': 0.2396, 'learning_rate': 1.3258636788048554e-06, 'epoch': 2.8}
{'loss': 0.235, 'learning_rate': 7.42296918767507e-07, 'epoch': 2.89}
{'loss': 0.2361, 'learning_rate': 1.5873015873015874e-07, 'epoch':

  0%|          | 0/635 [00:00<?, ?it/s]

{'eval_loss': 0.21781419217586517, 'eval_runtime': 84.768, 'eval_samples_per_second': 119.786, 'eval_steps_per_second': 7.491, 'epoch': 3.0}
{'train_runtime': 7031.5609, 'train_samples_per_second': 38.988, 'train_steps_per_second': 2.437, 'train_loss': 0.24080748106139938, 'epoch': 3.0}


TrainOutput(global_step=17136, training_loss=0.24080748106139938, metrics={'train_runtime': 7031.5609, 'train_samples_per_second': 38.988, 'train_steps_per_second': 2.437, 'train_loss': 0.24080748106139938, 'epoch': 3.0})

In [23]:
model.save_pretrained("../models/dotixificator")
tokenizer.save_pretrained("../models/dotixificator")

('../models/dotixificator\\tokenizer_config.json',
 '../models/dotixificator\\special_tokens_map.json',
 '../models/dotixificator\\spiece.model',
 '../models/dotixificator\\added_tokens.json',
 '../models/dotixificator\\tokenizer.json')

In [79]:
import random

random_toxic_sentences = random.sample(toxic_sentences, 5)

for text in random_toxic_sentences:
    print(text)
    print(detoxificate_text(text, model.to("cpu"), tokenizer))
    print()

If you wanna be there, you're gonna have to go my fucking way.
if you want to be there, you'll have to go.

Come on, get up. Don't be silly.
come on, get up, don't be silly.

I live in hope that it may pass, but that wretched man persists in writing to her.
I live in hope that it will pass, but the wretched man is still writing to her.

I'll pray for your soul, sue your ass into the dirt, and wait for the day I can beat it bloody.
I'll pray for your soul, I'll sue you in the dirt, and I'll wait for the day I'll beat it.

Fuck. No one wants to see Proctor go down more than her.
no one wants to see Proctor go down more than her.

