# Text Style Transfer

#####  Perform a rule-based text-style transfer by first training a token classifer to predict toxic tokens, then using antonyms to substitute such tokens.


### Yakoob Khan '21
### Date: March 18, 2021

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd "drive/My Drive/toxic-comments-classification-challenge"

/content/drive/My Drive/toxic-comments-classification-challenge


In [3]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/2c/d8/5144b0712f7f82229a8da5983a8fbb8d30cec5fbd5f8d12ffe1854dcea67/transformers-4.4.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 16.0MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 62.7MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 53.9MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=6f4f

### 1. Train a BERT Sequence Labeling Classifier to detect toxic tokens

In [4]:
import time
import torch
import numpy as np
import pandas as pd 
import json
import random
import time
import ast
import string
import itertools
import argparse
import random
from collections import defaultdict
from ast import literal_eval
from transformers import BertForTokenClassification, Trainer, TrainingArguments, BertTokenizerFast 

In [11]:
def load_dataset(dataset_path):
    dataset = pd.read_csv(dataset_path)
    print(f"\n> Loading {dataset.shape[0]} examples located at '{dataset_path}'\n")

    dataset["spans"] = dataset.spans.apply(literal_eval)
    texts, spans = dataset["text"], dataset["spans"]
    texts = [text for text in texts]
    spans = [span for span in spans]
   
    return texts, spans

def load_asian_tweets_test_set(dataset_path):
  dataset = pd.read_csv(dataset_path)
  print(f"\n> Loading {dataset.shape[0]} test examples located at '{dataset_path}'\n")
  texts = dataset["text"]
  texts = [text for text in texts]
  categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
  labels = []
  for i in range(len(texts)):
    row_labels = [dataset[category][i] for category in categories]
    labels.append(row_labels)

  return texts, labels


def preserve_labels(text_encoding, span):
  labels = [0] * len(text_encoding.tokens)
  toxic_indices = set(span)
  for i, offset in enumerate(text_encoding.offsets):
    # labels for CLS, SEP and PAD tokens are set to -100.
    if offset == (0, 0):
      labels[i] = -100
    
    else:
      # check if any character indices of this sub-token has gold label toxic 
      for k in range(offset[0], offset[1]):
        if k in toxic_indices: 
          # toxic, so set label to 1.
          labels[i] = 1
          break
  
  return labels
      
def tokenize_data(tokenizer, texts, spans):
    text_encodings = tokenizer(texts, return_offsets_mapping=True, padding=True, truncation=True)
    labels = [preserve_labels(text_encodings[i], span) for i, span in enumerate(spans)]
    return text_encodings, labels


def tokenize_testset(tokenizer, texts):
    text_encodings = tokenizer(texts, return_offsets_mapping=True, padding=True, truncation=True)
    dummy_labels = [[0] * len(tokens) for i, tokens in enumerate(text_encodings.input_ids)]
    return text_encodings, dummy_labels


class ToxicSpansDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

import nltk

# For tokenizing sentences
nltk.download('punkt')
sentence_tokenizer = nltk.data.load('tokenizers/punkt/PY3/english.pickle')


def toxic_character_offsets_with_thresholding(post_num, tokens, offset_mapping, prediction, val_sentences_info, prediction_score, threshold):
  toxic_offsets = []
  scores = []
  n = len(tokens)
  i = 1           # start from 1 as 0th token is [CLS]
  while i < n:
    # stop looping after processing all post tokens
    if tokens[i] == '[SEP]':
      break

    cur_toxic = []
    # if previous token is also predicted toxic, then toxic phrase found
    if len(toxic_offsets) > 0 and toxic_offsets[-1] == offset_mapping[i-1][1] - 1:
      cur_toxic.extend([index for index in range(offset_mapping[i-1][1], offset_mapping[i][0])])
    
    # add the characters offsets of this head BPE
    cur_toxic.extend([index for index in range(offset_mapping[i][0], offset_mapping[i][1])])
    cur_score = [(tokens[i], prediction_score[i].max())]
    cur_labels = [prediction[i]]
    
    # process all sub-tokens of the current head BPE
    i += 1
    while i < n and '##' in tokens[i]:
      cur_toxic.extend([index for index in range(offset_mapping[i][0], offset_mapping[i][1])])
      cur_score.append((tokens[i], prediction_score[i].max()))
      cur_labels.append(prediction[i])
      i += 1
    
    # word is predicted toxic if any sub-token is predicted toxic by model
    prediction_label = True if max(cur_labels) == 1 else False
    # prediction_label = True if min(cur_labels) == 1 else False

    
    # include cur_toxic offsets if any of the sub-token confidence score is greater than threshold
    confidence_values = [score for _, score in cur_score]
    passed_threshold = True if max(confidence_values) >= threshold else False
    # passed_threshold = True if min(confidence_values) >= threshold else False

    # include to global toxic offsets list only if both predicted label and threshold criteria passes
    if prediction_label and passed_threshold:
      toxic_offsets.extend(cur_toxic)
      scores.extend(cur_score)
  

  return toxic_offsets, scores

