In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from datasets import load_dataset, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "facebook/nllb-200-distilled-600M"

model = AutoModelForSeq2SeqLM.from_pretrained(
                                            model_name,
                                            device_map='auto',
                                            use_cache=False
                                            )

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
train_ds = load_dataset("amaniopia/merged_train", split="train")
valid_ds = load_dataset("amaniopia/flores-merged", split="train")

In [5]:
train_ds, valid_ds

(Dataset({
     features: ['source', 'target', 'src_lang', 'tgt_lang'],
     num_rows: 3145570
 }),
 Dataset({
     features: ['source', 'target', 'src_lang', 'tgt_lang'],
     num_rows: 29910
 }))

In [6]:
small_train = train_ds.filter(lambda example : example["src_lang"] in ["eng_Latn", "amh_Ethi", "swh_Latn"] and example["tgt_lang"] in ["eng_Latn", "amh_Ethi", "swh_Latn"])
small_train = small_train.shuffle(seed=42).select(range(0,4000))
small_valid = valid_ds.filter(lambda example : example["src_lang"] in ["eng_Latn", "amh_Ethi", "swh_Latn"] and example["tgt_lang"] in ["eng_Latn", "amh_Ethi", "swh_Latn"])

In [7]:
small_train, small_valid

(Dataset({
     features: ['source', 'target', 'src_lang', 'tgt_lang'],
     num_rows: 4000
 }),
 Dataset({
     features: ['source', 'target', 'src_lang', 'tgt_lang'],
     num_rows: 3988
 }))

In [8]:
def tokenize_fn(examples):
    inputs = examples["source"]
    targets = examples["target"]
    src_langs = examples["src_lang"]
    tgt_langs = examples["tgt_lang"]

    input_ids = []
    labels = []

    for src, tgt, src_lang, tgt_lang in zip(inputs, targets, src_langs, tgt_langs):
        tokenizer.src_lang = src_lang
        tokenizer.tgt_lang = tgt_lang


        tokenized = tokenizer(
            src,
            text_target=tgt,
            max_length=512,
            padding="max_length",
            truncation=True,
        )

        input_ids.append(tokenized["input_ids"])
        labels.append(tokenized["labels"])

    return {"input_ids": input_ids, "labels": labels}

In [9]:
# Tokenize datasets
tokenized_train = small_train.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names)
tokenized_valid = small_valid.map(tokenize_fn, batched=True, remove_columns=valid_ds.column_names)

Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1747.89 examples/s]


In [10]:
set(small_valid["tgt_lang"])

{'amh_Ethi', 'eng_Latn', 'swh_Latn'}

In [11]:
tokenizer.decode(tokenized_valid[0]["input_ids"])

'amh_Ethi ሰኞ እለት፣ በስታንፎርድ ዩኒቨርሲቲ የህክምና ትምህርት ቤት ህዋሶችን በአይነት የሚያስቀምጥ አዲስ የምርመራ መሳሪያ እንደተፈጠረ አስታውቋል፡ እያንዳንዱን በአንደ የዩ.ኤስ ሳንቲም የሚሆን መደበኛ የኢንክጄት አታሚዎችን በመጠቀም ሊፈበረክ የሚችል ትንሽ መታተም የሚችል ቺፕ።</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [12]:
tokenizer.decode(tokenized_valid[0]["labels"])

'eng_Latn On Monday, scientists from the Stanford University School of Medicine announced the invention of a new diagnostic tool that can sort cells by type: a tiny printable chip that can be manufactured using standard inkjet printers for possibly about one U.S. cent each.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p

In [13]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, logging
from huggingface_hub import HfFolder
import os

epochs = 2
learning_rate = 5e-5
batch_size = 4

hf_id = "amaniopia"  # change to your huggingface username
output_dir = os.path.join(hf_id, f"nllb-200-3.3B-finetuned")



# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=4,
    eval_accumulation_steps=4,
    #gradient_checkpointing=True,

    fp16=True,
    fp16_full_eval=True,

    learning_rate=learning_rate,
    lr_scheduler_type='constant',  # "constant", "linear", "cosine"

    eval_strategy="steps",  # or "epoch"
    eval_steps=100,
    save_strategy="epoch",
    logging_steps=100,
    report_to="none", # "tensorboard", "wandb", or "none"

    # push to hub parameters
    push_to_hub=True,
    hub_private_repo=True,
    hub_strategy="every_save",
    hub_token=HfFolder.get_token(),
)

# Initialize the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid,
)

In [14]:
trainer.train()

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Step,Training Loss,Validation Loss
100,7.4278,4.746602
200,3.1146,1.174518
300,0.4467,0.106117
400,0.096,0.088193
500,0.0929,0.085897




TrainOutput(global_step=500, training_loss=2.2355946826934816, metrics={'train_runtime': 1299.3975, 'train_samples_per_second': 6.157, 'train_steps_per_second': 0.385, 'total_flos': 8668418408448000.0, 'train_loss': 2.2355946826934816, 'epoch': 2.0})

In [25]:
import torch, gc

# Delete model, tokenizer, and trainer objects
# del model
del tokenizer
try:
    del trainer
except:
    pass

# Run garbage collection
gc.collect()

# Clear CUDA cache
torch.cuda.empty_cache()

In [26]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Mon Oct 13 09:01:37 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.247.01             Driver Version: 535.247.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A10G                    Off | 00000000:00:1E.0 Off |                    0 |
|  0%   30C    P0              55W / 300W |  19078MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    