In [1]:
from datasets import load_dataset
import evaluate
import torch
import os

# Use CPU only
device = torch.device("cpu")
print(f"Using device: {device}")

# Load dataset
dataset = load_dataset("parquet", data_files={
    "train": "data/iamTangsang_dataset/train-00000-of-00001.parquet",
    "validation": "data/iamTangsang_dataset/validation-00000-of-00001.parquet", 
    "test": "data/iamTangsang_dataset/test-00000-of-00001.parquet"
})

# Subset the data
dataset["train"] = dataset["train"].select(range(5000))       
dataset["validation"] = dataset["validation"].select(range(500))
dataset["test"] = dataset["test"].select(range(500)) 

# Filter out empty strings and ensure data quality
def is_valid_pair(example):
    source = str(example["source"]).strip() if example["source"] is not None else ""
    target = str(example["target"]).strip() if example["target"] is not None else ""
    return len(source) > 0 and len(target) > 0 and len(source) < 1000 and len(target) < 1000

dataset = dataset.filter(is_valid_pair)

print("Dataset sizes after filtering:")
print(f"Train: {len(dataset['train'])}")
print(f"Validation: {len(dataset['validation'])}")
print(f"Test: {len(dataset['test'])}")
print("\nSample:")
print(dataset["train"][0])

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
Dataset sizes after filtering:
Train: 5000
Validation: 500
Test: 500

Sample:
{'source': '"कुनै पनि अन्य सरकारी एजेन्सीले यो जानकारी प्रयोग गर्न सक्दैन, केन्द्रीय सरकार अन्तर्गतका कसैले कुनै पनि हालतमा यो जानकारी पाउँदैनन् र राज्यका अधिकारीहरूमा पनि स्वास्थ्य अधिकारीहरूले मात्र यसलाई प्रयोग गर्न सक्दछन्," उनले भने।', 'target': '"No other government agency can use this information, no one in the commonwealth government at all, and in state authorities, only the health officer can use it.'}


In [2]:
from transformers import MT5ForConditionalGeneration, MT5Tokenizer

tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", legacy=False)

def preprocess_function(examples):
    sources = [str(src).strip() for src in examples["source"]]
    targets = [str(tgt).strip() for tgt in examples["target"]]
    inputs = ["translate Nepali to English: " + src for src in sources]
    model_inputs = tokenizer(inputs, max_length=300, truncation=True, padding=True, return_tensors=None)
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=300, truncation=True, padding=True, return_tensors=None)

    model_inputs["labels"] = [
        [(token_id if token_id != tokenizer.pad_token_id else -100) for token_id in label]
        for label in labels["input_ids"]
    ]
    
    return model_inputs

In [3]:
print("Tokenizer vocab size:", tokenizer.vocab_size)
print("Pad token id:", tokenizer.pad_token_id)

lengths = []
for i in range(min(1000, len(dataset["train"]))):
    input_text = "translate Nepali to English: " + str(dataset["train"][i]["source"])
    tokens = tokenizer(input_text)["input_ids"]
    lengths.append(len(tokens))

print(f"Max length: {max(lengths)}")
print(f"Average length: {sum(lengths)/len(lengths):.1f}")

Tokenizer vocab size: 250100
Pad token id: 0
Max length: 272
Average length: 30.3


In [4]:
print("Preprocessing datasets...")
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing"
)

tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

sample = tokenized_dataset["train"][0]
for key, value in sample.items():
    if torch.is_tensor(value):
        print(f"{key}: shape {value.shape}, dtype {value.dtype}")
    else:
        print(f"{key}: {type(value)}")


Preprocessing datasets...


Tokenizing: 100%|██████████| 5000/5000 [00:01<00:00, 4462.85 examples/s]
Tokenizing: 100%|██████████| 500/500 [00:00<00:00, 5020.44 examples/s]
Tokenizing: 100%|██████████| 500/500 [00:00<00:00, 5010.28 examples/s]

input_ids: shape torch.Size([272]), dtype torch.int64
attention_mask: shape torch.Size([272]), dtype torch.int64
labels: shape torch.Size([219]), dtype torch.int64