def character_offsets_with_thresholding(val_text_encodings, val_offset_mapping, predictions, val_sentences_info, prediction_scores, threshold=-float('inf')):
  return [toxic_character_offsets_with_thresholding(i, val_text_encodings[i].tokens, offset_mapping, prediction, val_sentences_info, prediction_scores[i], threshold) for i, (offset_mapping, prediction) in enumerate(zip(val_offset_mapping, predictions))]

def _contiguous_ranges(span_list):
    """Extracts continguous runs [1, 2, 3, 5, 6, 7] -> [(1,3), (5,7)].
       Credit: https://github.com/ipavlopoulos/toxic_spans/blob/master/evaluation/fix_spans.py
    """
    output = []
    for _, span in itertools.groupby(
        enumerate(span_list), lambda p: p[1] - p[0]):
        span = list(span)
        output.append((span[0][1], span[-1][1]))
    return output



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [31]:
# Load the BERT base cased tokenizer and pre-trained model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=2)

# Load the train, val and test csv files
training_texts, training_spans = load_dataset('./data/tsd_train.csv')
val_texts, val_spans = load_dataset('./data/tsd_trial.csv')
test_texts, labels  = load_asian_tweets_test_set('./data/asian_y_pred.csv')

print(f"> Total number of Scraped Tweets in Test Set : {len(test_texts)} \n")

filtered_texts = []
# Filter the test tweets that contain at least one toxicity predicted label
for i in range(len(test_texts)):
  if max(labels[i]) == 1:
    filtered_texts.append(test_texts[i])

test_texts = filtered_texts

print(f"> Filtered test set that contains at least one offensive language label prediction: {len(test_texts)} \n")

val_sentences_info = {}
print('\n> Tokenizing text and generating word embeddings.. \n')
train_text_encodings, train_labels_encodings = tokenize_data(tokenizer, training_texts, training_spans)
val_text_encodings, val_labels_encodings = tokenize_data(tokenizer, val_texts, val_spans)
test_text_encodings, test_labels_encodings = tokenize_testset(tokenizer, test_texts)

# Create Torch Dataset Objects for train / valid sets
print('> Creating Tensor Datasets.. \n')
train_dataset = ToxicSpansDataset(train_text_encodings, train_labels_encodings)
val_dataset = ToxicSpansDataset(val_text_encodings, val_labels_encodings)
test_dataset = ToxicSpansDataset(test_text_encodings, test_labels_encodings)

print(f"> Training examples: {len(train_dataset)}")
print(f"> Validation examples: {len(val_dataset)}")
print(f"> Test examples: {len(test_dataset)}\n")

# We don't want to pass offset mappings to the model
train_offset_mapping = train_text_encodings.pop("offset_mapping") 
val_offset_mapping = val_text_encodings.pop("offset_mapping")
test_offset_mapping = test_text_encodings.pop("offset_mapping")

# Training Argument Object with hyper-parameter configuration.
training_args = TrainingArguments(
  output_dir='./logs',                    # output directory
  num_train_epochs=2,                     # total number of training epochs
  per_device_train_batch_size=8,         # batch size per device during training
  per_device_eval_batch_size=8,          # batch size for evaluation
  warmup_steps=500,                       # number of warmup steps for learning rate scheduler
  weight_decay=0.01,                      # strength of weight decay
  logging_dir='./logs',                   # directory for storing logs
  logging_steps=100,                      # log after every x steps
  do_eval=True,                           # whether to run evaluation on the val set
  evaluation_strategy="steps",            # evaluation is done (and logged) every logging_steps 
  learning_rate=5e-5,                     # 5e-5 is default learning rate
  disable_tqdm=True,                      # remove tqdm statements to reduce clutter
)

# Trainer Object
trainer = Trainer(
  model=model,                 # the instantiated 🤗 Transformers model to be trained
  args=training_args,          # training arguments, defined above
  train_dataset=train_dataset,       
  eval_dataset=val_dataset,         
)

print('> Started Toxic Spans Detection training! \n')
trainer.train()

print('> Making toxic token predictions on Scraped Asian Tweets \n')
# use trained model to make toxic token predictions on test datasets
test_pred = trainer.predict(test_dataset)

# retrieve the predictions
test_predictions = test_pred.predictions.argmax(-1)
test_prediction_scores = test_pred.predictions
test_toxic_char_preds = character_offsets_with_thresholding(test_text_encodings, test_offset_mapping, test_predictions, val_sentences_info,  test_prediction_scores, threshold=-float('inf'))

