# Imports

In [1]:
import transformers
import torch
import random
import numpy as np
from torch.utils.data import random_split
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import GenerationConfig

import sys
if '..' not in sys.path: sys.path.append('..')
from src.data.make_dataset import load_detoxification_dataset, load_toxicity_dataset

# Load the pretrained T5

In [2]:
global_seed = 1984

transformers.set_seed(global_seed)
random.seed(global_seed)
np.random.seed(global_seed)
torch.manual_seed(global_seed)
torch.cuda.manual_seed_all(global_seed)
model_checkpoint = "t5-small"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
dataset_portion = 1
dataset_kwargs = {
    'path': '../data/raw/filtered.tsv', # path to raw data
    'cache_path': '../data/interim/tokenized.tsv', # path to processed data (or where to store it)
    'tokenizer': tokenizer, # tokenizer to tokenize texts
    'portion': dataset_portion # get only a portion of dataset [0..1]
}

# Dataset

In [4]:
dataset = load_detoxification_dataset(**dataset_kwargs)

val_ratio = 0.2
train_dataset, val_dataset = random_split(dataset, [1 - val_ratio, val_ratio])

# Training

In [5]:
# defining the parameters for training
genConfig = GenerationConfig.from_pretrained(model_checkpoint)
genConfig.max_new_tokens = 64

batch_size = 32
postfix = "-10"
save_model_path = f'../models/t5_detoxifier{postfix}x10lr'
args = Seq2SeqTrainingArguments(
    f"../models/{model_checkpoint}-detoxification{postfix}x10lr",
    evaluation_strategy = "epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=10,
    num_train_epochs=20,
    predict_with_generate=True,
    fp16=True,
    report_to='tensorboard',
    logging_steps=5000,
    save_steps=25000,
    generation_config=genConfig
)

In [6]:
collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collator,
    tokenizer=tokenizer,
    # compute_metrics=compute_metrics
)

