# First Hypothesis: Mask toxic words and use MaskedLM to find appropriate alternatives

In [1]:
import pandas as pd
from tqdm import tqdm
from transformers import BertForMaskedLM, BertTokenizer, DataCollatorForLanguageModeling
from datasets import Dataset
import torch

import sys
sys.path.append('..')
from src.data.preprocess import put_mask
from src.models.predict import detoxificate_text
from src.models.train import train

import warnings
warnings.filterwarnings('ignore')

RANDOM_SEED = 1337
torch.manual_seed(RANDOM_SEED)

Some weights of the model checkpoint at SkolkovoInstitute/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<torch._C.Generator at 0x2656fbae410>

### Loading bert-base-uncased model for MaskedLM 

In [2]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Creating dataset for the model

In [3]:
df = pd.read_csv('../data/interim/train.csv')
toxic_sentences = df['reference'].tolist()
non_toxic_sentences = df['translation'].tolist()
toxic_words = open('../data/interim/toxic_words.txt').read().split('\n')

data = []
labels = []

for i in tqdm(range(len(toxic_sentences))):
    toxic_sentences[i] = put_mask(toxic_sentences[i], toxic_words)
    if '[MASK]' in toxic_sentences[i]:
        data.append(toxic_sentences[i])
        labels.append(non_toxic_sentences[i])

dataset = Dataset.from_dict({"text": data, "labels": labels})

100%|██████████| 97006/97006 [00:50<00:00, 1920.72it/s]


### Splitting the dataset into batches

In [4]:
MAX_LEN = 128

