# Imports And Loading Data

In [None]:
# the notebook's main objective is to filter and prepare the dataset to train a summarizer on it.
import os, sys
from pathlib import Path
HOME = os.getcwd()

current = HOME 
while 'src' not in os.listdir(current):
    current = Path(current).parent

PARENT_DIR = str(current)
DATA_FOLDER = os.path.join(PARENT_DIR, 'data')

sys.path.append(str(current))
sys.path.append(os.path.join(str(current), 'data_analysis'))
sys.path.append(os.path.join(str(current), 'evaluation'))
sys.path.append(os.path.join(str(current), 'text_processing')) 

# Custom Toxic Classifier

In [None]:
# read the data
import pandas as pd
data = pd.read_csv(os.path.join(DATA_FOLDER, 'toxic_train.csv'), usecols= lambda x: x !='id')
data.head()

In [None]:
# let's convert all the different sub toxicity-labels into a single label.
data['is_toxic'] = ((data['toxic'] + data['severe_toxic'] + data['obscene'] + data['threat'] + data['insult'] + data['identity_hate']) > 0).astype(int)
def prepare_data(row):
    row['is_toxic'] = int(row['toxic'] + row['severe_toxic'] + row['obcene'] + row['threat'] + row['insult'] + row['identity_hate'] > 0)
    return row 
# new_data = data.apply(prepare_data, axis='index')
new_data= data.drop(columns=['toxic','severe_toxic','obscene','threat','insult','identity_hate']).rename(columns={'comment_text': 'text'})
new_data['is_toxic'].value_counts()

### Balancing the data manually
Since the data is heavily unbalanced, I had 2 options, either use the entire dataset (which around 400k) samples and apply techniques such as weighted loss, or balance it manually. The latter presented itself as a very attractive alternative due to the lack oo computational resources.

In [None]:
toxic, non_toxic = new_data[new_data['is_toxic'] == 1], new_data[new_data['is_toxic'] == 0]
# let's make a final balanced dataset
num_samples = int(1.5 * len(toxic)) 
balanced_dataset = pd.concat([toxic, non_toxic.iloc[:num_samples, :]])
# save the balanced dataset
balanced_dataset.to_csv(os.path.join(DATA_FOLDER, 'toxicity_data.csv'), index=False)

# Train Classifier

In [None]:
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AutoModel, AutoTokenizer
# load tokenizer and model weights
toxic_tokenizer = RobertaTokenizer.from_pretrained('SkolkovoInstitute/roberta_toxicity_classifier')
toxic_classifier = RobertaForSequenceClassification.from_pretrained('SkolkovoInstitute/roberta_toxicity_classifier')

In [None]:
import datasets
data = datasets.load_dataset('csv', data_files=os.path.join(DATA_FOLDER, 'toxicity_data.csv'), split='train')

import torch
from transformers import AutoTokenizer, BartForSequenceClassification, AutoModelForSequenceClassification

checkpoint = 'facebook/bart-base'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

# freeze the entire model but the classification head
for n, p in model.named_parameters():
    if n not in ["classification_head.out_proj.bias", 
                 'classification_head.dense.weight', 
                 'classification_head.dense.bias', 
                 'classification_head.out_proj.weight']:
        
        p.requires_grad = False
    else:
        print(n)

In [None]:
from typing import Dict
from torch.nn.functional import softmax
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

def process_labels(batch: Dict, device: str):
    model_input = toxic_tokenizer(batch['text'], return_tensors='pt', truncation=True, padding=True)  
    model_input = {k: v.to(device) for k, v in model_input.items()}
    toxic_classifier.to(device)
    model_input['label'] = softmax(toxic_classifier(**model_input).logits, dim=1)
    return model_input

# start by saving the outputs of the original classifier as logits after softmax
def process_data(batch: Dict):
    model_input = tokenizer(batch['text'], truncation=True)
    model_input['label'] = batch['label']
    return model_input

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
d = data.map(lambda b : process_labels(b, device=DEVICE), batched=True, batch_size=4)
d = d.map(lambda b : process_data(b, device=DEVICE), batched=True, batch_size=32).remove_columns(['is_toxic', 'text'])