In [5]:
def find_invalid_tokens(dataset_split, split_name=""):
    invalid_count = 0
    print(f"\nChecking {split_name} set...")
    for i in range(min(100, len(dataset_split))):
        labels = dataset_split[i]["labels"]
        for token_id in labels:
            if token_id != -100 and (token_id >= tokenizer.vocab_size or token_id < 0):
                print(f"❌ Invalid token ID {token_id} at index {i}")
                invalid_count += 1
                break
    print(f"Invalid tokens in {split_name}: {invalid_count}")
    return invalid_count

find_invalid_tokens(tokenized_dataset["train"], "train")
find_invalid_tokens(tokenized_dataset["validation"], "validation")
find_invalid_tokens(tokenized_dataset["test"], "test")


Checking train set...
Invalid tokens in train: 0

Checking validation set...
Invalid tokens in validation: 0

Checking test set...
Invalid tokens in test: 0


0

In [6]:
print("Loading model...")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
model = model.to(device)

Loading model...


In [7]:
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True
)

training_args = Seq2SeqTrainingArguments(
    output_dir="./mt5-npi-en",
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=100,
    learning_rate=2e-4,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    generation_max_length=300,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=False,
    use_cpu=True,
    dataloader_pin_memory=False,
    dataloader_num_workers=0,
    remove_unused_columns=False,
    gradient_checkpointing=False,
    report_to=None,
)

In [8]:
from transformers import Seq2SeqTrainer

print("Testing data collator...")
test_batch = [tokenized_dataset["train"][i] for i in range(2)]
collated = data_collator(test_batch)
print("Data collator working:")
for key, value in collated.items():
    print(f"  {key}: {value.shape}")

# %%
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)


Testing data collator...
Data collator working:
  input_ids: torch.Size([2, 272])
  attention_mask: torch.Size([2, 272])
  labels: torch.Size([2, 219])
  decoder_input_ids: torch.Size([2, 219])


  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


In [9]:
print("Starting training on CPU...")
trainer.train()
print("Training completed!")

# %%
trainer.save_model("./mt5-npi-en")
tokenizer.save_pretrained("./mt5-npi-en")

# %%
test_results = trainer.evaluate(eval_dataset=tokenized_dataset["test"])
print("Test Evaluation:", test_results)

Starting training on CPU...


  1%|▏         | 100/7500 [02:18<2:45:52,  1.34s/it]

{'loss': 18.1087, 'grad_norm': 584.6068115234375, 'learning_rate': 0.00019733333333333335, 'epoch': 0.04}


  3%|▎         | 200/7500 [04:34<2:47:07,  1.37s/it]

{'loss': 8.927, 'grad_norm': 129.20431518554688, 'learning_rate': 0.0001946666666666667, 'epoch': 0.08}


  4%|▍         | 300/7500 [06:50<2:50:27,  1.42s/it]

{'loss': 5.8163, 'grad_norm': 20.3961238861084, 'learning_rate': 0.000192, 'epoch': 0.12}


  5%|▌         | 400/7500 [09:04<2:27:52,  1.25s/it]

{'loss': 5.0245, 'grad_norm': 13.711023330688477, 'learning_rate': 0.00018933333333333335, 'epoch': 0.16}


  7%|▋         | 500/7500 [11:22<2:34:27,  1.32s/it]

{'loss': 4.5155, 'grad_norm': 9.968568801879883, 'learning_rate': 0.0001866666666666667, 'epoch': 0.2}


  8%|▊         | 600/7500 [13:42<2:39:13,  1.38s/it]

{'loss': 4.5668, 'grad_norm': 14.477134704589844, 'learning_rate': 0.00018400000000000003, 'epoch': 0.24}


  9%|▉         | 700/7500 [15:56<2:29:58,  1.32s/it]

{'loss': 4.2089, 'grad_norm': 9.634683609008789, 'learning_rate': 0.00018133333333333334, 'epoch': 0.28}


 11%|█         | 800/7500 [18:10<2:33:27,  1.37s/it]

