# Dataset

In [None]:
import torch


class CivilCommentsDataset(torch.utils.data.Dataset):
    """
    Builds split instance of the `civil_comments` dataset: https://huggingface.co/datasets/civil_comments.
    """

    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


# Model

In [None]:
import numpy as np
from datasets import load_dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments


class CivilCommentsModel:
    """
    Trains distil-bert Hugging Face model on the `civil_comments` dataset
    """

    def __init__(self, num_train_epochs, num_training_points="all"):
        # Loading tokenizer and dataset
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        self.dataset = load_dataset("civil_comments")

        # Build dataset splits
        if num_training_points != "all":
            assert num_training_points <= len(self.dataset["train"]["text"])
        encodings, labels = self.build_data_split("train", num_training_points)
        self.train_dataset = CivilCommentsDataset(encodings, labels)
        encodings, labels = self.build_data_split("validation", 500)
        self.val_dataset = CivilCommentsDataset(encodings, labels)
        encodings, labels = self.build_data_split("test", 500)
        self.test_dataset = CivilCommentsDataset(encodings, labels)

        # Building model and freezing layers of base
        self.model = DistilBertForSequenceClassification.from_pretrained(
            'distilbert-base-uncased',
            num_labels=1
        )

        for param in self.model.base_model.parameters():
            param.requires_grad = False

        # Building trainer
        self.training_args = TrainingArguments(
            output_dir='./results/1',
            num_train_epochs=num_train_epochs,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            weight_decay=0.01,
            logging_dir='./logs',
            evaluation_strategy="steps",
            eval_steps=10,
            load_best_model_at_end=True
        )
        self.trainer = Trainer(
            model=self.model,
            args=self.training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.val_dataset,
            compute_metrics=self.compute_metrics,
        )

    def build_data_split(self, split, num_data_points="all"):
        print(f"Generating {num_data_points} data points for {split} split...", end="", flush=True)

        civil_idx = []
        uncivil_idx = []
        num_civil = num_data_points / 2
        num_uncivil = num_data_points / 2

        for i, data in enumerate(self.dataset[split]):
            if data["toxicity"] < 0.5 and num_civil > 0:
                civil_idx.append(i)
                num_civil -= 1
            elif data["toxicity"] > 0.5 and num_uncivil > 0:
                uncivil_idx.append(i)
                num_uncivil -= 1

            if num_civil == 0 and num_uncivil == 0:
                break

        indexes = civil_idx + uncivil_idx
        import random
        random.shuffle(indexes)

        if num_data_points == "all":
            encodings = self.tokenizer(self.dataset[split]["text"], truncation=True, padding=True)
            labels = self.dataset[split]["toxicity"]
        else:
            encodings = self.tokenizer(self.dataset[split][indexes]["text"], truncation=True, padding=True)
            labels = self.dataset[split][indexes]["toxicity"]
        print("done")
        return encodings, labels

    @staticmethod
    def compute_metrics(model_output):
        pred, label = model_output
        pred = np.array(pred)
        label = np.array(label)

        mse = np.mean(np.square(pred - label))
        acc = np.mean((pred > 0.5) == (label > 0.5))
        return {
            "MSE": mse,
            "Threshold Accuracy": acc
        }


# Main Program

In [None]:
model = CivilCommentsModel(
    num_train_epochs=5,
    num_training_points=500
)

model.trainer.train()
model.trainer.evaluate(eval_dataset=model.test_dataset)
model.trainer.save_model()

print("Program training complete")