In [7]:
trainer.train()

  0%|          | 0/288900 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 1.96, 'learning_rate': 0.00049134994807892, 'epoch': 0.35}
{'loss': 1.8517, 'learning_rate': 0.0004826998961578401, 'epoch': 0.69}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.6692404747009277, 'eval_runtime': 112.3317, 'eval_samples_per_second': 1028.694, 'eval_steps_per_second': 32.155, 'epoch': 1.0}
{'loss': 1.8016, 'learning_rate': 0.0004740515749394254, 'epoch': 1.04}
{'loss': 1.7309, 'learning_rate': 0.0004654015230183455, 'epoch': 1.38}
{'loss': 1.7211, 'learning_rate': 0.0004567532017999308, 'epoch': 1.73}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.6123136281967163, 'eval_runtime': 111.5584, 'eval_samples_per_second': 1035.826, 'eval_steps_per_second': 32.378, 'epoch': 2.0}
{'loss': 1.6984, 'learning_rate': 0.00044810488058151613, 'epoch': 2.08}
{'loss': 1.6463, 'learning_rate': 0.0004394548286604361, 'epoch': 2.42}
{'loss': 1.6487, 'learning_rate': 0.0004308047767393562, 'epoch': 2.77}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.582676649093628, 'eval_runtime': 111.5176, 'eval_samples_per_second': 1036.204, 'eval_steps_per_second': 32.39, 'epoch': 3.0}
{'loss': 1.6257, 'learning_rate': 0.0004221564555209415, 'epoch': 3.12}
{'loss': 1.588, 'learning_rate': 0.00041350640359986155, 'epoch': 3.46}
{'loss': 1.5958, 'learning_rate': 0.0004048580823814469, 'epoch': 3.81}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5665289163589478, 'eval_runtime': 111.5714, 'eval_samples_per_second': 1035.704, 'eval_steps_per_second': 32.374, 'epoch': 4.0}
{'loss': 1.5674, 'learning_rate': 0.0003962080304603669, 'epoch': 4.15}
{'loss': 1.5437, 'learning_rate': 0.000387557978539287, 'epoch': 4.5}
{'loss': 1.5541, 'learning_rate': 0.0003789096573208723, 'epoch': 4.85}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5518205165863037, 'eval_runtime': 111.2218, 'eval_samples_per_second': 1038.96, 'eval_steps_per_second': 32.476, 'epoch': 5.0}
{'loss': 1.5226, 'learning_rate': 0.0003702596053997923, 'epoch': 5.19}
{'loss': 1.5045, 'learning_rate': 0.0003616112841813776, 'epoch': 5.54}
{'loss': 1.5147, 'learning_rate': 0.0003529612322602977, 'epoch': 5.88}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5410844087600708, 'eval_runtime': 110.5182, 'eval_samples_per_second': 1045.574, 'eval_steps_per_second': 32.682, 'epoch': 6.0}
{'loss': 1.4796, 'learning_rate': 0.0003443111803392177, 'epoch': 6.23}
{'loss': 1.4691, 'learning_rate': 0.00033566285912080306, 'epoch': 6.58}
{'loss': 1.4806, 'learning_rate': 0.0003270128071997231, 'epoch': 6.92}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5386500358581543, 'eval_runtime': 111.3905, 'eval_samples_per_second': 1037.386, 'eval_steps_per_second': 32.426, 'epoch': 7.0}
{'loss': 1.4402, 'learning_rate': 0.0003183644859813084, 'epoch': 7.27}
{'loss': 1.4401, 'learning_rate': 0.00030971443406022846, 'epoch': 7.62}
{'loss': 1.4494, 'learning_rate': 0.0003010661128418138, 'epoch': 7.96}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5290231704711914, 'eval_runtime': 110.3255, 'eval_samples_per_second': 1047.401, 'eval_steps_per_second': 32.739, 'epoch': 8.0}
{'loss': 1.4062, 'learning_rate': 0.00029241606092073383, 'epoch': 8.31}
{'loss': 1.4092, 'learning_rate': 0.00028376600899965385, 'epoch': 8.65}
{'loss': 1.4206, 'learning_rate': 0.0002751159570785739, 'epoch': 9.0}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5234870910644531, 'eval_runtime': 110.4531, 'eval_samples_per_second': 1046.19, 'eval_steps_per_second': 32.702, 'epoch': 9.0}
{'loss': 1.3728, 'learning_rate': 0.00026646590515749395, 'epoch': 9.35}
{'loss': 1.3838, 'learning_rate': 0.0002578175839390793, 'epoch': 9.69}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5168672800064087, 'eval_runtime': 110.7633, 'eval_samples_per_second': 1043.26, 'eval_steps_per_second': 32.61, 'epoch': 10.0}
{'loss': 1.3847, 'learning_rate': 0.0002491675320179993, 'epoch': 10.04}
{'loss': 1.346, 'learning_rate': 0.00024051921079958465, 'epoch': 10.38}
{'loss': 1.3567, 'learning_rate': 0.0002318674281758394, 'epoch': 10.73}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5134544372558594, 'eval_runtime': 110.6393, 'eval_samples_per_second': 1044.43, 'eval_steps_per_second': 32.647, 'epoch': 11.0}
{'loss': 1.3543, 'learning_rate': 0.00022321737625475945, 'epoch': 11.08}
{'loss': 1.3214, 'learning_rate': 0.00021456732433367947, 'epoch': 11.42}
{'loss': 1.3333, 'learning_rate': 0.00020591727241259952, 'epoch': 11.77}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5121747255325317, 'eval_runtime': 111.8382, 'eval_samples_per_second': 1033.234, 'eval_steps_per_second': 32.297, 'epoch': 12.0}
{'loss': 1.3223, 'learning_rate': 0.0001972654897888543, 'epoch': 12.11}
{'loss': 1.2992, 'learning_rate': 0.00018861543786777432, 'epoch': 12.46}
{'loss': 1.3096, 'learning_rate': 0.0001799636552440291, 'epoch': 12.81}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5109244585037231, 'eval_runtime': 110.6252, 'eval_samples_per_second': 1044.563, 'eval_steps_per_second': 32.651, 'epoch': 13.0}
{'loss': 1.2923, 'learning_rate': 0.0001713136033229491, 'epoch': 13.15}
{'loss': 1.2769, 'learning_rate': 0.00016266355140186916, 'epoch': 13.5}
{'loss': 1.2873, 'learning_rate': 0.0001540117687781239, 'epoch': 13.85}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5143696069717407, 'eval_runtime': 111.7614, 'eval_samples_per_second': 1033.944, 'eval_steps_per_second': 32.319, 'epoch': 14.0}
{'loss': 1.2692, 'learning_rate': 0.00014536171685704396, 'epoch': 14.19}
{'loss': 1.2586, 'learning_rate': 0.0001367099342332987, 'epoch': 14.54}
{'loss': 1.2613, 'learning_rate': 0.00012805988231221876, 'epoch': 14.88}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5109548568725586, 'eval_runtime': 111.1046, 'eval_samples_per_second': 1040.056, 'eval_steps_per_second': 32.51, 'epoch': 15.0}
{'loss': 1.2446, 'learning_rate': 0.0001194115610938041, 'epoch': 15.23}
{'loss': 1.2351, 'learning_rate': 0.00011075977847005884, 'epoch': 15.58}
{'loss': 1.2438, 'learning_rate': 0.00010210972654897889, 'epoch': 15.92}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5105900764465332, 'eval_runtime': 111.1436, 'eval_samples_per_second': 1039.691, 'eval_steps_per_second': 32.498, 'epoch': 16.0}
{'loss': 1.2185, 'learning_rate': 9.345794392523364e-05, 'epoch': 16.27}
{'loss': 1.2178, 'learning_rate': 8.480962270681896e-05, 'epoch': 16.61}
{'loss': 1.2216, 'learning_rate': 7.615957078573901e-05, 'epoch': 16.96}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.517675757408142, 'eval_runtime': 112.2189, 'eval_samples_per_second': 1029.729, 'eval_steps_per_second': 32.187, 'epoch': 17.0}
{'loss': 1.199, 'learning_rate': 6.750951886465906e-05, 'epoch': 17.31}
{'loss': 1.1966, 'learning_rate': 5.88594669435791e-05, 'epoch': 17.65}
{'loss': 1.2049, 'learning_rate': 5.020768431983385e-05, 'epoch': 18.0}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5149683952331543, 'eval_runtime': 112.6942, 'eval_samples_per_second': 1025.385, 'eval_steps_per_second': 32.051, 'epoch': 18.0}
{'loss': 1.1824, 'learning_rate': 4.155936310141918e-05, 'epoch': 18.35}
{'loss': 1.1823, 'learning_rate': 3.290758047767393e-05, 'epoch': 18.69}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.518808364868164, 'eval_runtime': 92.0118, 'eval_samples_per_second': 1255.872, 'eval_steps_per_second': 39.256, 'epoch': 19.0}
{'loss': 1.1836, 'learning_rate': 2.4257528556593975e-05, 'epoch': 19.04}
{'loss': 1.1723, 'learning_rate': 1.5607476635514018e-05, 'epoch': 19.38}
{'loss': 1.1693, 'learning_rate': 6.95742471443406e-06, 'epoch': 19.73}


  0%|          | 0/3612 [00:00<?, ?it/s]