{'loss': 3.9636, 'grad_norm': 8.57861328125, 'learning_rate': 0.00017866666666666668, 'epoch': 0.32}


 12%|█▏        | 900/7500 [20:24<2:26:35,  1.33s/it]

{'loss': 4.0121, 'grad_norm': 10.82735824584961, 'learning_rate': 0.00017600000000000002, 'epoch': 0.36}


 13%|█▎        | 1000/7500 [22:38<2:17:58,  1.27s/it]

{'loss': 4.0493, 'grad_norm': 22.591039657592773, 'learning_rate': 0.00017333333333333334, 'epoch': 0.4}


 15%|█▍        | 1100/7500 [24:50<2:14:37,  1.26s/it]

{'loss': 3.8865, 'grad_norm': 12.487613677978516, 'learning_rate': 0.00017066666666666668, 'epoch': 0.44}


 16%|█▌        | 1200/7500 [27:01<2:17:32,  1.31s/it]

{'loss': 3.907, 'grad_norm': 9.294699668884277, 'learning_rate': 0.000168, 'epoch': 0.48}


 17%|█▋        | 1300/7500 [29:13<2:06:53,  1.23s/it]

{'loss': 3.8486, 'grad_norm': 4.731324195861816, 'learning_rate': 0.00016533333333333333, 'epoch': 0.52}


 19%|█▊        | 1400/7500 [31:25<2:01:39,  1.20s/it]

{'loss': 3.8787, 'grad_norm': 6.88698148727417, 'learning_rate': 0.00016266666666666667, 'epoch': 0.56}


 20%|██        | 1500/7500 [33:34<2:09:02,  1.29s/it]

{'loss': 3.7106, 'grad_norm': 10.443482398986816, 'learning_rate': 0.00016, 'epoch': 0.6}


 21%|██▏       | 1600/7500 [35:47<2:16:22,  1.39s/it]

{'loss': 3.6737, 'grad_norm': 7.66888952255249, 'learning_rate': 0.00015733333333333333, 'epoch': 0.64}


 23%|██▎       | 1700/7500 [37:54<2:06:41,  1.31s/it]

{'loss': 3.7005, 'grad_norm': 9.183389663696289, 'learning_rate': 0.00015466666666666667, 'epoch': 0.68}


 24%|██▍       | 1800/7500 [40:06<2:06:25,  1.33s/it]

{'loss': 3.6089, 'grad_norm': 7.709086894989014, 'learning_rate': 0.000152, 'epoch': 0.72}


 25%|██▌       | 1900/7500 [42:16<2:01:54,  1.31s/it]

{'loss': 3.6886, 'grad_norm': 10.056556701660156, 'learning_rate': 0.00014933333333333335, 'epoch': 0.76}


 27%|██▋       | 2000/7500 [44:27<1:55:50,  1.26s/it]

{'loss': 3.3497, 'grad_norm': 8.410666465759277, 'learning_rate': 0.00014666666666666666, 'epoch': 0.8}


 28%|██▊       | 2100/7500 [46:41<2:08:54,  1.43s/it]

{'loss': 3.4985, 'grad_norm': 7.419857501983643, 'learning_rate': 0.000144, 'epoch': 0.84}


 29%|██▉       | 2200/7500 [48:50<1:48:43,  1.23s/it]

{'loss': 3.6096, 'grad_norm': 9.422581672668457, 'learning_rate': 0.00014133333333333334, 'epoch': 0.88}


 31%|███       | 2300/7500 [51:00<2:02:39,  1.42s/it]

{'loss': 3.4806, 'grad_norm': 8.98550033569336, 'learning_rate': 0.00013866666666666669, 'epoch': 0.92}


 32%|███▏      | 2400/7500 [53:09<1:48:36,  1.28s/it]

{'loss': 3.3922, 'grad_norm': 13.609254837036133, 'learning_rate': 0.00013600000000000003, 'epoch': 0.96}


 33%|███▎      | 2500/7500 [55:17<1:48:20,  1.30s/it]

{'loss': 3.576, 'grad_norm': 5.983497619628906, 'learning_rate': 0.00013333333333333334, 'epoch': 1.0}


                                                     
 33%|███▎      | 2500/7500 [56:34<1:48:20,  1.30s/it]

