In [None]:
!pip install -q --upgrade transformers datasets evaluate sacrebleu sentencepiece

In [None]:
import os
import torch
from datetime import datetime
from transformers import (
    MT5Tokenizer, MT5ForConditionalGeneration,
    Seq2SeqTrainer, Seq2SeqTrainingArguments,
    EarlyStoppingCallback, DataCollatorForSeq2Seq, TrainerCallback
)
from datasets import load_dataset
import evaluate
import numpy as np
from transformers.trainer_utils import get_last_checkpoint
# BLEU metric
bleu = evaluate.load("sacrebleu")


In [None]:
# BLEU logger callback
class BLEULoggerCallback(TrainerCallback):
    def __init__(self, log_path="bleu_log.csv"):
        self.log_path = log_path
        self.logs = []
        if os.path.exists(self.log_path):
            os.remove(self.log_path)

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics and "eval_bleu" in metrics:
            print(f"Epoch {state.epoch:.1f} - BLEU: {metrics['eval_bleu']:.2f}")
            with open(self.log_path, "a") as f:
                f.write(f"{state.epoch},{metrics['eval_bleu']:.2f}\n")

In [None]:
# Load tokenizer & model
model_name = "google/mt5-small"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)

In [None]:
# Load dataset
data_files = {
    "train": "/kaggle/input/data-train/train.csv",
    "validation": "/kaggle/input/data-train/val.csv"
}
raw_datasets = load_dataset("csv", data_files=data_files)

In [None]:
# Tokenization
max_source_length = 128
max_target_length = 128

def preprocess(example):
    model_inputs = tokenizer(
        example["source"],
        max_length=max_source_length,
        truncation=True,
        padding="max_length"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["target"],
            max_length=max_target_length,
            truncation=True,
            padding="max_length"
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
tokenized_datasets = raw_datasets.map(preprocess, batched=True)

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    valid_pairs = [
        (pred, label)
        for pred, label in zip(decoded_preds, decoded_labels)
        if pred and label
    ]

    if not valid_pairs:
        return {"eval_bleu": 0.0}

    valid_preds, valid_labels = zip(*valid_pairs)

   
    result = bleu.compute(
        predictions=list(valid_preds),
        references=[[label] for label in valid_labels]
    )

    result["eval_bleu"] = result.pop("score")  
    return result


In [None]:
# Training config
training_args = Seq2SeqTrainingArguments(
    output_dir="./mt5_nom_translate",
    run_name=f"mt5_nom_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    eval_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=100,
    predict_with_generate=True,
    save_strategy="epoch",
    report_to="none",
    logging_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_bleu",
    greater_is_better=True,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Trainer with BLEU log + early stopping
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=2),
        BLEULoggerCallback(log_path="bleu_log.csv")
    ]
)

In [None]:
conti_train_dir="/kaggle/input/data-train"
last_checkpoint = get_last_checkpoint(conti_train_dir)

if last_checkpoint is not None:
    print(f"Found checkpoint at {last_checkpoint}, resuming training")
    train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    print(" No checkpoint found. Starting fresh training")
    train_result = trainer.train()

trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)