In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset, load_metric

In [None]:
# Step 1: Load and Prepare Dataset
# Example using a custom translation dataset (replace with your dataset loading logic)
# Ensure your dataset is structured with "source_text" and "target_text" columns
dataset = load_dataset("csv", data_files={"train": "train.csv", "validation": "validation.csv"})

# Preprocess dataset for T5 training
def preprocess_function(examples):
    # Tokenize the inputs and targets
    inputs = tokenizer(examples["source_text"], padding="max_length", truncation=True, max_length=512)
    targets = tokenizer(examples["target_text"], padding="max_length", truncation=True, max_length=512)

    # Update the examples with the tokenized inputs and targets
    examples["input_ids"] = inputs.input_ids
    examples["attention_mask"] = inputs.attention_mask
    examples["labels"] = targets.input_ids

    return examples

train_dataset = dataset["train"].map(preprocess_function, batched=True)
valid_dataset = dataset["validation"].map(preprocess_function, batched=True)

# Step 2: Load Tokenizer and Model
model_name = 't5-small'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Step 3: Configure Training Arguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=1e-4,
    num_train_epochs=3,
    logging_dir='./logs',
    logging_steps=100,
    save_steps=500,
    save_total_limit=2,
)

# Step 4: Define Trainer and Train the Model
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
    compute_metrics=load_metric("sacrebleu"),
)

trainer.train()

# Step 5: Save the Trained Model
model.save_pretrained('./trained_t5_translation_model')
tokenizer.save_pretrained('./trained_t5_translation_model')