Lets see whether simple mt5 model overfits in small data samples

In [1]:
from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          Seq2SeqTrainer,
                          Seq2SeqTrainingArguments,
                          DataCollatorForSeq2Seq
                          )
from datasets import load_dataset
import evaluate
import numpy as np
import torch
import warnings
import wandb
warnings.filterwarnings("ignore")
import random
# Set all seeds for reproducibility
random.seed(100)
np.random.seed(100)
torch.manual_seed(100)
torch.cuda.manual_seed_all(100)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load aryal's dataset from hf
# ds = load_dataset("sumitaryal/nepali_grammatical_error_correction")

In [3]:
# select randomly few samples from train 
# split further into train and valid dataset
# small_dataset = ds["train"].shuffle(seed=42).select(range(125))
# small_dataset = small_dataset.train_test_split(test_size=0.2, seed=42)
# small_dataset["valid"] = small_dataset["test"] # Rename the split in the DatasetDict
# del small_dataset["test"]
# small_dataset

In [4]:
from datasets import load_dataset

small_dataset = load_dataset(
    "csv",
    data_files="multi_seman.txt",
    sep=", ",         # use "," if comma-separated
    column_names=["incorrect_sentence", "correct_sentence"]
)
small_dataset["train"].shuffle(seed=42)



Dataset({
    features: ['incorrect_sentence', 'correct_sentence'],
    num_rows: 40
})

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_ckpt = "google/mt5-base"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [6]:
prefix = "वाक्य सच्याउनुहोस्: "

def preprocess(batch):
    
    inputs = [prefix + inp for inp in batch["incorrect_sentence"]]

    # tokenize input (incorrect)
    input_encodings = tokenizer(
        inputs, 
        max_length=128,
        truncation=True 
    )
    # tokenize target (correct)
    with tokenizer.as_target_tokenizer():
        target_encodings = tokenizer(
            batch["correct_sentence"], 
            max_length=128,
            truncation=True
        )

    # set labels for seq2seq training                           # for seq2deq models, the "labels" are the token IDs of the target sequence
    input_encodings["labels"] = target_encodings["input_ids"]   

    return input_encodings

dataset_encoded = small_dataset.map(preprocess, batched=True) 


In [7]:
#pytorch model expects in tensor format
dataset_encoded.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

In [8]:
import evaluate
import numpy as np

# Load metrics once
bleu_metric = evaluate.load("bleu")
chrf_metric = evaluate.load("chrf")
bertscore_metric = evaluate.load("bertscore")