{'eval_loss': 2.728579521179199, 'eval_runtime': 74.315, 'eval_samples_per_second': 6.728, 'eval_steps_per_second': 3.364, 'epoch': 1.0}


 35%|███▍      | 2600/7500 [58:46<1:44:15,  1.28s/it] 

{'loss': 3.0705, 'grad_norm': 6.1676483154296875, 'learning_rate': 0.00013066666666666668, 'epoch': 1.04}


 36%|███▌      | 2700/7500 [1:00:56<1:46:18,  1.33s/it]

{'loss': 3.2107, 'grad_norm': 8.644936561584473, 'learning_rate': 0.00012800000000000002, 'epoch': 1.08}


 37%|███▋      | 2800/7500 [1:03:12<1:52:09,  1.43s/it]

{'loss': 3.0126, 'grad_norm': 13.18329906463623, 'learning_rate': 0.00012533333333333334, 'epoch': 1.12}


 39%|███▊      | 2900/7500 [1:05:25<1:45:20,  1.37s/it]

{'loss': 3.0937, 'grad_norm': 9.856952667236328, 'learning_rate': 0.00012266666666666668, 'epoch': 1.16}


 40%|████      | 3000/7500 [1:07:31<1:37:08,  1.30s/it]

{'loss': 3.1901, 'grad_norm': 9.18731689453125, 'learning_rate': 0.00012, 'epoch': 1.2}


 41%|████▏     | 3100/7500 [1:09:46<1:36:25,  1.31s/it]

{'loss': 3.2024, 'grad_norm': 8.975611686706543, 'learning_rate': 0.00011733333333333334, 'epoch': 1.24}


 43%|████▎     | 3200/7500 [1:11:58<1:42:19,  1.43s/it]

{'loss': 3.007, 'grad_norm': 6.989584922790527, 'learning_rate': 0.00011466666666666667, 'epoch': 1.28}


 44%|████▍     | 3300/7500 [1:14:12<1:27:24,  1.25s/it]

{'loss': 3.0731, 'grad_norm': 12.792412757873535, 'learning_rate': 0.00011200000000000001, 'epoch': 1.32}


 45%|████▌     | 3400/7500 [1:16:27<1:32:56,  1.36s/it]

{'loss': 2.9865, 'grad_norm': 7.648661136627197, 'learning_rate': 0.00010933333333333333, 'epoch': 1.36}


 47%|████▋     | 3500/7500 [1:18:43<1:31:28,  1.37s/it]

{'loss': 3.1524, 'grad_norm': 8.548654556274414, 'learning_rate': 0.00010666666666666667, 'epoch': 1.4}


 48%|████▊     | 3600/7500 [1:20:57<1:14:08,  1.14s/it]

{'loss': 2.9438, 'grad_norm': 12.487067222595215, 'learning_rate': 0.00010400000000000001, 'epoch': 1.44}


 49%|████▉     | 3700/7500 [1:23:04<1:15:57,  1.20s/it]

{'loss': 3.0524, 'grad_norm': 9.303947448730469, 'learning_rate': 0.00010133333333333335, 'epoch': 1.48}


 51%|█████     | 3800/7500 [1:25:14<1:27:44,  1.42s/it]

{'loss': 2.8603, 'grad_norm': 6.337976932525635, 'learning_rate': 9.866666666666668e-05, 'epoch': 1.52}


 52%|█████▏    | 3900/7500 [1:27:27<1:12:54,  1.22s/it]

{'loss': 2.8768, 'grad_norm': 6.475668430328369, 'learning_rate': 9.6e-05, 'epoch': 1.56}


 53%|█████▎    | 4000/7500 [1:29:41<1:19:37,  1.37s/it]

{'loss': 2.8738, 'grad_norm': 7.498484134674072, 'learning_rate': 9.333333333333334e-05, 'epoch': 1.6}


 55%|█████▍    | 4100/7500 [1:31:56<1:16:08,  1.34s/it]

