# Imports

In [1]:
import sys
sys.path.insert(1, '../')

from datasets import load_dataset, load_metric
import transformers
import datasets
import random
from IPython.display import display, HTML
import torch
from tqdm.auto import tqdm
from math import sqrt
from torch.utils.data import random_split

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

# my imports
from src.data.load_data import load_tokenized_data

## Loading the dataset (???)

In [2]:
global_seed = 1984
transformers.set_seed(global_seed)

model_checkpoint = "t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [3]:
df = load_tokenized_data(path='../data/raw/filtered.tsv',
                         cache_path='../data/processed/tokenized.tsv',
                         tokenizer=tokenizer, 
                         flatten=True)

In [4]:
class ToxicDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe):
        self.raw_data = dataframe

        self.texts = dataframe['text'].tolist()
        self.targets = dataframe['toxicity'].tolist()

        self.inputs = []
        for input, target in zip(self.texts, self.targets):
            model_input = { 'input_ids': input, 'labels': target }
            self.inputs.append(model_input)

    def __getitem__(self, idx):
        return self.inputs[idx]

    def __len__(self):
        return len(self.inputs)

## Preprocessing the data
As usual we will need to preprocess data and tokenize it before passing to model

In [5]:
dataset = ToxicDataset(df)

val_ratio = 0.2
train_dataset, val_dataset = random_split(dataset, [1 - val_ratio, val_ratio])

# train_dataset, val_dataset, temp = random_split(dataset, [0.01, 0.05, 0.94])

In [6]:
def collate_batch(batch):
    return tokenizer.pad(batch, return_tensors='pt')

## Fine-tuning the model

In [7]:
# create a model for the pretrained model
t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
encoder = t5_model.encoder

In [8]:
from torch import nn

class RegressorMk1(nn.Module):
    def __init__(self, encoder, h1=128, h2=32):
        super(RegressorMk1, self).__init__()
        self.encoder = encoder
        self.h1 = nn.Linear(512, h1)
        self.h2 = nn.Linear(h1, h2)
        self.h3 = nn.Linear(h2, 1)

    def forward(self, input_ids, attention_mask, labels=None):
        encoded = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        x1 = nn.functional.relu(self.h1(encoded))
        x2 = nn.functional.relu(self.h2(x1))
        x3 = self.h3(x2)

        return nn.functional.sigmoid(torch.sum(x3[:,:,0] * attention_mask, dim=1))

class Regressor(nn.Module):
    def __init__(self, encoder):
        super(Regressor, self).__init__()
        self.encoder = encoder
        self.h1 = nn.Linear(512, 1)

    def forward(self, input_ids, attention_mask, labels=None):
        encoded = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        x = self.h1(encoded)

        return nn.functional.sigmoid(torch.sum(x[:,:,0] * attention_mask, dim=1))

In [9]:
model = Regressor(encoder)

In [10]:
from transformers import Trainer, TrainingArguments

mse = nn.MSELoss()
class RegressorTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = mse(outputs, inputs.pop("labels"))
        return (loss, outputs) if return_outputs else loss

In [11]:
import numpy as np

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    mse = (np.array(preds) - np.array(labels)) ** 2 

    return {
        'MSE': mse,
        'RMSE': sqrt(mse)
    }

In [12]:
batch_size = 64
training_args = TrainingArguments(
    "../models/toxic_regressor_output",
    evaluation_strategy = "epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=10,
    num_train_epochs=10,
    save_steps=5000,
    # fp16=True,
    report_to='tensorboard',
)

trainer = RegressorTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_batch,
    #compute_metrics=compute_metrics
)

