In [None]:
# ------------------------------
# Cell 1 - Imports & setup
# ------------------------------
import os
import glob
import torch
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    BitsAndBytesConfig,
    TrainerCallback
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import evaluate
import numpy as np

# ------------------------------
# Cell 2 - Paths & global params
# ------------------------------
TRAIN_GLOB = "dataset/toxic_nontoxic/*.parquet"
OUTPUT_DIR = "mt0_base_qlora_detox_t5"
os.makedirs(OUTPUT_DIR, exist_ok=True)
TEST_OUTPUT_DIR = "dataset/test_outputs"
os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)

MODEL_NAME = "bigscience/mt0-base"
SEED = 42

# Seq2seq lengths (will overwrite after dataset analysis)
MAX_SOURCE_LENGTH = 256
MAX_TARGET_LENGTH = 256

# Training hyperparams
PER_DEVICE_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8
NUM_EPOCHS = 3
LEARNING_RATE = 1e-4
LOGGING_STEPS = 50

# BitsAndBytes / QLoRA
BNB_QUANT_TYPE = "nf4"
COMPUTE_DTYPE = torch.bfloat16

# LoRA hyperparams
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q", "v"]

# Inference settings
INFERENCE_BATCH = 4
NUM_BEAMS = 4
MAX_GEN_LEN = 128

# ------------------------------
# Cell 3 - Load tokenizer
# ------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded. pad_token_id:", tokenizer.pad_token_id)

# ------------------------------
# Cell 4 - Load train dataset
# ------------------------------
train_files = sorted(glob.glob(TRAIN_GLOB))
if len(train_files) == 0:
    raise FileNotFoundError(f"No train parquet files found with pattern: {TRAIN_GLOB}")
print("Found train files:", train_files)

raw_train = load_dataset("parquet", data_files={"train": train_files})["train"]
print("Loaded dataset. Examples:", len(raw_train))
print("Example row:", raw_train[0])

# Filter empty rows
raw_train = raw_train.filter(
    lambda x: x["toxic_sentence"].strip() != "" and x["neutral_sentence"].strip() != ""
)
print("Filtered empty rows. Remaining examples:", len(raw_train))

# ------------------------------
# Cell 5 - Compute max token lengths
# ------------------------------
def get_max_token_lengths(dataset, source_col="toxic_sentence", target_col="neutral_sentence"):
    max_src, max_tgt = 0, 0
    for row in dataset:
        src_len = len(tokenizer(str(row[source_col]), truncation=False)["input_ids"])
        tgt_len = len(tokenizer(str(row[target_col]), truncation=False)["input_ids"])
        max_src = max(max_src, src_len)
        max_tgt = max(max_tgt, tgt_len)
    return max_src, max_tgt

max_source_len, max_target_len = get_max_token_lengths(raw_train)
print(f"Max source length: {max_source_len}, Max target length: {max_target_len}")

MAX_SOURCE_LENGTH = max_source_len
MAX_TARGET_LENGTH = max_target_len
print(f"Updated MAX_SOURCE_LENGTH={MAX_SOURCE_LENGTH}, MAX_TARGET_LENGTH={MAX_TARGET_LENGTH}")

# Shuffle & split
raw_train = raw_train.shuffle(seed=SEED)
split = raw_train.train_test_split(test_size=0.01, seed=SEED)
datasets = DatasetDict({
    "train": split["train"],
    "validation": split["test"]
})
print("Train size:", len(datasets["train"]), "Validation size:", len(datasets["validation"]))

# ------------------------------
# Cell 6 - Preprocess and tokenize (fixed)
# ------------------------------
SOURCE_COL = "toxic_sentence"
TARGET_COL = "neutral_sentence"

def preprocess_function(examples):
    # Tokenize source
    model_inputs = tokenizer(
        examples[SOURCE_COL],
        max_length=MAX_SOURCE_LENGTH,
        truncation=True,
        padding="max_length"
    )

    # Tokenize target / labels
    labels = tokenizer(
        examples[TARGET_COL],
        max_length=MAX_TARGET_LENGTH,
        truncation=True,
        padding="max_length"
    )["input_ids"]

    # Replace pad token id with -100
    labels = [[(l if l != tokenizer.pad_token_id else -100) for l in seq] for seq in labels]
    model_inputs["labels"] = labels
    return model_inputs

