In [1]:
import numpy as np
import pandas as pd
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import CountVectorizer
from transformers import BertForMaskedLM, BertTokenizer
from scipy.spatial.distance import cosine
from joblib import dump

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Reading the datadrame
df = pd.read_csv('../data/interim/02_ParaNMT_train.csv')
df.head()

Unnamed: 0,reference,translation,similarity,lenght_diff,ref_tox,trn_tox
0,"Well, if you ask me, Family is a pain in the b...","if you ask me, then the family is a pain in th...",0.81907,0.087912,0.310097,0.994614
1,"Well, you are screwing up my life, which I'm u...","well, you ruined my life, which I got used to,...",0.745561,0.113402,0.591161,0.076705
2,"I mean, what kind of loser has his bachelor pa...",what kind of poor guy has ten yards away from ...,0.64362,0.126316,0.986372,0.000157
3,I asked him to confine his salacious acts to t...,I asked him to carry out his filthy activities...,0.714307,0.081395,0.013184,0.933126
4,Found on the beach - some sort of driftwood. -...,we found him on the beach - like a piece of ju...,0.69971,0.195876,7.3e-05,0.938116


## Training a logistic classifier

In [3]:
vectorizer = CountVectorizer()
X = vectorizer.fit_transform(df['reference'])
y = (df['ref_tox'] > 0.5).astype(int)  # Binary classification based on the toxicity threshold

In [4]:
classifier = LogisticRegression()
classifier.fit(X, y)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


## Generate Substitutions using Bert

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

#Identifying the toxic words
feature_names = np.array(vectorizer.get_feature_names_out())
weights = classifier.coef_[0]

# Normalize the weights and find the indices of the words with the highest weights
normalized_weights = weights / np.linalg.norm(weights)
toxic_indices = np.argsort(normalized_weights)[-10:]  # Get top 10 toxic words for example

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def get_substitutes(sentence, toxic_words):
    substitutes = {}
    for word in toxic_words:
        # Mask each toxic word in the sentence
        masked_sentence = sentence.replace(word, tokenizer.mask_token * len(word.split()))
        inputs = tokenizer.encode_plus(masked_sentence, return_tensors='pt')
        input_ids = inputs['input_ids']
        token_logits = model(input_ids).logits
        
        # Find all indices of the mask_token_id in input_ids
        masked_token_indices = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        
        # If no mask_token was found, skip this word
        if len(masked_token_indices) == 0:
            # print(f"No masked token found for word '{word}' in the sentence.")
            continue

        substitutes_for_word = []
        for idx in masked_token_indices:
            # Get top 5 tokens for each masked token index
            top_5_tokens = torch.topk(token_logits[0, idx], 5).indices.tolist()
            substitutes_for_word.extend([tokenizer.decode([token]) for token in top_5_tokens])
        
        substitutes[word] = substitutes_for_word

    return substitutes

In [7]:
toxic_words = feature_names[toxic_indices]
substitutes = get_substitutes('This is a bullshit', toxic_words)