In [13]:
trainer.train()

  0%|          | 0/144450 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 0.1603, 'learning_rate': 0.0001993077189338872, 'epoch': 0.03}
{'loss': 0.1194, 'learning_rate': 0.00019861543786777431, 'epoch': 0.07}
{'loss': 0.1084, 'learning_rate': 0.0001979231568016615, 'epoch': 0.1}
{'loss': 0.1023, 'learning_rate': 0.00019723087573554865, 'epoch': 0.14}
{'loss': 0.1002, 'learning_rate': 0.0001965385946694358, 'epoch': 0.17}
{'loss': 0.0954, 'learning_rate': 0.00019584631360332295, 'epoch': 0.21}
{'loss': 0.0914, 'learning_rate': 0.00019515403253721013, 'epoch': 0.24}
{'loss': 0.0901, 'learning_rate': 0.00019446175147109726, 'epoch': 0.28}
{'loss': 0.0883, 'learning_rate': 0.00019376947040498444, 'epoch': 0.31}
{'loss': 0.0872, 'learning_rate': 0.0001930771893388716, 'epoch': 0.35}
{'loss': 0.0863, 'learning_rate': 0.00019238490827275874, 'epoch': 0.38}
{'loss': 0.0854, 'learning_rate': 0.00019169262720664592, 'epoch': 0.42}
{'loss': 0.0833, 'learning_rate': 0.00019100034614053305, 'epoch': 0.45}
{'loss': 0.082, 'learning_rate': 0.00019030806507442023,

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.06895655393600464, 'eval_runtime': 103.0339, 'eval_samples_per_second': 2243.049, 'eval_steps_per_second': 35.056, 'epoch': 1.0}
{'loss': 0.0713, 'learning_rate': 0.0001799238490827276, 'epoch': 1.0}
{'loss': 0.0669, 'learning_rate': 0.00017923156801661475, 'epoch': 1.04}
{'loss': 0.0657, 'learning_rate': 0.0001785392869505019, 'epoch': 1.07}
{'loss': 0.067, 'learning_rate': 0.00017784700588438908, 'epoch': 1.11}
{'loss': 0.0691, 'learning_rate': 0.00017715472481827623, 'epoch': 1.14}
{'loss': 0.0659, 'learning_rate': 0.00017646244375216338, 'epoch': 1.18}
{'loss': 0.0648, 'learning_rate': 0.00017577016268605056, 'epoch': 1.21}
{'loss': 0.0655, 'learning_rate': 0.0001750778816199377, 'epoch': 1.25}
{'loss': 0.0646, 'learning_rate': 0.00017438560055382487, 'epoch': 1.28}
{'loss': 0.0648, 'learning_rate': 0.00017369331948771202, 'epoch': 1.32}
{'loss': 0.0635, 'learning_rate': 0.00017300103842159917, 'epoch': 1.35}
{'loss': 0.0664, 'learning_rate': 0.00017230875735548633,

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.06197770684957504, 'eval_runtime': 102.3296, 'eval_samples_per_second': 2258.486, 'eval_steps_per_second': 35.298, 'epoch': 2.0}
{'loss': 0.0606, 'learning_rate': 0.00015984769816545518, 'epoch': 2.01}
{'loss': 0.0565, 'learning_rate': 0.00015915541709934233, 'epoch': 2.04}
{'loss': 0.0562, 'learning_rate': 0.0001584631360332295, 'epoch': 2.08}
{'loss': 0.0562, 'learning_rate': 0.00015777085496711666, 'epoch': 2.11}
{'loss': 0.0559, 'learning_rate': 0.00015707857390100382, 'epoch': 2.15}
{'loss': 0.0554, 'learning_rate': 0.00015638629283489097, 'epoch': 2.18}
{'loss': 0.0561, 'learning_rate': 0.00015569401176877815, 'epoch': 2.22}
{'loss': 0.0561, 'learning_rate': 0.00015500173070266527, 'epoch': 2.25}
{'loss': 0.0556, 'learning_rate': 0.00015430944963655245, 'epoch': 2.28}
{'loss': 0.0562, 'learning_rate': 0.0001536171685704396, 'epoch': 2.32}
{'loss': 0.0564, 'learning_rate': 0.00015292488750432676, 'epoch': 2.35}
{'loss': 0.0562, 'learning_rate': 0.000152232606438213

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.05781667307019234, 'eval_runtime': 121.506, 'eval_samples_per_second': 1902.046, 'eval_steps_per_second': 29.727, 'epoch': 3.0}
{'loss': 0.052, 'learning_rate': 0.00013977154724818276, 'epoch': 3.01}
{'loss': 0.0474, 'learning_rate': 0.00013907926618206992, 'epoch': 3.05}
{'loss': 0.0495, 'learning_rate': 0.0001383869851159571, 'epoch': 3.08}
{'loss': 0.0482, 'learning_rate': 0.00013769470404984425, 'epoch': 3.12}
{'loss': 0.0493, 'learning_rate': 0.0001370024229837314, 'epoch': 3.15}
{'loss': 0.0495, 'learning_rate': 0.00013631014191761855, 'epoch': 3.18}
{'loss': 0.0502, 'learning_rate': 0.0001356178608515057, 'epoch': 3.22}
{'loss': 0.0493, 'learning_rate': 0.00013492557978539289, 'epoch': 3.25}
{'loss': 0.0498, 'learning_rate': 0.00013423329871928004, 'epoch': 3.29}
{'loss': 0.0485, 'learning_rate': 0.0001335410176531672, 'epoch': 3.32}
{'loss': 0.0496, 'learning_rate': 0.00013284873658705434, 'epoch': 3.36}
{'loss': 0.0491, 'learning_rate': 0.00013215645552094152, 

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.056825902312994, 'eval_runtime': 105.3199, 'eval_samples_per_second': 2194.362, 'eval_steps_per_second': 34.296, 'epoch': 4.0}
{'loss': 0.0463, 'learning_rate': 0.00011969539633091036, 'epoch': 4.02}
{'loss': 0.0431, 'learning_rate': 0.0001190031152647975, 'epoch': 4.05}
{'loss': 0.0431, 'learning_rate': 0.00011831083419868467, 'epoch': 4.08}
{'loss': 0.0433, 'learning_rate': 0.00011761855313257183, 'epoch': 4.12}
{'loss': 0.0439, 'learning_rate': 0.00011692627206645899, 'epoch': 4.15}
{'loss': 0.043, 'learning_rate': 0.00011623399100034615, 'epoch': 4.19}
{'loss': 0.0434, 'learning_rate': 0.0001155417099342333, 'epoch': 4.22}
{'loss': 0.0443, 'learning_rate': 0.00011484942886812047, 'epoch': 4.26}
{'loss': 0.0441, 'learning_rate': 0.00011415714780200761, 'epoch': 4.29}
{'loss': 0.0445, 'learning_rate': 0.00011346486673589478, 'epoch': 4.33}
{'loss': 0.0448, 'learning_rate': 0.00011277258566978193, 'epoch': 4.36}
{'loss': 0.0443, 'learning_rate': 0.0001120803046036691, 

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.05378337204456329, 'eval_runtime': 167.7644, 'eval_samples_per_second': 1377.586, 'eval_steps_per_second': 21.53, 'epoch': 5.0}
{'loss': 0.0413, 'learning_rate': 9.961924541363793e-05, 'epoch': 5.02}
{'loss': 0.0394, 'learning_rate': 9.89269643475251e-05, 'epoch': 5.05}
{'loss': 0.0385, 'learning_rate': 9.823468328141225e-05, 'epoch': 5.09}
{'loss': 0.0384, 'learning_rate': 9.75424022152994e-05, 'epoch': 5.12}
{'loss': 0.0401, 'learning_rate': 9.685012114918659e-05, 'epoch': 5.16}
{'loss': 0.0387, 'learning_rate': 9.615784008307374e-05, 'epoch': 5.19}
{'loss': 0.0403, 'learning_rate': 9.546555901696089e-05, 'epoch': 5.23}
{'loss': 0.0394, 'learning_rate': 9.477327795084804e-05, 'epoch': 5.26}
{'loss': 0.0396, 'learning_rate': 9.408099688473521e-05, 'epoch': 5.3}
{'loss': 0.0401, 'learning_rate': 9.338871581862236e-05, 'epoch': 5.33}
{'loss': 0.0403, 'learning_rate': 9.269643475250951e-05, 'epoch': 5.37}
{'loss': 0.0403, 'learning_rate': 9.200415368639668e-05, 'epoch': 5

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.054195500910282135, 'eval_runtime': 163.2207, 'eval_samples_per_second': 1415.936, 'eval_steps_per_second': 22.13, 'epoch': 6.0}
{'loss': 0.038, 'learning_rate': 7.954309449636553e-05, 'epoch': 6.02}
{'loss': 0.0356, 'learning_rate': 7.885081343025269e-05, 'epoch': 6.06}
{'loss': 0.0369, 'learning_rate': 7.815853236413985e-05, 'epoch': 6.09}
{'loss': 0.0355, 'learning_rate': 7.7466251298027e-05, 'epoch': 6.13}
{'loss': 0.0365, 'learning_rate': 7.677397023191416e-05, 'epoch': 6.16}
{'loss': 0.0361, 'learning_rate': 7.608168916580132e-05, 'epoch': 6.2}
{'loss': 0.0367, 'learning_rate': 7.538940809968848e-05, 'epoch': 6.23}
{'loss': 0.0362, 'learning_rate': 7.469712703357563e-05, 'epoch': 6.27}
{'loss': 0.0363, 'learning_rate': 7.40048459674628e-05, 'epoch': 6.3}
{'loss': 0.0364, 'learning_rate': 7.331256490134996e-05, 'epoch': 6.33}
{'loss': 0.0376, 'learning_rate': 7.262028383523711e-05, 'epoch': 6.37}
{'loss': 0.0368, 'learning_rate': 7.192800276912426e-05, 'epoch': 6.4

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.054357439279556274, 'eval_runtime': 159.1459, 'eval_samples_per_second': 1452.189, 'eval_steps_per_second': 22.696, 'epoch': 7.0}
{'loss': 0.0344, 'learning_rate': 5.946694357909312e-05, 'epoch': 7.03}
{'loss': 0.033, 'learning_rate': 5.877466251298027e-05, 'epoch': 7.06}
{'loss': 0.033, 'learning_rate': 5.808238144686743e-05, 'epoch': 7.1}
{'loss': 0.0337, 'learning_rate': 5.739010038075458e-05, 'epoch': 7.13}
{'loss': 0.0334, 'learning_rate': 5.669781931464174e-05, 'epoch': 7.17}
{'loss': 0.0331, 'learning_rate': 5.600553824852891e-05, 'epoch': 7.2}
{'loss': 0.0333, 'learning_rate': 5.531325718241607e-05, 'epoch': 7.23}
{'loss': 0.034, 'learning_rate': 5.4620976116303226e-05, 'epoch': 7.27}
{'loss': 0.0336, 'learning_rate': 5.392869505019038e-05, 'epoch': 7.3}
{'loss': 0.0346, 'learning_rate': 5.323641398407754e-05, 'epoch': 7.34}
{'loss': 0.0325, 'learning_rate': 5.25441329179647e-05, 'epoch': 7.37}
{'loss': 0.0335, 'learning_rate': 5.185185185185185e-05, 'epoch': 7.

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.05384266749024391, 'eval_runtime': 104.3855, 'eval_samples_per_second': 2214.005, 'eval_steps_per_second': 34.603, 'epoch': 8.0}
{'loss': 0.0319, 'learning_rate': 3.93907926618207e-05, 'epoch': 8.03}
{'loss': 0.0306, 'learning_rate': 3.869851159570786e-05, 'epoch': 8.07}
{'loss': 0.0305, 'learning_rate': 3.8006230529595015e-05, 'epoch': 8.1}
{'loss': 0.0308, 'learning_rate': 3.7313949463482174e-05, 'epoch': 8.13}
{'loss': 0.0318, 'learning_rate': 3.6621668397369333e-05, 'epoch': 8.17}
{'loss': 0.0313, 'learning_rate': 3.592938733125649e-05, 'epoch': 8.2}
{'loss': 0.0304, 'learning_rate': 3.523710626514365e-05, 'epoch': 8.24}
{'loss': 0.0309, 'learning_rate': 3.4544825199030805e-05, 'epoch': 8.27}
{'loss': 0.0318, 'learning_rate': 3.385254413291797e-05, 'epoch': 8.31}
{'loss': 0.0321, 'learning_rate': 3.316026306680512e-05, 'epoch': 8.34}
{'loss': 0.0315, 'learning_rate': 3.246798200069228e-05, 'epoch': 8.38}
{'loss': 0.0296, 'learning_rate': 3.177570093457944e-05, 'epoc

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.05295313149690628, 'eval_runtime': 104.3652, 'eval_samples_per_second': 2214.436, 'eval_steps_per_second': 34.609, 'epoch': 9.0}
{'loss': 0.0298, 'learning_rate': 1.9314641744548288e-05, 'epoch': 9.03}
{'loss': 0.0295, 'learning_rate': 1.8622360678435447e-05, 'epoch': 9.07}
{'loss': 0.0303, 'learning_rate': 1.7930079612322603e-05, 'epoch': 9.1}
{'loss': 0.0287, 'learning_rate': 1.7237798546209763e-05, 'epoch': 9.14}
{'loss': 0.0296, 'learning_rate': 1.654551748009692e-05, 'epoch': 9.17}
{'loss': 0.0292, 'learning_rate': 1.5853236413984078e-05, 'epoch': 9.21}
{'loss': 0.0304, 'learning_rate': 1.5160955347871237e-05, 'epoch': 9.24}
{'loss': 0.0289, 'learning_rate': 1.4468674281758396e-05, 'epoch': 9.28}
{'loss': 0.0298, 'learning_rate': 1.3776393215645552e-05, 'epoch': 9.31}
{'loss': 0.029, 'learning_rate': 1.308411214953271e-05, 'epoch': 9.35}
{'loss': 0.0298, 'learning_rate': 1.2391831083419869e-05, 'epoch': 9.38}
{'loss': 0.0288, 'learning_rate': 1.1699550017307027e-05

  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 0.05349328741431236, 'eval_runtime': 104.8192, 'eval_samples_per_second': 2204.845, 'eval_steps_per_second': 34.459, 'epoch': 10.0}
{'train_runtime': 50535.0806, 'train_samples_per_second': 182.931, 'train_steps_per_second': 2.858, 'train_loss': 0.04704653177099601, 'epoch': 10.0}


TrainOutput(global_step=144450, training_loss=0.04704653177099601, metrics={'train_runtime': 50535.0806, 'train_samples_per_second': 182.931, 'train_steps_per_second': 2.858, 'train_loss': 0.04704653177099601, 'epoch': 10.0})

In [112]:
# saving model
trainer.save_model('../models/last_toxic_regressor')
torch.save(trainer.model.state_dict(), '../models/last_toxic_regressor/model.pt')

In [120]:
model = Regressor(encoder)
weights = torch.load('../models/last_toxic_regressor/model.pt')
model.load_state_dict(weights)
model = model.to('cuda')

In [115]:
def prompt(model, inference_request, tokenizer=tokenizer):
    input = tokenizer(inference_request, return_tensors="pt").to("cuda")
    outputs = model(**input).item()
    print(outputs)

In [116]:
from src.data.load_data import load_data
text_df = load_data('../data/raw/filtered.tsv', flatten=True)
idx = 0

In [123]:
idx = random.randint(0, len(text_df))
ptext = text_df['text'][idx]
print(ptext)
prompt(model, ptext)

A black car wouldn't be pretty.
0.1153540089726448
