In [1]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import BertForMaskedLM, BertTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset

In [2]:
MAX_LEN = 128

In [3]:
# class DetoxDataset(Dataset):
#     def __init__(self, texts, labels, tokenizer):
#         self.texts = texts
#         self.labels = labels
#         self.tokenizer = tokenizer

#     def __len__(self):
#         return len(self.texts)

#     def __getitem__(self, idx):
#         text = self.texts[idx]
#         label = self.labels[idx]

#         inputs = self.tokenizer(text, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt')
#         inputs["labels"] = self.tokenizer(label, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt').input_ids
#         return inputs

In [51]:
df = pd.read_csv('../data/raw/filtered.tsv', sep='\t')
sents = df[(df['similarity'] < 0.7) & (df['ref_tox'] > df['trn_tox'])]
sents = sents[['reference', 'translation']]
toxic_sentences = sents['reference'].tolist()
non_toxic_sentences = sents['translation'].tolist()

In [5]:
offensive_words = open('../data/external/offensive_words.txt').read().split('\n')
toxic_words = open('../data/external/toxic_words.txt').read().split('\n')
toxic_words.extend(offensive_words)

toxic_words = [w for w in toxic_words if w.isalnum() and len(w) > 1]

print(len(toxic_words))

5038


In [41]:
def replace_toxic_words(text):
    text = text.lower()
    for word in toxic_words:
        text = text.replace(word, '[MASK]')
    return text

In [7]:
data = []
labels = []

for i in tqdm(range(len(toxic_sentences))):
    toxic_sentences[i] = replace_toxic_words(toxic_sentences[i])
    if '[MASK]' in toxic_sentences[i]:
        data.append(toxic_sentences[i])
        labels.append(non_toxic_sentences[i])

100%|██████████| 101535/101535 [01:04<00:00, 1585.06it/s]


In [8]:
import random 
random_toxic_sentences = random.sample(data, 10)
print('\n'.join(random_toxic_sentences))

I got the [MASK] kicked out of me, Albert. That's what happened.
All right, give me three grams of [MASK]man.
He put [MASK]ers in Sid's nursery and let him watch them [MASK]ing each other.
Don't let any of these [MASK]holes get th[MASK].
You're [MASK].
There's so much [MASK] in here.
Corner of Bum[MASK] and you Got a Purty Mouth.
John, get your t[MASK] out of here.
Would you rather she spent her life with Appleby [MASK] her?
I ain't going to [MASK] for your [MASK] [MASK].