{'eval_loss': 1.5204483270645142, 'eval_runtime': 91.1479, 'eval_samples_per_second': 1267.775, 'eval_steps_per_second': 39.628, 'epoch': 20.0}
{'train_runtime': 33602.6877, 'train_samples_per_second': 275.11, 'train_steps_per_second': 8.598, 'train_loss': 1.406808580904941, 'epoch': 20.0}


TrainOutput(global_step=288900, training_loss=1.406808580904941, metrics={'train_runtime': 33602.6877, 'train_samples_per_second': 275.11, 'train_steps_per_second': 8.598, 'train_loss': 1.406808580904941, 'epoch': 20.0})

In [8]:
# saving model
trainer.save_model(save_model_path)

In [64]:
# loading the model and run inference for it
# model = AutoModelForSeq2SeqLM.from_pretrained(save_model_path)
model = AutoModelForSeq2SeqLM.from_pretrained('../models backup/t5_detoxifier-10')
model.eval()
model.config.use_cache = False

# Testing ??

In [65]:
def translate(model, inference_request, tokenizer=tokenizer):
    tokenized = tokenizer.encode(inference_request, return_tensors="pt")
    outputs = model.generate(tokenized, generation_config=genConfig)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [76]:
inference_request = """
this guy is a con man. He's an actor. He's just a character.
"""
translate(model, inference_request)