{'loss': 2.9442, 'grad_norm': 10.44591999053955, 'learning_rate': 9.066666666666667e-05, 'epoch': 1.64}


 56%|█████▌    | 4200/7500 [1:34:07<1:04:45,  1.18s/it]

{'loss': 3.0005, 'grad_norm': 6.114720821380615, 'learning_rate': 8.800000000000001e-05, 'epoch': 1.68}


 57%|█████▋    | 4300/7500 [1:36:15<1:04:00,  1.20s/it]

{'loss': 2.9291, 'grad_norm': 7.085155010223389, 'learning_rate': 8.533333333333334e-05, 'epoch': 1.72}


 59%|█████▊    | 4400/7500 [1:38:25<1:09:44,  1.35s/it]

{'loss': 2.9226, 'grad_norm': 5.681437015533447, 'learning_rate': 8.266666666666667e-05, 'epoch': 1.76}


 60%|██████    | 4500/7500 [1:40:37<1:08:52,  1.38s/it]

{'loss': 3.0284, 'grad_norm': 7.417647361755371, 'learning_rate': 8e-05, 'epoch': 1.8}


 61%|██████▏   | 4600/7500 [1:42:52<59:39,  1.23s/it]  

{'loss': 3.0275, 'grad_norm': 5.647931098937988, 'learning_rate': 7.733333333333333e-05, 'epoch': 1.84}


 63%|██████▎   | 4700/7500 [1:45:02<1:03:29,  1.36s/it]

{'loss': 2.947, 'grad_norm': 8.14724349975586, 'learning_rate': 7.466666666666667e-05, 'epoch': 1.88}


 64%|██████▍   | 4800/7500 [1:47:10<57:40,  1.28s/it]  

{'loss': 2.9376, 'grad_norm': 7.083330154418945, 'learning_rate': 7.2e-05, 'epoch': 1.92}


 65%|██████▌   | 4900/7500 [1:49:17<55:24,  1.28s/it]  

{'loss': 2.9183, 'grad_norm': 8.063425064086914, 'learning_rate': 6.933333333333334e-05, 'epoch': 1.96}


 67%|██████▋   | 5000/7500 [1:51:24<52:25,  1.26s/it]

{'loss': 2.891, 'grad_norm': 7.5943603515625, 'learning_rate': 6.666666666666667e-05, 'epoch': 2.0}


                                                     
 67%|██████▋   | 5000/7500 [1:52:38<52:25,  1.26s/it]

{'eval_loss': 2.527980089187622, 'eval_runtime': 70.8227, 'eval_samples_per_second': 7.06, 'eval_steps_per_second': 3.53, 'epoch': 2.0}


 68%|██████▊   | 5100/7500 [1:54:44<50:56,  1.27s/it]   

{'loss': 2.6576, 'grad_norm': 6.736379146575928, 'learning_rate': 6.400000000000001e-05, 'epoch': 2.04}


 69%|██████▉   | 5200/7500 [1:56:53<45:04,  1.18s/it]

{'loss': 2.8857, 'grad_norm': 6.556136608123779, 'learning_rate': 6.133333333333334e-05, 'epoch': 2.08}


 71%|███████   | 5300/7500 [1:59:01<50:18,  1.37s/it]

{'loss': 2.7414, 'grad_norm': 7.3391923904418945, 'learning_rate': 5.866666666666667e-05, 'epoch': 2.12}


 72%|███████▏  | 5400/7500 [2:01:17<44:26,  1.27s/it]  

{'loss': 2.6971, 'grad_norm': 9.54862117767334, 'learning_rate': 5.6000000000000006e-05, 'epoch': 2.16}


 73%|███████▎  | 5500/7500 [2:03:27<39:21,  1.18s/it]

{'loss': 2.741, 'grad_norm': 10.134439468383789, 'learning_rate': 5.333333333333333e-05, 'epoch': 2.2}


 75%|███████▍  | 5600/7500 [2:05:44<39:11,  1.24s/it]  

{'loss': 2.6255, 'grad_norm': 8.989924430847168, 'learning_rate': 5.0666666666666674e-05, 'epoch': 2.24}


 76%|███████▌  | 5700/7500 [2:07:55<39:05,  1.30s/it]

