# Imports

In [1]:
import transformers
import torch
import random
import numpy as np
from torch.utils.data import random_split
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import GenerationConfig

import sys
if '../' not in sys.path: sys.path.insert(1, '../')
from src.data.load_dataset import load_detoxification_dataset, load_toxicity_dataset

# Load the pretrained T5

In [None]:
global_seed = 1984

transformers.set_seed(global_seed)
random.seed(global_seed)
np.random.seed(global_seed)
torch.manual_seed(global_seed)
torch.cuda.manual_seed_all(global_seed)
model_checkpoint = "t5-small"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
dataset_portion = 0.05
dataset_kwargs = {
    'path': '../data/raw/filtered.tsv', # path to raw data
    'cache_path': '../data/processed/tokenized.tsv', # path to processed data (or where to store it)
    'tokenizer': tokenizer, # tokenizer to tokenize texts
    'portion': dataset_portion # get only a portion of dataset [0..1]
}

# Dataset

In [None]:
dataset = load_detoxification_dataset(**dataset_kwargs)

val_ratio = 0.2
train_dataset, val_dataset = random_split(dataset, [1 - val_ratio, val_ratio])

# Training

In [None]:
# defining the parameters for training
genConfig = GenerationConfig.from_pretrained(model_checkpoint)
genConfig.max_length = 128

batch_size = 32
postfix = "-10"
args = Seq2SeqTrainingArguments(
    f"../models/{model_checkpoint}-detoxification{postfix}",
    evaluation_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    fp16=True,
    report_to='tensorboard',
    logging_steps=5000,
    save_steps=10000,
    generation_config=genConfig
)

In [None]:
collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collator,
    tokenizer=tokenizer,
    # compute_metrics=compute_metrics
)

In [None]:
trainer.train()

  0%|          | 0/144450 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
model_path = f'../models/t5_detoxifier{postfix}'

In [None]:
# saving model
trainer.save_model(model_path)

In [None]:
# loading the model and run inference for it
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
model.config.use_cache = False

# Testing ??

In [None]:
def translate(model, inference_request, tokenizer=tokenizer, print_ids=False):
    input_ids = tokenizer.encode(inference_request, return_tensors="pt")
    if print_ids: print(input_ids)
    outputs = model.generate(input_ids=input_ids)
    if print_ids: print(outputs)
    print(outputs[0].size())
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
inference_request = "Hello, world!"
translate(model, inference_request)

torch.Size([6])
hello, world!




# Validation ????

In [None]:
from src.models.t5_toxicity_evaluator import T5TEModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
evalutator = T5TEModel('../models/last_toxic_regressor/model.pt').to(device)
model.to(device)
_ = evalutator.model.eval()

In [None]:
eval_dataset = load_toxicity_dataset(**dataset_kwargs)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=128, shuffle=False, collate_fn=evalutator.collate_batch)

In [None]:
from tqdm.auto import tqdm
transformed = []

for batch in tqdm(eval_loader, total=len(eval_loader), desc='Translating'):
    output = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask, generation_config=genConfig)
    transformed += output.detach().cpu()

Translating:   0%|          | 0/9028 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


KeyboardInterrupt: 

In [None]:
ref_evaluations = []

torch.cuda.empty_cache()
for batch in tqdm(eval_loader, total=len(eval_loader), desc='Evaluation'):
    output = evalutator(batch)
    ref_evaluations += output.detach().cpu()

Evaluation:   0%|          | 0/9028 [00:00<?, ?it/s]

In [None]:
transformed_keys = [{'input_ids': x} for x in transformed]
trn_loader = torch.utils.data.DataLoader(transformed_keys, batch_size=128, shuffle=False, collate_fn=evalutator.collate_batch)

In [None]:
trn_evaluations = []

torch.cuda.empty_cache()
for batch in tqdm(trn_loader, total=len(trn_loader), desc='Evaluation'):
    output = evalutator(batch)
    trn_evaluations += output.detach().cpu()

Evaluation:   0%|          | 0/9028 [00:00<?, ?it/s]

In [None]:
torch.cuda.empty_cache()

In [None]:
threshold = 0.5

refevs = np.array(ref_evaluations)
trnevs = np.array(trn_evaluations)

ref_toxs = refevs > threshold
trn_toxs = trnevs > threshold

In [None]:
ref_neutrals = ref_toxs == False
ref_toxics = ref_toxs == True
trn_neutrals = trn_toxs == False
trn_toxics = trn_toxs == True

print(f'Neutral -> neutral: {np.sum(ref_neutrals)} -> {np.sum(np.logical_and(ref_neutrals, trn_neutrals))}')
print(f'Neutral -> toxic: {np.sum(ref_neutrals)} -> {np.sum(np.logical_and(ref_neutrals, trn_toxics))}')
print(f'Toxic -> neutral: {np.sum(ref_toxics)} -> {np.sum(np.logical_and(ref_toxics, trn_neutrals))}')
print(f'Toxic -> toxic: {np.sum(ref_toxics)} -> {np.sum(np.logical_and(ref_toxics, trn_toxics))}')

Neutral -> neutral: 558246 -> 557953
Neutral -> toxic: 558246 -> 293
Toxic -> neutral: 597308 -> 592276
Toxic -> toxic: 597308 -> 5032