This guy is a con, he's an actor, he's just a character.


# Validation ????

In [13]:
from src.models.t5_toxicity_evaluator import T5TEModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
evalutator = T5TEModel('../models/last_toxic_regressor/model.pt').to(device)
model.to(device)
_ = evalutator.model.eval()

In [14]:
eval_dataset = load_toxicity_dataset(**dataset_kwargs)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=128, shuffle=False, collate_fn=evalutator.collate_batch)

In [15]:
from tqdm.auto import tqdm
transformed = []

for batch in tqdm(eval_loader, total=len(eval_loader), desc='Translating'):
    output = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask, generation_config=genConfig)
    transformed += output.detach().cpu()

Translating:   0%|          | 0/452 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [16]:
ref_evaluations = []

torch.cuda.empty_cache()
for batch in tqdm(eval_loader, total=len(eval_loader), desc='Evaluation'):
    output = evalutator(batch)
    ref_evaluations += output.detach().cpu()

Evaluation:   0%|          | 0/452 [00:00<?, ?it/s]

In [17]:
transformed_keys = [{'input_ids': x} for x in transformed]
trn_loader = torch.utils.data.DataLoader(transformed_keys, batch_size=128, shuffle=False, collate_fn=evalutator.collate_batch)

In [18]:
trn_evaluations = []

torch.cuda.empty_cache()
for batch in tqdm(trn_loader, total=len(trn_loader), desc='Evaluation'):
    output = evalutator(batch)
    trn_evaluations += output.detach().cpu()

Evaluation:   0%|          | 0/452 [00:00<?, ?it/s]

In [19]:
torch.cuda.empty_cache()

In [20]:
threshold = 0.5

refevs = np.array(ref_evaluations)
trnevs = np.array(trn_evaluations)

ref_toxs = refevs > threshold
trn_toxs = trnevs > threshold

In [21]:
ref_neutrals = ref_toxs == False
ref_toxics = ref_toxs == True
trn_neutrals = trn_toxs == False
trn_toxics = trn_toxs == True

print(f'Neutral -> neutral: {np.sum(ref_neutrals)} -> {np.sum(np.logical_and(ref_neutrals, trn_neutrals))}')
print(f'Neutral -> toxic: {np.sum(ref_neutrals)} -> {np.sum(np.logical_and(ref_neutrals, trn_toxics))}')
print(f'Toxic -> neutral: {np.sum(ref_toxics)} -> {np.sum(np.logical_and(ref_toxics, trn_neutrals))}')
print(f'Toxic -> toxic: {np.sum(ref_toxics)} -> {np.sum(np.logical_and(ref_toxics, trn_toxics))}')

Neutral -> neutral: 27871 -> 27857
Neutral -> toxic: 27871 -> 14
Toxic -> neutral: 29906 -> 28996
Toxic -> toxic: 29906 -> 910