{'loss': 2.7299, 'grad_norm': 6.944118022918701, 'learning_rate': 4.8e-05, 'epoch': 2.28}


 77%|███████▋  | 5800/7500 [2:10:06<40:32,  1.43s/it]

{'loss': 2.5993, 'grad_norm': 8.804169654846191, 'learning_rate': 4.5333333333333335e-05, 'epoch': 2.32}


 79%|███████▊  | 5900/7500 [2:12:20<35:23,  1.33s/it]

{'loss': 2.5976, 'grad_norm': 8.313344955444336, 'learning_rate': 4.266666666666667e-05, 'epoch': 2.36}


 80%|████████  | 6000/7500 [2:14:32<32:36,  1.30s/it]

{'loss': 2.7149, 'grad_norm': 8.762036323547363, 'learning_rate': 4e-05, 'epoch': 2.4}


 81%|████████▏ | 6100/7500 [2:16:44<27:38,  1.18s/it]

{'loss': 2.6366, 'grad_norm': 8.457252502441406, 'learning_rate': 3.733333333333334e-05, 'epoch': 2.44}


 83%|████████▎ | 6200/7500 [2:18:53<26:47,  1.24s/it]

{'loss': 2.6428, 'grad_norm': 9.981436729431152, 'learning_rate': 3.466666666666667e-05, 'epoch': 2.48}


 84%|████████▍ | 6300/7500 [2:21:02<26:46,  1.34s/it]

{'loss': 2.7028, 'grad_norm': 5.470043182373047, 'learning_rate': 3.2000000000000005e-05, 'epoch': 2.52}


 85%|████████▌ | 6400/7500 [2:23:14<24:29,  1.34s/it]

{'loss': 2.7358, 'grad_norm': 9.647879600524902, 'learning_rate': 2.9333333333333336e-05, 'epoch': 2.56}


 87%|████████▋ | 6500/7500 [2:25:23<22:57,  1.38s/it]

{'loss': 2.7969, 'grad_norm': 6.494241714477539, 'learning_rate': 2.6666666666666667e-05, 'epoch': 2.6}


 88%|████████▊ | 6600/7500 [2:27:41<20:49,  1.39s/it]

{'loss': 2.5131, 'grad_norm': 5.925939559936523, 'learning_rate': 2.4e-05, 'epoch': 2.64}


 89%|████████▉ | 6700/7500 [2:29:58<17:42,  1.33s/it]

{'loss': 2.7586, 'grad_norm': 2.963557720184326, 'learning_rate': 2.1333333333333335e-05, 'epoch': 2.68}


 91%|█████████ | 6800/7500 [2:32:10<15:04,  1.29s/it]

{'loss': 2.6477, 'grad_norm': 7.630964279174805, 'learning_rate': 1.866666666666667e-05, 'epoch': 2.72}


 92%|█████████▏| 6900/7500 [2:34:42<15:55,  1.59s/it]

{'loss': 2.6822, 'grad_norm': 11.243441581726074, 'learning_rate': 1.6000000000000003e-05, 'epoch': 2.76}


 93%|█████████▎| 7000/7500 [2:37:17<11:17,  1.35s/it]

{'loss': 2.6746, 'grad_norm': 5.420026779174805, 'learning_rate': 1.3333333333333333e-05, 'epoch': 2.8}


 95%|█████████▍| 7100/7500 [2:39:39<09:33,  1.43s/it]

{'loss': 2.7894, 'grad_norm': 9.089927673339844, 'learning_rate': 1.0666666666666667e-05, 'epoch': 2.84}


 96%|█████████▌| 7200/7500 [2:42:00<07:00,  1.40s/it]

{'loss': 2.8648, 'grad_norm': 8.861496925354004, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.88}


 97%|█████████▋| 7300/7500 [2:44:13<04:18,  1.29s/it]

{'loss': 2.5712, 'grad_norm': 9.662216186523438, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.92}


 99%|█████████▊| 7400/7500 [2:46:29<02:19,  1.39s/it]

{'loss': 2.6449, 'grad_norm': 11.283933639526367, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.96}


