## Multi stage training
This notebook explores the possibility of multi-stage training for the detoxification model, with first stage being done on a fraction of the dataset, and the second stage being done on the bigger part.

In [None]:
# reprepare the materials from previous notebook
import numpy as np
import pandas as pd
from datasets import load_metric
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
from sklearn.model_selection import train_test_split

df = pd.read_csv("processed.csv")
df.head()

# prepare the dataset
NUM_PRE = 20000
NUM_VAL = 5000
NUM_TEST = 400000

df_text = df[['toxic','detoxified']].rename(columns={'toxic':'input','detoxified':'target'})
train, pre = train_test_split(df_text, test_size=NUM_PRE / len(df_text), random_state=42)
train, val = train_test_split(train, test_size=NUM_VAL / len(train), random_state=42)
train, test = train_test_split(train, test_size=NUM_TEST / len(train), random_state=42)

pretrain_dataset = Dataset.from_dict(pre.to_dict(orient='list'))
train_dataset = Dataset.from_dict(train.to_dict(orient='list'))
val_dataset = Dataset.from_dict(val.to_dict(orient='list'))
test_dataset = Dataset.from_dict(test.to_dict(orient='list'))

# preprocess the dataset
max_input_length = 128
max_target_length = 128

# add the prefix
prefix = "detoxify text: "

def preprocess_function(examples):
    inputs = [prefix + ex for ex in examples['input']]
    targets = examples['target']
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# set up the metrics for the training process.
metric = load_metric("sacrebleu") # using the metric from the example

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# load model and tokeinzer
tokenizer = T5Tokenizer.from_pretrained("T5Small-detoxification")
model = T5ForConditionalGeneration.from_pretrained('T5Small-detoxification')

In [None]:
pretrain_processed = pretrain_dataset.map(preprocess_function, batched=True)
train_processed = train_dataset.map(preprocess_function, batched=True)
val_processed = val_dataset.map(preprocess_function, batched=True)
#test_processed = test_dataset.map(preprocess_function, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

batch_size = 64
args = Seq2SeqTrainingArguments(
    "T5-Small-finetuned-detoxification",
    evaluation_strategy = "epoch",
    learning_rate=1e-2,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=10,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
    #generation_max_length=64,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=pretrain_processed,
    eval_dataset=val_processed,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model('T5Small-detoxification-step1')

In [None]:
# load the model from file
tokenizer = T5Tokenizer.from_pretrained("T5Small-detoxification-step1")
model = T5ForConditionalGeneration.from_pretrained('T5Small-detoxification-step1')

In [None]:
# run the second stage of the training
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

batch_size = 64
args = Seq2SeqTrainingArguments(
    "T5-Small-finetuned-detoxification",
    evaluation_strategy = "epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=10,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=True,
    #generation_max_length=64,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_processed,
    eval_dataset=val_processed,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model('T5Small-detoxification-step2')

In [None]:
# load the model from file
tokenizer_trained = T5Tokenizer.from_pretrained("T5Small-detoxification-step2")
model_trained = T5ForConditionalGeneration.from_pretrained('T5Small-detoxification-step2')

In [None]:
# test if it is working on a part of test dataset
input = []
result = [] 
target = []
N = 50

for i in range(N):
    input_ids = tokenizer_trained(prefix + test_dataset['input'][i], return_tensors="pt").input_ids
    outputs = model_trained.generate(input_ids, max_new_tokens = 128)
    input.append(test_dataset['input'][i])
    result.append(tokenizer_trained.decode(outputs[0], skip_special_tokens=True))
    target.append(test_dataset['target'][i])

pd.DataFrame.from_dict({'input': input, 'result': result, 'target': target}).head(N)