tokenized = datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=datasets["train"].column_names
)

print("Tokenization done. Sample:", tokenized["train"][0])

# ------------------------------
# Cell 7 - Load model with 4-bit QLoRA
# ------------------------------
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=COMPUTE_DTYPE,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type=BNB_QUANT_TYPE,
)

model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
try:
    model.gradient_checkpointing_enable()
except Exception:
    pass
print("Model ready for k-bit training.")

# ------------------------------
# Cell 8 - PEFT / LoRA
# ------------------------------
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
peft_model = get_peft_model(model, lora_config)
trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in peft_model.parameters())
print(f"PEFT model ready — trainable params: {trainable} / {total}")

# ------------------------------
# Cell 9 - Metrics & collator
# ------------------------------
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=peft_model,
    padding="longest"
)

# ------------------------------
# Cell 10 - Save best loss callback (fixed)
# ------------------------------
class SaveBestLossCallback(TrainerCallback):
    def __init__(self, output_dir):
        self.best_loss = float("inf")
        self.output_dir = os.path.join(output_dir, "best_eval")
        os.makedirs(self.output_dir, exist_ok=True)

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is None or "eval_loss" not in metrics:
            return
        loss = metrics["eval_loss"]
        if loss < self.best_loss:
            print(f"eval_loss improved {self.best_loss} -> {loss}, saving model")
            self.best_loss = loss
            kwargs["model"].save_pretrained(self.output_dir)
            # Do not save tokenizer here to avoid KeyError

save_best_cb = SaveBestLossCallback(OUTPUT_DIR)

# ------------------------------
# Cell 11 - TrainingArguments & Trainer
# ------------------------------
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    predict_with_generate=False,
    bf16=True,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=LOGGING_STEPS,
    eval_strategy="steps",
    eval_steps=20,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=None,
    callbacks=[save_best_cb],
)

# ------------------------------
# Cell 12 - Train & save
# ------------------------------
train_result = trainer.train()
trainer.save_model(OUTPUT_DIR)
peft_model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Training finished. Saved adapters to", OUTPUT_DIR)

# ------------------------------
# Cell 13 - Test inference
# ------------------------------
def detoxify(text):
    inputs = tokenizer(text, return_tensors="pt", max_length=88, truncation=True).to("cuda")
    outputs = peft_model.generate(**inputs, max_length=102)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test
test_text = "your toxic text here"
print(detoxify(test_text))


Tokenizer loaded. pad_token_id: 0
Found train files: ['dataset/toxic_nontoxic\\multilingual_paradetox_am.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_ar.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_de.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_en.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_es.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_hi.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_ru.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_uk.parquet', 'dataset/toxic_nontoxic\\multilingual_paradetox_zh.parquet']
Loaded dataset. Examples: 3600
Example row: {'toxic_sentence': 'ገልቱዬ ስንቴ ነው ሚሞቱት ግን? መሞታቸውን የዛሬ ወርም አርድተኘን ነበር ??', 'neutral_sentence': 'እሳቸው ስንቴ ነው ሞቱ ብለህ የምትነግረን ?'}
Filtered empty rows. Remaining examples: 3600
Max source length: 88, Max target length: 102
Updated MAX_SOURCE_LENGTH=88, MAX_TARGET_LENGTH=102
Train size: 3564 Validation size: 36


Map: 100%|██████████| 3564/3564 [00:01<00:00, 2539.83 examples/s]
Map: 100%|██████████| 36/36 [00:00<00:00, 1155.29 examples/s]


Tokenization done. Sample: {'input_ids': [904, 1121, 22256, 456, 81213, 1040, 894, 58657, 1773, 4134, 1285, 41312, 324, 3013, 80602, 1964, 261, 259, 52164, 1040, 309, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [904, 9599, 102361, 152617, 10884, 558, 894, 58657, 1773, 4134, 1285, 41312, 324, 3013, 80602, 1964, 261, 259, 52164, 1040, 309, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -

KeyboardInterrupt: 

In [22]:

# ------------------------------
# Cell 12 - Train & save
# ------------------------------
train_result = trainer.train()
trainer.save_model(OUTPUT_DIR)
peft_model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Training finished. Saved adapters to", OUTPUT_DIR)

Step,Training Loss,Validation Loss
20,No log,


KeyboardInterrupt: 