In [12]:
# --- CONFIGURATION ---
model_name = "google/mt5-small"
dataset_name = "ai4bharat/samanantar"
dataset_config = "as"
max_length = 128
batch_size = 2  # <-- REDUCED FURTHER to 2
gradient_accumulation_steps = 8  # <-- NEW: Accumulate gradients over 8 steps
num_epochs = 1

# --- 1. SETUP & INSTALLS ---
print("Installing required libraries...")
!pip install transformers[sentencepiece] datasets sacrebleu wandb -q > /dev/null
print("Libraries installed.")

from transformers import MT5ForConditionalGeneration, MT5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
import torch
import wandb

# --- 2. LOGIN TO W&B ---
wandb.login(key='100fa93408aa5a13bfb4acdc7d19060ef991b61a')

# --- 3. LOAD DATA AND MODEL ---
print("Loading dataset and model...")
dataset = load_dataset(dataset_name, dataset_config)

# --- NEW LINE: SUBSAMPLE THE DATA FOR A QUICK TEST ---
dataset["train"] = dataset["train"].select(range(10000))  # Use only the first 10,000 examples

tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name, use_cache=False)
print("Dataset and model loaded successfully!")

# --- 4. PREPROCESS DATA ---
def preprocess_function(examples):
    inputs = ["translate English to Assamese: " + en for en in examples['src']]
    targets = examples['tgt']
    model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True, padding="max_length")
    return model_inputs

print("Tokenizing dataset...")
tokenized_datasets = dataset.map(preprocess_function, batched=True)
print("Dataset tokenized.")

# --- 5. SET UP TRAINING ---
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    eval_strategy="no",
    save_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=batch_size, # Very small batch size
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps, # <-- NEW: Accumulate gradients
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=num_epochs,
    predict_with_generate=True,
    fp16=True,  # Mixed precision
    report_to="wandb",
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

# --- 6. START TRAINING! ---
print("Starting training...")
train_result = trainer.train()
print("Training finished!")

# --- 7. SAVE MODEL ---
trainer.save_model("my_english_assamese_model")
tokenizer.save_pretrained("my_english_assamese_model")
print("Model saved!")

# --- 8. LOG METRICS ---
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

wandb.finish()
print("Check your W&B dashboard for results!")

Installing required libraries...




Libraries installed.
Loading dataset and model...


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.


Dataset and model loaded successfully!
Tokenizing dataset...


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset tokenized.


  trainer = Seq2SeqTrainer(


Starting training...




Step,Training Loss


Training finished!
Model saved!
***** train metrics *****
  epoch                    =        1.0
  total_flos               =  1231091GF
  train_loss               =     6.7027
  train_runtime            = 0:16:30.06
  train_samples_per_second =       10.1
  train_steps_per_second   =      0.316


0,1
train/epoch,▁
train/global_step,▁

0,1
total_flos,1321874227200000.0
train/epoch,1.0
train/global_step,313.0
train_loss,6.70275
train_runtime,990.0644
train_samples_per_second,10.1
train_steps_per_second,0.316


Check your W&B dashboard for results!
