## More finetunnig and using a prefix
T5 Model have shown some performance in the detoxification task, but 1 epoch of finetunning with default setting was not enough to achive great performance. In this notebook, trough tweaking different parameters and adding a prefix to the input data, I attempt to reach better performance.

In [45]:
# reprepare the materials from previous notebook
import numpy as np
import pandas as pd
from datasets import load_metric
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
from sklearn.model_selection import train_test_split

df = pd.read_csv("processed.csv")
df.head()

# prepare the dataset
NUM_VAL = 3000
NUM_TEST = 550000

df_text = df[['toxic','detoxified']].rename(columns={'toxic':'input','detoxified':'target'})
train, val = train_test_split(df_text, test_size=NUM_VAL / len(df_text), random_state=42)
train, test = train_test_split(train, test_size=NUM_TEST / len(train), random_state=42)

train_dataset = Dataset.from_dict(train.to_dict(orient='list'))
val_dataset = Dataset.from_dict(val.to_dict(orient='list'))
test_dataset = Dataset.from_dict(test.to_dict(orient='list'))

# preprocess the dataset
max_input_length = 128
max_target_length = 128

# add the prefix
prefix = "detoxify text: "

def preprocess_function(examples):
    inputs = [prefix + ex for ex in examples['input']]
    targets = examples['target']
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# set up the metrics for the training process.
metric = load_metric("sacrebleu") # using the metric from the example

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [57]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# load model and tokeinzer
tokenizer = T5Tokenizer.from_pretrained("T5Small-detoxification")
model = T5ForConditionalGeneration.from_pretrained('T5Small-detoxification')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [47]:
train_processed = train_dataset.map(preprocess_function, batched=True)
val_processed = val_dataset.map(preprocess_function, batched=True)
#test_processed = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/44777 [00:00<?, ? examples/s]



Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [58]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

batch_size = 64
args = Seq2SeqTrainingArguments(
    "T5-Small-finetuned-detoxification",
    evaluation_strategy = "epoch",
    learning_rate=1e-2,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=10,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=True,
    #generation_max_length=64,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_processed,
    eval_dataset=val_processed,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model('T5Small-detoxification-prefix')

Detected kernel version 5.4.254, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,2.461,2.136438,20.4103,13.4513
2,2.2042,2.046599,21.5165,13.178
3,1.8823,1.97745,21.9587,13.2163
4,1.7042,1.943634,22.6424,13.1313
5,1.4199,1.968257,22.6733,13.163




In [59]:
# load the model from file
tokenizer_trained = T5Tokenizer.from_pretrained("T5Small-detoxification-prefix")
model_trained = T5ForConditionalGeneration.from_pretrained('T5Small-detoxification-prefix')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [60]:
# test if it is working on a part of test dataset
input = []
result = [] 
target = []
N = 50

for i in range(N):
    input_ids = tokenizer_trained(prefix + test_dataset['input'][i], return_tensors="pt").input_ids
    outputs = model_trained.generate(input_ids, max_new_tokens = 128)
    input.append(test_dataset['input'][i])
    result.append(tokenizer_trained.decode(outputs[0], skip_special_tokens=True))
    target.append(test_dataset['target'][i])

pd.DataFrame.from_dict({'input': input, 'result': result, 'target': target}).head(N)

Unnamed: 0,input,result,target
0,"besides, I don't think I want two scoundrels, ...","Besides, I don't think I want two pricks, my n...","Besides, I don't think I’d be willing to take ..."
1,take a look at yourself and then call me crazy.,Take a look at yourself and then call me crazy.,See for yourself and then call me crazy.
2,and the compilation of culture is a tyranny ag...,And the compilation of culture is a tyranny ag...,And the commodification of culture is uniquely...
3,"For example, watches. Electronic shit, anythin...","for example, it's electronic, anything you thi...","a watch, an electronics, everything we might l..."
4,I can't believe someone would want to kill me.,I can't believe someone would try to kill me.,I can't believe someone's trying to kill me.
5,your mother was a hamster... and your father s...,Your mother was a hamster... and your father s...,"Thy mother's son, and thy father's shadow!"
6,Know what I'd fucking love ?,you know what I'd love?,you know what I like?
7,Rex is violent and he hates us.,Rex is violent and he hates us.,Rex is aggressive and he hates us.
8,Hypocrites?,hypocrites?,hypocritical?
9,"Everybody's fucking except for me, and it's dr...","everyone is except me, and it drives me crazy.","everyone's sleeping together except for me, an..."