toxic_char_offsets = [span[0] for span in test_toxic_char_preds]
test_set_toxic_tokens = {'text': [], 'spans': []}

for text, pred in zip(test_texts, toxic_char_offsets):
  test_set_toxic_tokens['text'].append(text)
  test_set_toxic_tokens['spans'].append(pred)

# Save the toxic span predictions in a CSV file
df = pd.DataFrame(test_set_toxic_tokens)
df.to_csv('./style-transfer/asian_tweet_toxic_token_predictions.csv', index=False)


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas


> Loading 7939 examples located at './data/tsd_train.csv'


> Loading 690 examples located at './data/tsd_trial.csv'


> Loading 28408 test examples located at './data/asian_y_pred.csv'

> Total number of Scraped Tweets in Test Set : 28408 

> Filtered test set that contains at least one offensive language label prediction: 2288 


> Tokenizing text and generating word embeddings.. 

> Creating Tensor Datasets.. 

> Training examples: 7939
> Validation examples: 690
> Test examples: 2288

> Started Toxic Spans Detection training! 

{'loss': 0.4037, 'learning_rate': 1e-05, 'epoch': 0.1}
{'eval_loss': 0.22229667007923126, 'eval_runtime': 3.3458, 'eval_samples_per_second': 206.229, 'epoch': 0.1}
{'loss': 0.2761, 'learning_rate': 2e-05, 'epoch': 0.2}
{'eval_loss': 0.20020772516727448, 'eval_runtime': 3.3487, 'eval_samples_per_second': 206.05, 'epoch': 0.2}
{'loss': 0.2494, 'learning_rate': 3e-05, 'epoch': 0.3}
{'eval_loss': 0.20238102972507477, 'eval_runtime': 3.3444, 'eval_samples_per_se

In [32]:
from nltk.tokenize import TreebankWordTokenizer as twt
nltk.download('wordnet')
from nltk.corpus import wordnet 

def neutralize(text, span):
  toxic_offsets = set(span)
  # tokenize the text using NLTK
  try:
    token_indices = twt().span_tokenize(text)
    # loop through the tokens and convert any toxic tokens to antonym
    neutral = ''
    for i, j in token_indices:
      toxic = False
      for k in range(i, j+1):
        if k in toxic_offsets:
          toxic = True
          break
      
      token = text[i:j+1]
      if not toxic:
        neutral += f"{token} "
      else:
        antonyms = []
        # find all the antonyms of this word
        # Credit: https://www.geeksforgeeks.org/get-synonymsantonyms-nltk-wordnet-python/
        for syn in wordnet.synsets(token): 
          for l in syn.lemmas(): 
            if l.antonyms(): 
                antonyms.append(l.antonyms()[0].name()) 

        # pick a random antonym if there is any
        
        substituted_token = random.choice(antonyms) if antonyms else "***"
        neutral += f"{substituted_token} "
    
    return neutral.strip()

  except:
    # twt().span_tokenize(text) raises error sometimes, so catch all statement to return original text
    return text

asian_test_texts, test_spans = load_dataset('./style-transfer/asian_tweet_toxic_token_predictions.csv')
neutralized_dict = {'original': [], 'neutral': []}

for text, span in zip(asian_test_texts, test_spans):
  
  neutralized_dict['original'].append(text)
  neutralized_text = neutralize(text, span)
  neutralized_dict['neutral'].append(neutralized_text)

# Save the neutralized posts in a CSV file
df = pd.DataFrame(neutralized_dict)
df.to_csv('./style-transfer/asian_toxic_tweet_neutralized.csv', index=False)
df

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!

> Loading 2288 examples located at './style-transfer/asian_tweet_toxic_token_predictions.csv'



Unnamed: 0,original,neutral
0,These morons attack an Asian business owner fo...,These *** attack an Asian business owner ...
1,"Last tweet of the night... in 2012, I truley t...","Last tweet of the night. ... in 2012, , ..."
2,@THR Could a closeted gay asian man be a chink...,@T THR Could a closeted *** asian man be...
3,"When you’re Muslim, and only want/need one wif...","When you’ ’r re Muslim, , and only want/n..."
4,@JGreenblattADL Is all trumps fault. His vicio...,@J JGreenblattADL Is all trumps fault. Hi...
...,...,...
2283,"Be afraid, be very afraid taxpayers! Thru the ...","Be afraid, , be very afraid taxpayers! ! ..."
2284,It shouldn’t surprise anyone that the previous...,It shouldn’ ’t t surprise anyone that the...
2285,@guardian Meghan is country girl turned duches...,@g guardian Meghan is country girl turned...
2286,"The source article, with warnings from Chines ...","The source article, , with warnings from ..."