In [None]:
import src.data_preparation.prepare_data as pdr
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_data, val_data, test_data = pdr.data_split(d)

## Customize the HF Trainer API

In [None]:
from torch import nn
from transformers import Trainer


# create a custom trainer for which the loss function is overridden: Use the Knowledge distillation loss
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.logits
        loss = nn.CrossEntropyLoss()(logits, labels)
        return (loss, outputs) if return_outputs else loss 

In [None]:
from transformers import TrainingArguments

batch_size = 246
num_epochs = 1
learning_rate = 5e-5
warmup_steps = 500
weight_decay = 0.01

training_args = TrainingArguments(os.path.join(os.getcwd(), "toxic_classifier_checkpoints"), 
                                  per_device_train_batch_size=batch_size, 
                                  per_device_eval_batch_size=batch_size, 
                                  num_train_epochs=10, 
                                  warmup_steps=500, 
                                  weight_decay=0.001, 
                                  learning_rate=learning_rate, 
                                  report_to='none',
                                  save_steps=100
                                  )

trainer = CustomTrainer(
    model,
    training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
from src.training_utilities.pytorch_utilities import cleanup
cleanup()
trainer.train()     

## Training Seq2Seq

In [None]:
# let's load the data
from datasets import load_dataset
train_data = load_dataset("csv", data_files=os.path.join(DATA_FOLDER, 'train_split.csv'), split='train')
val_data = load_dataset("csv", data_files=os.path.join(DATA_FOLDER, 'val_split.csv'), split='train')
test_data = load_dataset("csv", data_files=os.path.join(DATA_FOLDER, 'test_split.csv'), split='train')

In [None]:
def prepare_labeled_data(batch):
    # tokenize 'x'
    model_inputs = tokenizer(batch['source'], truncation=True)
    # tokenize 'y'  
    labels = tokenizer(text_target=batch["target"], truncation=True)
    # add it to the model's input
    model_inputs["labels"] = labels["input_ids"]
    # model_inputs["labels_attention_masks"] = labels['attention_mask']    
    return model_inputs

train_data = train_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])
val_data = val_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])
test_data = test_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])

In [None]:
# prepare the data collator
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

# define the custom trainer
from src.training_utilities.pytorch_utilities import get_module_device
from torch import nn
from transformers import Trainer
from torch.nn.functional import softmax

class CustomTrainer(Trainer):
    def __init__(self, toxic_classifier, *args, **kwargs):
        self.toxic_classifier = toxic_classifier
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        model_output = model(**inputs)
        
        model_device = get_module_device(model)
        # extract the sequence to sequence loss
        s2s_loss = model_output.loss

        labels = inputs['labels']
        batch_size, max_sentence_length = labels.shape

        # reproduce the 
        prediction_ids = model.generate(inputs['input_ids'], max_length=max_sentence_length)

        attention_mask = torch.where(prediction_ids == self.tokenizer.pad_token_id,
                                     torch.zeros(*prediction_ids.shape).to(model_device), torch.ones(*prediction_ids.shape).to(model_device))
        
        toxic_output = self.toxic_classifier(input_ids=prediction_ids, attention_mask=attention_mask)
        toxic_loss = torch.mean(F.softmax(toxic_output.logits, dim=1)[:, 1])
        loss = s2s_loss + 0.1 * toxic_loss 
        return (loss, model_output) if return_outputs else loss 


In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

batch_size = 16
num_epochs = 5
learning_rate = 5e-5
warmup_steps = 500
weight_decay = 0.01

sc_training_args = Seq2SeqTrainingArguments(
    output_dir='seq_2_seq_checkpoints',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    do_train=True,
    do_eval=True,
    logging_steps=100,
    save_steps=100,
    eval_steps=10,
    overwrite_output_dir=True,
    warmup_steps=warmup_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    num_train_epochs=num_epochs,
    report_to="none",
)

trainer = CustomTrainer(
    model=model,
    args=sc_training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator
)

In [None]:
from src.training_utilities.pytorch_utilities import cleanup
cleanup()
trainer.train() 