def compute_metrics(eval_pred):
    """
    Compute BLEU, chrF, Correction Accuracy, and BERTScore for Nepali GEC.
    Handles both token IDs and plain text predictions.
    """
    predictions, labels = eval_pred

    # --- Handle tuple outputs (e.g., logits + labels) ---
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    # --- If preds/labels are lists of strings, skip decoding ---
    if isinstance(predictions[0], str) and isinstance(labels[0], str):
        preds_clean = [p.strip() for p in predictions]
        refs_clean = [r.strip() for r in labels]
    else:
        # Convert to numpy arrays
        predictions = np.array(predictions)
        labels = np.array(labels)

        # Handle logits (vocab dimension)
        if predictions.ndim == 3:
            predictions = predictions.argmax(axis=-1)

        # Replace -100 with pad_token_id
        predictions = np.where(predictions == -100, tokenizer.pad_token_id, predictions)
        labels = np.where(labels == -100, tokenizer.pad_token_id, labels)

        # Decode
        preds = tokenizer.batch_decode(predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        refs = tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        preds_clean = [p.strip() for p in preds]
        refs_clean = [r.strip() for r in refs]

    # --- Format for metrics ---
    references = [[r] for r in refs_clean]
    metrics = {}

    # --- BLEU ---
    try:
        non_empty_indices = [i for i, (p, r) in enumerate(zip(preds_clean, refs_clean)) if p and r]
        if non_empty_indices:
            preds_bleu = [preds_clean[i] for i in non_empty_indices]
            refs_bleu = [[refs_clean[i]] for i in non_empty_indices]
            bleu_result = bleu_metric.compute(predictions=preds_bleu, references=refs_bleu)
            metrics["bleu"] = bleu_result["bleu"]
        else:
            metrics["bleu"] = 0.0
    except Exception as e:
        print(f"BLEU computation failed: {e}")
        metrics["bleu"] = 0.0

    # --- chrF ---
    try:
        chrf_result = chrf_metric.compute(predictions=preds_clean, references=refs_clean)
        metrics["chrf"] = chrf_result["score"]
    except Exception as e:
        print(f"chrF computation failed: {e}")
        metrics["chrf"] = 0.0

    # --- Correction Accuracy ---
    try:
        exact_matches = np.mean([p == r for p, r in zip(preds_clean, refs_clean)])
        metrics["correction_accuracy"] = exact_matches
    except Exception as e:
        print(f"Correction accuracy computation failed: {e}")
        metrics["correction_accuracy"] = 0.0

    # --- BERTScore ---
    try:
        non_empty_indices_bert = [i for i, (p, r) in enumerate(zip(preds_clean, refs_clean)) if p and r]
        if non_empty_indices_bert:
            preds_bert = [preds_clean[i] for i in non_empty_indices_bert]
            refs_bert = [refs_clean[i] for i in non_empty_indices_bert]
            bertscore_result = bertscore_metric.compute(
                predictions=preds_bert,
                references=refs_bert,
                lang="ne",
                model_type="microsoft/mdeberta-v3-base"
            )
            metrics["bertscore_f1"] = float(np.mean(bertscore_result["f1"]))
        else:
            metrics["bertscore_f1"] = 0.0
    except Exception as e:
        print(f"BERTScore computation failed: {e}")
        metrics["bertscore_f1"] = 0.0

    # --- Print one sample for sanity ---
    if len(preds_clean) > 0:
        print(f"🔍 Sample - Pred: '{preds_clean[0][:50]}...' | Ref: '{refs_clean[0][:50]}...' | Match: {preds_clean[0] == refs_clean[0]}")

    return metrics


In [9]:
# import torch
# import evaluate
# import numpy as np

# # Load metrics once
# bleu_metric = evaluate.load("bleu")
# chrf_metric = evaluate.load("chrf")
# bertscore_metric = evaluate.load("bertscore")

# # Minimum GPU memory (bytes) to safely run BERTScore
# MIN_BERTSCORE_GPU_FREE = 3 * 1024**3  # 3 GB, adjust if needed

# def compute_metrics(eval_pred):
#     """
#     Compute BLEU, chrF, Correction Accuracy, and BERTScore for Nepali GEC.
#     BERTScore is skipped if GPU RAM is insufficient.
#     Handles both token IDs and plain text predictions.
#     """
#     predictions, labels = eval_pred

#     # --- Handle tuple outputs (e.g., logits + labels) ---
#     if isinstance(predictions, tuple):
#         predictions = predictions[0]

#     # --- If preds/labels are lists of strings, skip decoding ---
#     if isinstance(predictions[0], str) and isinstance(labels[0], str):
#         preds_clean = [p.strip() for p in predictions]
#         refs_clean = [r.strip() for r in labels]
#     else:
#         if tokenizer is None:
#             raise ValueError("Tokenizer must be provided for decoding token IDs.")
#         predictions = np.array(predictions)
#         labels = np.array(labels)

#         if predictions.ndim == 3:
#             predictions = predictions.argmax(axis=-1)

#         predictions = np.where(predictions == -100, tokenizer.pad_token_id, predictions)
#         labels = np.where(labels == -100, tokenizer.pad_token_id, labels)

#         preds = tokenizer.batch_decode(predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)
#         refs = tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)

#         preds_clean = [p.strip() for p in preds]
#         refs_clean = [r.strip() for r in refs]

#     metrics = {}

#     # --- BLEU ---
#     try:
#         non_empty_indices = [i for i, (p, r) in enumerate(zip(preds_clean, refs_clean)) if p and r]
#         if non_empty_indices:
#             preds_bleu = [preds_clean[i] for i in non_empty_indices]
#             refs_bleu = [[refs_clean[i]] for i in non_empty_indices]
#             bleu_result = bleu_metric.compute(predictions=preds_bleu, references=refs_bleu)
#             metrics["bleu"] = bleu_result["bleu"]
#         else:
#             metrics["bleu"] = 0.0
#     except Exception as e:
#         print(f"BLEU computation failed: {e}")
#         metrics["bleu"] = 0.0

#     # --- chrF ---
#     try:
#         chrf_result = chrf_metric.compute(predictions=preds_clean, references=refs_clean)
#         metrics["chrf"] = chrf_result["score"]
#     except Exception as e:
#         print(f"chrF computation failed: {e}")
#         metrics["chrf"] = 0.0

#     # --- Correction Accuracy ---
#     try:
#         exact_matches = np.mean([p == r for p, r in zip(preds_clean, refs_clean)])
#         metrics["correction_accuracy"] = exact_matches
#     except Exception as e:
#         print(f"Correction accuracy computation failed: {e}")
#         metrics["correction_accuracy"] = 0.0

#     # --- BERTScore (skip if GPU memory low) ---
#     try:
#         free_mem = torch.cuda.mem_get_info()[0] if torch.cuda.is_available() else 0
#         if free_mem >= MIN_BERTSCORE_GPU_FREE:
#             non_empty_indices_bert = [i for i, (p, r) in enumerate(zip(preds_clean, refs_clean)) if p and r]
#             if non_empty_indices_bert:
#                 preds_bert = [preds_clean[i] for i in non_empty_indices_bert]
#                 refs_bert = [refs_clean[i] for i in non_empty_indices_bert]
#                 bertscore_result = bertscore_metric.compute(
#                     predictions=preds_bert,
#                     references=refs_bert,
#                     lang="ne",
#                     model_type="microsoft/mdeberta-v3-base"
#                 )
#                 metrics["bertscore_f1"] = float(np.mean(bertscore_result["f1"]))
#             else:
#                 metrics["bertscore_f1"] = 0.0
#         else:
#             print(f"⚠️ Skipping BERTScore: free GPU memory {free_mem / 1024**3:.2f} GB < required {MIN_BERTSCORE_GPU_FREE / 1024**3:.1f} GB")
#             metrics["bertscore_f1"] = None
#     except Exception as e:
#         print(f"BERTScore computation failed: {e}")
#         metrics["bertscore_f1"] = None

#     # --- Print one sample for sanity ---
#     if len(preds_clean) > 0:
#         print(f"🔍 Sample - Pred: '{preds_clean[0][:50]}...' | Ref: '{refs_clean[0][:50]}...' | Match: {preds_clean[0] == refs_clean[0]}")

#     return metrics


In [10]:
preds = ["मेरो नाम सन्तोष हो ।", "म स्कुल जान्छु ।", "म खाना खान्छु ।"]
refs  = ["मेरो नाम सन्तोष हो ।", "म स्कुल जान्छु ।", "म खाना खान्छु ।"]
compute_metrics((preds, refs))


🔍 Sample - Pred: 'मेरो नाम सन्तोष हो ।...' | Ref: 'मेरो नाम सन्तोष हो ।...' | Match: True


{'bleu': 1.0,
 'chrf': 100.0,
 'correction_accuracy': np.float64(1.0),
 'bertscore_f1': 1.0}

In [11]:
import gc
import torch

# del model       # or del comet_model
gc.collect()
# torch.cuda.empty_cache()

1327

In [12]:

seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
wandb.finish()
wandb.init(project="gec_overfit")
run_id = wandb.run.id

training_args = Seq2SeqTrainingArguments(
    output_dir="../outputs",
    num_train_epochs=30,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=5e-5,
    logging_steps=1,
    eval_strategy="epoch",
    save_strategy="no",
    report_to="wandb"
)

[34m[1mwandb[0m: Currently logged in as: [33mlsumit008[0m ([33mlsumit008-khwopa-college-of-engineering[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_encoded["train"],
    eval_dataset=dataset_encoded["train"].select(range(15)),  # same dataset for overfitting here
    tokenizer=tokenizer,
    data_collator=seq2seq_data_collator,
    compute_metrics=compute_metrics
      
)
trainer.evaluate()

trainer.train()



🔍 Sample - Pred: '<extra_id_0>  कोन))्द <extra_id_56>ई <extra_id_40>...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False


Epoch,Training Loss,Validation Loss,Model Preparation Time,Bleu,Chrf,Correction Accuracy,Bertscore F1
1,12.0815,7.84936,0.0035,0.0,6.192697,0.0,0.422561
2,12.3766,6.568001,0.0035,0.0,7.896298,0.0,0.458113
3,12.5288,5.910247,0.0035,0.0,10.242493,0.0,0.464395
4,8.6122,4.559451,0.0035,0.0,17.159604,0.0,0.480479
5,6.1368,3.623879,0.0035,0.0,22.11168,0.0,0.508076
6,4.0513,2.754323,0.0035,0.013539,31.526172,0.0,0.578632
7,4.134,2.143264,0.0035,0.022071,40.162586,0.0,0.637022
8,4.0628,1.976315,0.0035,0.03327,44.938007,0.0,0.658077
9,1.8604,1.72937,0.0035,0.043778,47.942412,0.0,0.686681
10,3.4461,1.614933,0.0035,0.057083,51.747946,0.0,0.698935


🔍 Sample - Pred: '<extra_id_0>....बाट),्द... <extra_id_2> <extra_id_...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False
🔍 Sample - Pred: '<extra_id_0>.ईकोबीच  ्दितचीुर <extra_id_44>!नीन्जि...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False
🔍 Sample - Pred: '<extra_id_0>य ई य।्दीरी <extra_id_43>लोुर ध्वनीन्ज...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False
🔍 Sample - Pred: '<extra_id_0>य ईरहेको्दछभरछान्तको <extra_id_11>मधुर...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False
🔍 Sample - Pred: '<extra_id_0>र ः नै्दछमाान्तरमधुर ध्वनी गुन्जिन्छने...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False
🔍 Sample - Pred: '<extra_id_0>यी दुई र मिल ्दछमा शान्तको मधुर ध्वनी ...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False
🔍 Sample - Pred: '<extra_id_0>यी दुई हात मिल मिल्दछमा शान्तको सुमधुर...' | Ref: 'यी दुई हातहरू

TrainOutput(global_step=600, training_loss=4.047622843682766, metrics={'train_runtime': 582.8035, 'train_samples_per_second': 2.059, 'train_steps_per_second': 1.03, 'total_flos': 91019779209216.0, 'train_loss': 4.047622843682766, 'epoch': 30.0})

In [14]:
import gc
import torch

# del model       # or del comet_model
gc.collect()
torch.cuda.empty_cache()

Testing with metrics

In [15]:
# testing with same train dataset for overfit model in order to check whether metrics function works. 
trainer.evaluate(dataset_encoded["train"])



🔍 Sample - Pred: 'बियी दुई हात मिल मिल्द विश्वमा शान्तिको सुमधुर ध्व...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False


{'eval_loss': 0.9358439445495605,
 'eval_model_preparation_time': 0.0035,
 'eval_bleu': 0.23839540672661175,
 'eval_chrf': 75.09730452317531,
 'eval_correction_accuracy': 0.025,
 'eval_bertscore_f1': 0.8259538620710373,
 'eval_runtime': 13.4613,
 'eval_samples_per_second': 2.971,
 'eval_steps_per_second': 1.486,
 'epoch': 30.0}

Inference

In [16]:
def correct_grammar_simple(text):
    # Add task prefix (use the same format as during training)
    input_text = f"वाक्य सुधार्नुहोस्: {text}"
    
    # Tokenize
    inputs = tokenizer(
        input_text,
        return_tensors = "pt",
        truncation = True,
        padding=False
    ).to(device)
    
    # Generate correction
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=128,
            # num_beams=5,
            # repetition_penalty=2.5,
            # length_penalty=1.0,
            # temperature=0.8
        )
        
    # Decode output
    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_text

# Test
test_sentence = "नगरपालिका कस्तो किसिमको पर्यटक ल्याउन सक्छे "
corrected = correct_grammar_simple(test_sentence)
print(f"Original: {test_sentence}")
print(f"Corrected: {corrected}")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Original: नगरपालिका कस्तो किसिमको पर्यटक ल्याउन सक्छे 
Corrected: <extra_id_0> नगरपालिका कस्ता किसिमको पर्यटक ल्याउन सक्छ ?


In [17]:
def correct_batch(texts, batch_size=8):
    """
    Correct grammar for multiple sentences
    """
    corrected_texts = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        
        # Add prefix to each text
        input_texts = [f"वाक्य सुधार्नुहोस्: {text}" for text in batch_texts]
        
        
    
        # Tokenize
        inputs = tokenizer(
            input_texts,
            return_tensors = "pt",
            truncation = True,
            padding=True
        ).to(device)
        
        # Generate correction
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                # attention_mask=inputs.attention_mask,
                max_length=128,
                num_beams=5,
                repetition_penalty=2.5
            )
            
        # Decode batch
        batch_corrected = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        corrected_texts.extend(batch_corrected)
        
    return corrected_texts
        
    
test_sentences = small_dataset["train"]["incorrect_sentence"][:]
labels = small_dataset["train"]["correct_sentence"][:]
corrected_sentences = correct_batch(test_sentences)
for orig, corr, lab in zip(test_sentences, corrected_sentences, labels):
    print(f"Original:  {orig}")
    print(f"Corrected: {corr}")
    print(f"label:     {lab}")
    print("---")

Original:  यी दुई हात मिल्दछ विश्वमा शान्तीको सुमधुर ध्वनी गुन्जिन्छे
Corrected: यी दुई हात मिल्दा विश्वमा शान्तिको सुमधुर ध्वनी गुञ्जिनेछ।
label:     यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि गुञ्जिनेछ ।
---
Original:  फुटबल छुनै दिन्न थिएन।
Corrected: <extra_id_0> थिएन।
label:     फुटबल छुनै दिँदैन थिए ।
---
Original:  मलाई भन्न त धेरै मन हो
Corrected: बिहान मलाई भन्न त धेरै मन छ ।
label:     मलाई भन्न त धेरै मन छ।
---
Original:  नेक्सनमा ६ स्पिड गेअर बम्स रहको हो।
Corrected: ६ स्पिड गेअर बम्स रहन्छ।
label:     नेक्सनमा ६ स्पीड गेयर बम्स रहेको छ ।
---
Original:  सिलाइ कटाई तालिम लिएको व्यवसाय चलाउ योजना बनाइ
Corrected: <extra_id_0> व्यवसाय चलाउने योजना बनाइएको छ ।
label:     सिलाइकटाइ तालिम लिएर व्यवसाय चलाउने योजना बनाइन् ।
---
Original:  काठमाण्डौले दिएको ९० रनको लक्ष्य ललितपुर १६.३ ओभरमा ३ विकेट हराएर पुरा गर्यो
Corrected: ललितपुरले दिएको ९० रनको लक्ष्य काठमाण्डौले १६.३ ओभरमा ३ विकेट गुमाएर पुरा गर्यो।
label:     काठमाडौंले दिएको ९० रनको लक्ष्य ललितपुरले १६.३ ओभरमा ३ विके

In [18]:
compute_metrics((corrected_sentences, labels))

🔍 Sample - Pred: 'यी दुई हात मिल्दा विश्वमा शान्तिको सुमधुर ध्वनी गु...' | Ref: 'यी दुई हातहरू मिल्दा विश्वमै शान्तिको सुमधुर ध्वनि...' | Match: False


{'bleu': 0.315693833959732,
 'chrf': 62.606904147812315,
 'correction_accuracy': np.float64(0.05),
 'bertscore_f1': 0.8399195969104767}