100%|██████████| 7500/7500 [2:48:42<00:00,  1.41s/it]

{'loss': 2.6371, 'grad_norm': 9.602225303649902, 'learning_rate': 0.0, 'epoch': 3.0}


                                                     
100%|██████████| 7500/7500 [2:49:56<00:00,  1.36s/it]


{'eval_loss': 2.466721534729004, 'eval_runtime': 71.5197, 'eval_samples_per_second': 6.991, 'eval_steps_per_second': 3.496, 'epoch': 3.0}
{'train_runtime': 10196.5879, 'train_samples_per_second': 1.471, 'train_steps_per_second': 0.736, 'train_loss': 3.472575948079427, 'epoch': 3.0}
Training completed!


100%|██████████| 250/250 [01:18<00:00,  3.18it/s]

Test Evaluation: {'eval_loss': 2.435408592224121, 'eval_runtime': 78.9327, 'eval_samples_per_second': 6.335, 'eval_steps_per_second': 3.167, 'epoch': 3.0}





In [10]:
trainer.save_model("./mt5-npi-en")
tokenizer.save_pretrained("./mt5-npi-en")

('./mt5-npi-en/tokenizer_config.json',
 './mt5-npi-en/special_tokens_map.json',
 './mt5-npi-en/spiece.model',
 './mt5-npi-en/added_tokens.json')

In [11]:
test_results = trainer.evaluate(eval_dataset=tokenized_dataset["test"])
print("Test Evaluation:", test_results)


100%|██████████| 250/250 [01:06<00:00,  3.77it/s]

Test Evaluation: {'eval_loss': 2.435408592224121, 'eval_runtime': 66.5632, 'eval_samples_per_second': 7.512, 'eval_steps_per_second': 3.756, 'epoch': 3.0}





In [12]:
from tqdm import tqdm
import torch

bleu = evaluate.load("bleu")

# For single-GPU CPU/accelerator inference
model.eval()

batch_size = 8  # You can tune this
sources = dataset["test"]["source"]
references = dataset["test"]["target"]

predictions = []

for i in tqdm(range(0, len(sources), batch_size), desc="Generating translations"):
    batch_src = sources[i:i + batch_size]
    batch_inputs = ["translate Nepali to English: " + s for s in batch_src]
    
    inputs = tokenizer(
        batch_inputs,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=300
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=300,
            num_beams=4
        )
    
    batch_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    predictions.extend(batch_preds)

# Print a few samples to check translation quality
for i in range(10):
    print(f"\nSource    : {sources[i]}")
    print(f"Reference : {references[i]}")
    print(f"Predicted : {predictions[i]}")

bleu_score = bleu.compute(
    predictions=[p.strip() for p in predictions],
    references=[[r.strip()] for r in references]
)
print("Test BLEU:", bleu_score["bleu"])



Generating translations: 100%|██████████| 63/63 [10:00<00:00,  9.54s/it]


Source    : यसबाट युवकहरु पनि पीडित भएको देखिन्छन ।
Reference : Young people also suffer from it.
Predicted : It has not been affected by young people.

Source    : बजाऊने लिस्ट
Reference : Playlist
Predicted : The Preview of Music

Source    : त्यहाँ केही नियमहरू छन् जुन तपाईंले पछ्याउनु पर्छः
Reference : There are a few rules that you should follow:
Predicted : There are some rules that you need to follow:

Source    : सबै गीतहरु उनी आफैंले लेखेका हुन् ।
Reference : All songs were written by himself.
Predicted : He wrote all songs he wrote.

Source    : मेरो एउटा जर्मन साथी छ ।
Reference : I had a German friend.
Predicted : I am a French man.

Source    : तपाईँ वास्तवमै फाइल मेट्न चाहनुहुन्छ?

Reference : Do you really want to delete file ?

Predicted : You want to delete a file?

Source    : यस कुरालाई लिएर आक्रोशित हुनुपर्ने कुनै आवश्यकता छैन ।
Reference : There is no need to get upset about this.
Predicted : There is no need to understand this.

Source    : पुरुष र महिला दुबैले य


