# The Outline

1. Using the pre-trained available BART-ParaDetox model, we run all of the comments in Jigsaw dataset through it to generate its parallel detoxed dataset.
2. Using the generated parallel dataset, we use it to train our own BART model
3. The trained BART model will then be used to generate the detoxified sentence of any input sentence.

## Setup the python environment
Ideally you run the following cells with the virtual environment created using python3 (anything before 3.10 version works from our testing) with the packages specified in the `requirements.txt`.

In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
# for cleaning later
from transformers import BertTokenizer
import spacy
# from emoji import demojize
import string
# spacy.cli.download("en_core_web_sm")
# nlp = spacy.load("en_core_web_sm")

  from .autonotebook import tqdm as notebook_tqdm


Read and load the Jigsaw dataset with the relevant columns:
- `comment_text` : the column containing the input toxic sentence
- `toxicity` : the column cotaining the toxicity score of the input toxic sentence

## 1. Generating the parallel dataset from Jigsaw
We use `cleaned_balanced_data.csv` as it is a smaller dataset

In [2]:
df = pd.read_csv('all_data.csv')
target_columns = ['comment_text', 'toxicity', 'severe_toxicity']

### Data Cleaning

In [4]:
import re
import nltk
from nltk.corpus import stopwords
# https://mlwhiz.com/blog/2019/01/17/deeplearning_nlp_preprocess/

def clean_text(text):
    if isinstance(text, str):
        # Convert to lowercase
        text = text.lower()
        
        # Convert emojis to text descriptions
        # text = demojize(text)
        
        # Remove URLs
        text = re.sub(r'http\S+|www\S+', '', text)
        
        # Remove HTML tags
        text = re.sub(r'<.*?>', '', text)
        
        # Remove punctuation
        text = text.translate(str.maketrans('', '', string.punctuation))
        
        # Remove stopwords
        stop_words = set(stopwords.words('english'))
        text = ' '.join(word for word in text.split() if word not in stop_words)
        
        # Handle contractions
        text = re.sub(r"won't", "will not", text)
        text = re.sub(r"can\'t", "can not", text)
        text = re.sub(r"n\'t", " not", text)
        text = re.sub(r"\'re", " are", text)
        text = re.sub(r"\'s", " is", text)
        text = re.sub(r"\'d", " would", text)
        text = re.sub(r"\'ll", " will", text)
        text = re.sub(r"\'t", " not", text)
        text = re.sub(r"\'ve", " have", text)
        text = re.sub(r"\'m", " am", text)
        
        # Handle self-censored text
        text = re.sub(r'\*+', '', text)
        
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
    else:
        text = ''
    
    return text

df['cleaned_text']= df['comment_text'].apply(clean_text)
df.to_csv('clean_data.csv', index=False)
# cleaned

In [5]:
df = pd.read_csv('clean_data.csv')

In [6]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load the pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("s-nlp/bart-base-detox")
model = AutoModelForSeq2SeqLM.from_pretrained("s-nlp/bart-base-detox")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)  # Move the model to the desired device (CPU or GPU)

# Create lists to store the dataframes
jigsaw_parallel_data = []
jigsaw_strict_parallel_data = []

# Set the batch size
batch_size = 64

# Iterate over the rows in the dataframe in batches
# len(df) for whole dataset, eg 200 for first 200 rows
# Note that if num_rows_to_process is not divisible by batch_size, the last batch will contain fewer rows than batch_size.
for start_idx in tqdm(range(0, 1000, batch_size), desc="Detoxifying texts"):
    end_idx = start_idx + batch_size
    batch_texts = df['cleaned_text'][start_idx:end_idx].tolist()
    batch_toxicity = df['toxicity'][start_idx:end_idx].tolist()
    batch_severe_toxicity = df['severe_toxicity'][start_idx:end_idx].tolist()

    # Tokenize the input texts
    tokens = tokenizer(batch_texts, return_tensors='pt', padding=True)
    tokens = tokens.to(device)  # Move the tensors to the desired device

    # Generate the detoxified texts
    output_tokens = model.generate(**tokens, num_return_sequences=1, do_sample=False,
                                   temperature=1.0, repetition_penalty=10.0,
                                   max_length=128, num_beams=10)

    # Decode the output tokens to get the detoxified texts
    detoxed_texts = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)

    # Append the data to the respective lists
    for text, detoxed_text, toxicity, severe_toxicity in zip(batch_texts, detoxed_texts, batch_toxicity, batch_severe_toxicity):
        jigsaw_parallel_data.append({'pre-detoxed': text, 'detoxed_text': detoxed_text, 'toxicity': toxicity, 'severe_toxicity': severe_toxicity})
        if text != detoxed_text:
            jigsaw_strict_parallel_data.append({'pre-detoxed': text, 'detoxed_text': detoxed_text, 'toxicity': toxicity, 'severe_toxicity': severe_toxicity})

# Create dataframes from the lists
jigsaw_parallel = pd.DataFrame(jigsaw_parallel_data)
jigsaw_strict_parallel = pd.DataFrame(jigsaw_strict_parallel_data)

# Save the dataframes to CSV files
jigsaw_parallel.to_csv('jigsaw_parallel.csv', index=False)
jigsaw_strict_parallel.to_csv('jigsaw_strict_parallel.csv', index=False)

Detoxifying texts:   6%|▋         | 1/16 [24:41<6:10:15, 1481.06s/it]


KeyboardInterrupt: 

In [6]:
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# tokenizer = AutoTokenizer.from_pretrained("s-nlp/bart-base-detox")
# model = AutoModelForSeq2SeqLM.from_pretrained("s-nlp/bart-base-detox")
# # 'that sick fuck is going to be out in 54 years.', 
# toxics = ["cant fix stupid but we are electing them to public office due in large part to the silent majority who choose not to vote"]
# tokens = tokenizer(toxics, return_tensors='pt', padding=True)
# tokens = model.generate(**tokens, num_return_sequences=5, do_sample=False,
#                         temperature=1.0, repetition_penalty=10.0,
#                         max_length=128, num_beams=10)
# neutrals = tokenizer.decode(tokens[0, ...], skip_special_tokens=True)
# print(neutrals) # stdout: She is going to be out in 54 years.


cant fix bad but we are electing them to public office due in large part to the silent majority who choose not to vote