In [9]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
def group_texts(examples):
    inputs = [ex for ex in examples['text']]
    target = [ex for ex in examples['labels']]

    batch = tokenizer(inputs, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt')
    batch["labels"] = tokenizer(target, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt').input_ids

    return batch

In [18]:
dataset = Dataset.from_dict({"text": data, "labels": labels})

In [19]:
dataset = dataset.map(group_texts, batched=True)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset = dataset.select(range(train_size))
val_dataset = dataset.select(range(train_size, train_size + val_size))

Map:   0%|          | 0/84837 [00:00<?, ? examples/s]

In [65]:
def detoxificate_text(text, model, tokenizer):
    test_input = tokenizer(replace_toxic_words(text), padding='max_length', max_length=128, truncation=True, return_tensors='pt')
    input_ids = test_input.input_ids
    with torch.no_grad():
        output = model(**test_input)
    mask_idxs = torch.where(test_input['input_ids'][0] == tokenizer.mask_token_id)
    mask_token_logits = output.logits[0, mask_idxs[0]]
    top_tokens = torch.topk(mask_token_logits, 100, dim=1).indices.tolist()
    for i in range(len(top_tokens)):
        for token in top_tokens[i]:
            if tokenizer.decode([token]) not in toxic_words:
                input_ids[0][mask_idxs[0][i]] = token
                break

    non_toxic_text = tokenizer.decode(input_ids[0]).replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '').strip()

    return non_toxic_text

In [66]:
text = "Don't worry about me"

print(text)

print(replace_toxic_words(text))

print(detoxificate_text(text, model, tokenizer))


Don't worry about me
don't [MASK] about me
don't talk about me


In [22]:
training_args = TrainingArguments(
    output_dir="../models",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator
)

trainer.train()



  0%|          | 0/14319 [00:00<?, ?it/s]

{'loss': 2.7494, 'learning_rate': 1.9301627208603956e-05, 'epoch': 0.1}
{'loss': 2.5851, 'learning_rate': 1.8603254417207907e-05, 'epoch': 0.21}
{'loss': 2.4498, 'learning_rate': 1.790488162581186e-05, 'epoch': 0.31}
{'loss': 2.4943, 'learning_rate': 1.7206508834415813e-05, 'epoch': 0.42}
{'loss': 2.4539, 'learning_rate': 1.6508136043019764e-05, 'epoch': 0.52}
{'loss': 2.4589, 'learning_rate': 1.5809763251623718e-05, 'epoch': 0.63}
{'loss': 2.4751, 'learning_rate': 1.5111390460227671e-05, 'epoch': 0.73}
{'loss': 2.3787, 'learning_rate': 1.4413017668831624e-05, 'epoch': 0.84}
{'loss': 2.3844, 'learning_rate': 1.3714644877435577e-05, 'epoch': 0.94}


  0%|          | 0/531 [00:00<?, ?it/s]

{'eval_loss': 2.267263889312744, 'eval_runtime': 134.0451, 'eval_samples_per_second': 63.292, 'eval_steps_per_second': 3.961, 'epoch': 1.0}
{'loss': 2.3166, 'learning_rate': 1.301627208603953e-05, 'epoch': 1.05}
{'loss': 2.3021, 'learning_rate': 1.2317899294643482e-05, 'epoch': 1.15}
{'loss': 2.2953, 'learning_rate': 1.1619526503247433e-05, 'epoch': 1.26}
{'loss': 2.2953, 'learning_rate': 1.0921153711851386e-05, 'epoch': 1.36}
{'loss': 2.2989, 'learning_rate': 1.022278092045534e-05, 'epoch': 1.47}
{'loss': 2.2951, 'learning_rate': 9.524408129059293e-06, 'epoch': 1.57}
{'loss': 2.2183, 'learning_rate': 8.826035337663246e-06, 'epoch': 1.68}
{'loss': 2.2355, 'learning_rate': 8.127662546267197e-06, 'epoch': 1.78}
{'loss': 2.2557, 'learning_rate': 7.429289754871151e-06, 'epoch': 1.89}
{'loss': 2.2302, 'learning_rate': 6.730916963475103e-06, 'epoch': 1.99}


  0%|          | 0/531 [00:00<?, ?it/s]

{'eval_loss': 2.228304624557495, 'eval_runtime': 136.5881, 'eval_samples_per_second': 62.114, 'eval_steps_per_second': 3.888, 'epoch': 2.0}
{'loss': 2.2021, 'learning_rate': 6.032544172079057e-06, 'epoch': 2.1}
{'loss': 2.1959, 'learning_rate': 5.334171380683009e-06, 'epoch': 2.2}
{'loss': 2.2246, 'learning_rate': 4.635798589286962e-06, 'epoch': 2.3}
{'loss': 2.1593, 'learning_rate': 3.937425797890914e-06, 'epoch': 2.41}
{'loss': 2.1512, 'learning_rate': 3.239053006494867e-06, 'epoch': 2.51}
{'loss': 2.153, 'learning_rate': 2.5406802150988204e-06, 'epoch': 2.62}
{'loss': 2.1859, 'learning_rate': 1.8423074237027727e-06, 'epoch': 2.72}
{'loss': 2.2147, 'learning_rate': 1.1439346323067255e-06, 'epoch': 2.83}
{'loss': 2.1729, 'learning_rate': 4.4556184091067813e-07, 'epoch': 2.93}


  0%|          | 0/531 [00:00<?, ?it/s]

{'eval_loss': 2.180795431137085, 'eval_runtime': 133.9439, 'eval_samples_per_second': 63.34, 'eval_steps_per_second': 3.964, 'epoch': 3.0}
{'train_runtime': 9574.8345, 'train_samples_per_second': 23.923, 'train_steps_per_second': 1.495, 'train_loss': 2.311999336102545, 'epoch': 3.0}


TrainOutput(global_step=14319, training_loss=2.311999336102545, metrics={'train_runtime': 9574.8345, 'train_samples_per_second': 23.923, 'train_steps_per_second': 1.495, 'train_loss': 2.311999336102545, 'epoch': 3.0})

In [70]:
model.save_pretrained("../models/bert_maskedlm")
tokenizer.save_pretrained("../models/bert_maskedlm")

('../models/bert_maskedlm\\tokenizer_config.json',
 '../models/bert_maskedlm\\special_tokens_map.json',
 '../models/bert_maskedlm\\vocab.txt',
 '../models/bert_maskedlm\\added_tokens.json')

In [69]:
random.seed(1337)

model = BertForMaskedLM.from_pretrained("../models/bert_maskedlm")
tokenizer = BertTokenizer.from_pretrained("../models/bert_maskedlm")

random_toxic_sentences = random.sample(toxic_sentences, 10)

for sentence in random_toxic_sentences:
    print(sentence)
    print(detoxificate_text(sentence, model, tokenizer))
    print()

Suddenly, to the delight and outrage of the congregation, a raucous saxophone broke the solemnity, and a jazz rendering of "Fools Rush In" was blaring over the loudspeakers.
suddenly, to the delight and delight of the congregation, a raucous s of ophone filled the'ity, and a jazz rendering of "'s rush in " was blaring over the overhead speakers.

This place is such a dump.
this place is such a hole.

Doesn't mean a damn thing!
doesn't mean a god thing!

I’m just going to have to find someone to cover for my ass first.’
i ’ m just going to have to find someone to cover for my wife first. ’

He is a walking dead man with no will of his own.
he is a walking old man with no will of his own.

You're such a jerk.
you're such a baby.

I'm jacking off.
i'm jacking off.

I may puke.
i may puke.

You gotta be shitting me.
you gotta be telling ting me.

You're a great liar, Dad.
you're a great man, dad.