def group_texts(examples):
    inputs = [ex for ex in examples['text']]
    target = [ex for ex in examples['labels']]

    batch = tokenizer(inputs, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt')
    batch["labels"] = tokenizer(target, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors='pt').input_ids

    return batch

dataset = dataset.map(group_texts, batched=True)

Map:   0%|          | 0/86397 [00:00<?, ? examples/s]

### Split the data into train and validation

In [5]:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset = dataset.select(range(train_size))
val_dataset = dataset.select(range(train_size, train_size + val_size))

### Train using the Hugging Face trainer

In [6]:
train('maskedlm', 
      model, 
      tokenizer, 
      train_dataset, 
      val_dataset, 
      data_collator,
      batch_size=16, 
      epochs=3,
      seed=RANDOM_SEED
)

  0%|          | 0/14580 [00:00<?, ?it/s]

{'loss': 2.476, 'learning_rate': 1.9314128943758576e-05, 'epoch': 0.1}
{'loss': 2.2942, 'learning_rate': 1.862825788751715e-05, 'epoch': 0.21}
{'loss': 2.2701, 'learning_rate': 1.7942386831275723e-05, 'epoch': 0.31}
{'loss': 2.245, 'learning_rate': 1.7256515775034294e-05, 'epoch': 0.41}
{'loss': 2.1805, 'learning_rate': 1.6570644718792868e-05, 'epoch': 0.51}
{'loss': 2.1607, 'learning_rate': 1.588477366255144e-05, 'epoch': 0.62}
{'loss': 2.2014, 'learning_rate': 1.5198902606310016e-05, 'epoch': 0.72}
{'loss': 2.1563, 'learning_rate': 1.4513031550068588e-05, 'epoch': 0.82}
{'loss': 2.1493, 'learning_rate': 1.3827160493827162e-05, 'epoch': 0.93}


  0%|          | 0/540 [00:00<?, ?it/s]

{'eval_loss': 1.9825509786605835, 'eval_runtime': 137.9325, 'eval_samples_per_second': 62.639, 'eval_steps_per_second': 3.915, 'epoch': 1.0}
{'loss': 2.0801, 'learning_rate': 1.3141289437585736e-05, 'epoch': 1.03}
{'loss': 2.0414, 'learning_rate': 1.2455418381344308e-05, 'epoch': 1.13}
{'loss': 2.019, 'learning_rate': 1.1769547325102882e-05, 'epoch': 1.23}
{'loss': 2.0071, 'learning_rate': 1.1083676268861454e-05, 'epoch': 1.34}
{'loss': 2.0173, 'learning_rate': 1.039780521262003e-05, 'epoch': 1.44}
{'loss': 1.9853, 'learning_rate': 9.711934156378602e-06, 'epoch': 1.54}
{'loss': 1.9916, 'learning_rate': 9.026063100137174e-06, 'epoch': 1.65}
{'loss': 1.9981, 'learning_rate': 8.340192043895748e-06, 'epoch': 1.75}
{'loss': 1.9776, 'learning_rate': 7.654320987654322e-06, 'epoch': 1.85}
{'loss': 1.9767, 'learning_rate': 6.968449931412895e-06, 'epoch': 1.95}


  0%|          | 0/540 [00:00<?, ?it/s]

{'eval_loss': 1.8876408338546753, 'eval_runtime': 137.6567, 'eval_samples_per_second': 62.765, 'eval_steps_per_second': 3.923, 'epoch': 2.0}
{'loss': 1.9605, 'learning_rate': 6.282578875171468e-06, 'epoch': 2.06}
{'loss': 1.9216, 'learning_rate': 5.596707818930042e-06, 'epoch': 2.16}
{'loss': 1.9267, 'learning_rate': 4.910836762688615e-06, 'epoch': 2.26}
{'loss': 1.8795, 'learning_rate': 4.224965706447189e-06, 'epoch': 2.37}
{'loss': 1.9247, 'learning_rate': 3.5390946502057617e-06, 'epoch': 2.47}
{'loss': 1.9298, 'learning_rate': 2.8532235939643347e-06, 'epoch': 2.57}
{'loss': 1.9279, 'learning_rate': 2.167352537722908e-06, 'epoch': 2.67}
{'loss': 1.9018, 'learning_rate': 1.4814814814814815e-06, 'epoch': 2.78}
{'loss': 1.9098, 'learning_rate': 7.956104252400549e-07, 'epoch': 2.88}
{'loss': 1.8872, 'learning_rate': 1.0973936899862826e-07, 'epoch': 2.98}


  0%|          | 0/540 [00:00<?, ?it/s]

{'eval_loss': 1.8983267545700073, 'eval_runtime': 137.3143, 'eval_samples_per_second': 62.921, 'eval_steps_per_second': 3.933, 'epoch': 3.0}
{'train_runtime': 9772.7218, 'train_samples_per_second': 23.87, 'train_steps_per_second': 1.492, 'train_loss': 2.047203666781202, 'epoch': 3.0}


### Some examples from test data

In [13]:
best_model = BertForMaskedLM.from_pretrained("../models/bert_maskedlm")
tokenizer = BertTokenizer.from_pretrained("../models/bert_maskedlm")

test_toxic_sentences = pd.read_csv('../data/interim/test.csv')['reference'].to_list()[:10]

detoxified = detoxificate_text(test_toxic_sentences, toxic_words, tokenizer, best_model)

for sentence, detoxified_sentence in zip(test_toxic_sentences, detoxified):
    print(f'Original: {sentence}')
    print(f'Masked: {put_mask(sentence, toxic_words)}')
    print(f'Detoxified: {detoxified_sentence}')
    print()

Original: Hurt you and I want to find out who that was.
Masked: [MASK] you and i want to find out who that was.
Detoxified: both you and i want to find out who that was.

Original: I can't believe we haven't fucked for two years, nine months, three weeks and... 69 hours.
Masked: i can't [MASK]ve we haven't [MASK]ed for two years, nine months, three weeks and... 69 hours.
Detoxified: i can't've we haven't been ed for two years, nine months, three weeks and... 69 hours.

Original: So forgive me for being a little fidgety, but if it's lunch we're talking, I'm gonna eat a fat pork sandwich, and I'm sure as shit not gonna eat it here.
Masked: so forgive me for being a little [MASK]y, but if it's lunch we're talking, i'm gonna eat a [MASK] pork sandwich, and i'm sure as [MASK] not gonna eat it here.
Detoxified: so forgive me for being a little co y, but if it's lunch we're talking, i'm gonna eat a roast pork sandwich, and i'm sure as well not gonna eat it here.

Original: There is no fucking