# Config

In [None]:
! pip install transformers sentencepiece datasets
! pip install tqdm
! pip install torch
!pip install sacrebleu
!pip install evaluate



In [None]:
from google.colab import userdata
from huggingface_hub import login
hf_token = userdata.get('HF_TOKEN')
login(token=hf_token, add_to_git_credential=True)

# Preparation


In [None]:
import torch
import numpy as np
import math
import random
import os
from datasets import load_dataset, load_from_disk
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    AutoConfig,
    TrainerCallback
)
import evaluate
from typing import Dict, List, Optional, Union
import logging
import time
from datetime import datetime
from torch.optim import AdamW

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Set up device and seed for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Model and configuration parameters
base_model_name = "google/mt5-small"
model_name = "JMwagunda/GIR-ENG-MODEL"
repo_id = "JMwagunda/GIR-ENG-MODEL"
output_dir = repo_id
max_length = 128
batch_size = 16
learning_rate = 3e-5
weight_decay = 0.03
num_epochs = 60
source_lang = "sw"
target_lang = "en"  # Nyf = Giriama language code
save_total_limit = 3
gradient_accumulation_steps = 4
max_grad_norm = 0.5  # Gradient clipping
warmup_ratio = 0.15
early_stopping_patience = 5

# Language tokens
lang_tokens = {
    'sw': '<sw>',
    'en': '<en>'
}

# repo_id = "Lingua-Connect/SWA_TrainerImproved"  # Your Hub repository ID

# Try to download the latest checkpoint from Hub
try:
    # Load model config
    model_config = AutoConfig.from_pretrained(model_name)
    # Add dropout for regularization
    model_config.dropout_rate = 0.2

    # Load the model and tokenizer from the downloaded checkpoint
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    print("Successfully loaded model and tokenizer from Hub checkpoint")

except Exception as e:
    print(f"No checkpoint found or error loading from Hub: {e}")
    print("Loading base model instead...")

    # Fallback to loading the base model
    model_config = AutoConfig.from_pretrained(model_name)
    model_config.dropout_rate = 0.2  # Add dropout

    # Fallback to loading the base model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    special_tokens = {'additional_special_tokens': list(lang_tokens.values())}
    tokenizer.add_special_tokens(special_tokens)

    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))

# Create custom callback for monitoring and debugging
class MonitorCallback(TrainerCallback):
    def __init__(self):
        self.step_times = []
        self.last_time = time.time()
        self.step_loss = []

    def on_step_end(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            # Track loss value
            current_loss = logs["loss"]
            self.step_loss.append(current_loss)

            # Check for NaN or Inf
            if math.isnan(current_loss) or math.isinf(current_loss):
                logger.warning(f"WARNING: Abnormal loss detected: {current_loss}")

                # Check model weights for NaN
                for name, param in trainer.model.named_parameters():
                    if torch.isnan(param).any() or torch.isinf(param).any():
                        logger.warning(f"NaN or Inf found in parameter {name}")

            # Track step time
            current_time = time.time()
            step_time = current_time - self.last_time
            self.step_times.append(step_time)
            self.last_time = current_time

            # Report average step time and memory every 50 steps
            if state.global_step % 50 == 0:
                avg_step_time = sum(self.step_times[-50:]) / min(50, len(self.step_times))
                logger.info(f"Step {state.global_step}: Avg step time = {avg_step_time:.3f}s, Loss = {current_loss:.4f}")

                # Reset step times after reporting
                if len(self.step_times) > 100:
                    self.step_times = self.step_times[-50:]
                if len(self.step_loss) > 100:
                    self.step_loss = self.step_loss[-50:]

                # Report memory usage if on CUDA
                if torch.cuda.is_available():
                    mem_allocated = torch.cuda.memory_allocated() / 1024**2
                    mem_reserved = torch.cuda.memory_reserved() / 1024**2
                    logger.info(f"GPU Memory: Allocated = {mem_allocated:.1f}MB, Reserved = {mem_reserved:.1f}MB")



Successfully loaded model and tokenizer from Hub checkpoint


In [None]:
# Improved early stopping callback with more detailed logging
class DetailedEarlyStoppingCallback(EarlyStoppingCallback):
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        metric_to_check = args.metric_for_best_model
        if not metric_to_check.startswith("eval_"):
            metric_to_check = f"eval_{metric_to_check}"

        metric_value = metrics.get(metric_to_check)

        logger.info(f"Early stopping metric '{metric_to_check}' value: {metric_value}")
        logger.info(f"Best value so far: {state.best_metric}")
        logger.info(f"No improvement counter: {self.early_stopping_patience_counter}")

        # Call the parent class method
        super().on_evaluate(args, state, control, metrics, **kwargs)

        if control.should_training_stop:
            logger.warning("Early stopping triggered! Training will stop.")

# Preprocessing

In [None]:
# Function to load preprocessed data or process it again if needed
def load_or_preprocess_data():

        # Load the dataset
        ds = load_dataset('Lingua-Connect/English-Giriama-Dataset')
        split_datasets = ds["train"].train_test_split(train_size=0.9, seed=seed)
        split_datasets["validation"] = split_datasets.pop("test")

        logger.info(f"Dataset loaded: {len(split_datasets['train'])} train, {len(split_datasets['validation'])} validation")

        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id

        # Define preprocessing function
        def preprocess_function(examples):
            # Prepare input texts with prefix
            source_prefix = f"translate {source_lang} to {target_lang}: "
            inputs = [source_prefix + sw for sw in examples['Giriama Translation'] if sw is not None]
            targets = [str(en) for en in examples['English Sentence'] if en is not None]

            # Check if inputs and targets have the same length after filtering
            if len(inputs) != len(targets):
                # Handle the case where they have different lengths
                min_len = min(len(inputs), len(targets))
                inputs = inputs[:min_len]
                targets = targets[:min_len]

            # Tokenize inputs
            model_inputs = tokenizer(
                inputs,
                max_length=max_length,
                truncation=True,
                padding="max_length",
                return_tensors=None
            )

            # Tokenize targets
            labels = tokenizer(
                targets,
                max_length=max_length,
                truncation=True,
                padding="max_length",
                return_tensors=None
            )

            # Add labels to model inputs
            model_inputs["labels"] = labels["input_ids"]

            # Replace pad token id with -100 in labels so it's ignored in loss computation
            for i in range(len(model_inputs["labels"])):
                pad_mask = [token == tokenizer.pad_token_id for token in model_inputs["labels"][i]]
                model_inputs["labels"][i] = [
                    -100 if mask else token
                    for mask, token in zip(pad_mask, model_inputs["labels"][i])
                ]

            return model_inputs

        # Process datasets
        logger.info("Processing datasets...")
        train_dataset = split_datasets["train"].map(
            preprocess_function,
            batched=True,
            batch_size=16,
            remove_columns=split_datasets["train"].column_names,
            desc="Preprocessing training dataset"
        )

        validation_dataset = split_datasets["validation"].map(
            preprocess_function,
            batched=True,
            batch_size=16,
            remove_columns=split_datasets["validation"].column_names,
            desc="Preprocessing validation dataset"
        )

        return train_dataset, validation_dataset

In [None]:
# Load data
train_dataset, validation_dataset = load_or_preprocess_data()

# Load model and tokenizer
logger.info(f"Loading model: {model_name}")
model_config = AutoConfig.from_pretrained(model_name)
model_config.dropout_rate = 0.2  # Add dropout for regularization
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure pad_token_id is set correctly
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# # Move model to device
# model = model.to(device)
# logger.info(f"Model loaded with {model.num_parameters():,} parameters")

# Initialize output layer weights with small values for numerical stability
for name, param in model.named_parameters():
    if "decoder" in name and "dense" in name:
        logger.info(f"Initializing {name} with small values")
        torch.nn.init.normal_(param, mean=0.0, std=0.02)

# Custom optimizer setup with layer-wise learning rate decay
def get_optimizer(model, lr):
    decay_parameters = [p for n, p in model.named_parameters() if "LayerNorm" not in n and p.requires_grad]
    no_decay_parameters = [p for n, p in model.named_parameters() if "LayerNorm" in n and p.requires_grad]

    optimizer_grouped_parameters = [
        {"params": decay_parameters, "weight_decay": weight_decay, "lr": lr},
        {"params": no_decay_parameters, "weight_decay": 0.0, "lr": lr}
    ]

    return AdamW(optimizer_grouped_parameters)

# Prepare data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="max_length",
    max_length=max_length,
    return_tensors="pt"
)


In [None]:
# Load metric for evaluation
metric = evaluate.load("sacrebleu")

# Optimizer

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    # Debug information - use print in addition to logger
    print(f"Prediction shape: {preds.shape}, Labels shape: {labels.shape}")
    logger.info(f"Prediction shape: {preds.shape}, Labels shape: {labels.shape}")

    try:
        # Check vocabulary boundaries
        vocab_size = tokenizer.vocab_size
        print(f"Tokenizer vocabulary size: {vocab_size}")
        logger.info(f"Tokenizer vocabulary size: {vocab_size}")

        # Replace token IDs that are out of vocabulary range with pad token ID
        invalid_indices = np.where((preds >= vocab_size) | (preds < 0))
        if invalid_indices[0].size > 0:
            print(f"Found {invalid_indices[0].size} token IDs outside vocab range. Replacing with pad token.")
            logger.warning(f"Found {invalid_indices[0].size} token IDs outside vocab range. Replacing with pad token.")
            preds[invalid_indices] = tokenizer.pad_token_id

        # Decode predictions
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

        # Handle labels: replace -100 with pad token ID and clip to valid range
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        invalid_label_indices = np.where((labels >= vocab_size) | (labels < 0))
        if invalid_label_indices[0].size > 0:
            print(f"Found {invalid_label_indices[0].size} label IDs outside vocab range. Replacing with pad token.")
            logger.warning(f"Found {invalid_label_indices[0].size} label IDs outside vocab range. Replacing with pad token.")
            labels[invalid_label_indices] = tokenizer.pad_token_id

        # Decode labels
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Post-processing
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [[label.strip()] for label in decoded_labels]

        # Debug output - print some examples with both print and logger
        print("\n===== PREDICTION EXAMPLES =====")
        for i in range(min(3, len(decoded_preds))):
            print(f"Pred[{i}]: {decoded_preds[i][:100]}...")
            print(f"Label[{i}]: {decoded_labels[i][0][:100]}...")
            print("-" * 50)

            logger.info(f"Pred[{i}]: {decoded_preds[i][:100]}...")
            logger.info(f"Label[{i}]: {decoded_labels[i][0][:100]}...")

        # Ensure these examples are flushed to output
        import sys
        sys.stdout.flush()

        # Compute BLEU score
        result = metric.compute(predictions=decoded_preds, references=decoded_labels)

        # Add generation length
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)

        # Format results
        formatted_result = {
            "bleu": round(result["score"], 4),
            "gen_len": round(result["gen_len"], 4)
        }

        # Print final metrics
        print(f"\nMetrics: BLEU = {formatted_result['bleu']}, Gen Length = {formatted_result['gen_len']}")

        return formatted_result

    except Exception as e:
        # More detailed error logging
        error_msg = f"Error in compute_metrics: {e}"
        print(error_msg)
        logger.error(error_msg)

        import traceback
        tb = traceback.format_exc()
        print(f"Traceback: {tb}")
        logger.error(f"Traceback: {tb}")

        # Return zeros to prevent training from crashing
        return {"bleu": 0.0, "gen_len": 0.0}

In [None]:
# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    # eval_steps=100,
    save_strategy="epoch",
    # save_steps=100,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    save_total_limit=save_total_limit,
    num_train_epochs=num_epochs,
    predict_with_generate=True,
    fp16=False,  # Disable mixed precision initially for stability
    push_to_hub=True,
    hub_model_id=repo_id,
    load_best_model_at_end=True,
    metric_for_best_model="bleu",
    greater_is_better=True,
    resume_from_checkpoint=True,
    max_grad_norm=max_grad_norm,
    gradient_accumulation_steps=gradient_accumulation_steps,
    logging_dir=f"./logs",
    logging_steps=10,
    generation_max_length=max_length,
    generation_num_beams=4,
    label_smoothing_factor=0.15,
    lr_scheduler_type="cosine",
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    report_to="tensorboard"
)

# Create optimizer
optimizer = get_optimizer(model, learning_rate)

# Initialize the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=early_stopping_patience),
        MonitorCallback()
    ]
)


  trainer = Seq2SeqTrainer(


In [None]:
# Initial evaluation
print("\nRunning initial evaluation...")
initial_eval_results = trainer.evaluate(max_length=max_length)
print(f"Initial evaluation results: {initial_eval_results}")


Running initial evaluation...




Prediction shape: (782, 126), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 27872 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: Jesus said to them A prophet was not born in his own country...
Label[0]: Jesus had said before that a prophet is not respected in his own country...
--------------------------------------------------
Pred[1]: Peter baptized them again He said to them I dont know him...
Label[1]: Again Peter said he was never with Jesus He said I swear to God I dont know the man...
--------------------------------------------------
Pred[2]: Peter said No If all of you leave you I will not leave you...
Label[2]: Peter answered All the other followers may lose their faith in you But my faith will never be shaken...
--------------------------------------------------

Metrics: BLEU = 14.8574, Gen Length = 47.1547
Initial evaluation results: {'eval_loss': 3.2401461601257324, 'eval_model_preparation_time': 0.0063

In [None]:
# Use the path if checkpoint was downloaded, otherwise let it default to None
trainer.train()

Epoch,Training Loss,Validation Loss,Model Preparation Time,Bleu,Gen Len
1,3.0741,3.191496,0.0063,15.4638,47.7455
2,3.0865,3.185565,0.0063,15.4828,47.7992
3,3.0531,3.181058,0.0063,15.3824,47.6854
4,3.0783,3.17979,0.0063,15.1285,47.1471
5,3.0436,3.177676,0.0063,15.1449,46.555
6,3.0518,3.177946,0.0063,15.3923,47.4335
7,3.0429,3.178887,0.0063,15.1473,47.6138




Prediction shape: (782, 118), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 22688 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: I want you to know how to live with me against this But you see me fighting behind me and this is wh...
Label[0]: You saw the difficulties I had to face and you hear that I am still having troubles Now you must fac...
--------------------------------------------------
Pred[1]: They began to accuse him They said to him Are you the king of the Jews...
Label[1]: Then they began shouting Welcome king of the Jews...
--------------------------------------------------
Pred[2]: Come to me about all you and that you have given to me is a burden and I will give you it...
Label[2]: Come to me all of you who are tired from the heavy burden you have been forced to carry I will give ...
--------------------------------------------------

Metrics: BLEU = 15.4638, Gen Length = 47.7455




Prediction shape: (782, 128), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 29248 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: It is the same with those who are raised from death The body that is buried is the body that will di...
Label[0]: It will be the same when those who have died are raised to life The body that is planted in the grav...
--------------------------------------------------
Pred[1]: So we sent Jesus to the place where we can find our own guard in his own rock and helped us...
Label[1]: So we should go to Jesus outside the camp and accept the same shame that he had...
--------------------------------------------------
Pred[2]: There are heavens bodies and earthly bodies The goodness of heavens is something else and the goodne...
Label[2]: Also there are heavenly bodies and earthly bodies But the beauty of the heavenly bodies is one kind ...
--------------------------------------------------

Metri



Prediction shape: (782, 128), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 30920 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: God has the power to give you more than that you will always have to have everything you need He wil...
Label[0]: And God can give you more blessings than you need and you will always have plenty of everything You ...
--------------------------------------------------
Pred[1]: My servant is a very sick and very small servant That is why he is sent and my servant is very sick...
Label[1]: The officer said Lord my servant is very sick at home in bed He cant move his body and has much pain...
--------------------------------------------------
Pred[2]: You have heard that when everyone was told When you swear before the Lord dont obey everything you h...
Label[2]: You have heard that it was said to our people long ago When you make a vow you must not break your p...
----------------------------



Prediction shape: (782, 128), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 31156 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: David himself calls the Lord He can be his son The meeting was very happy and very happy...
Label[0]: David himself calls the Messiah Lord So how can the Messiah be Davids son Many people listened to Je...
--------------------------------------------------
Pred[1]: To the Jews I became like a Jew so that I could obey the Jews I did not obey their law as if I were ...
Label[1]: To the Jews I became like a Jew so that I could help save Jews I myself am not ruled by the law but ...
--------------------------------------------------
Pred[2]: If you did not come back I would not be able to hit those who follow the teaching and the sword that...
Label[2]: So change your hearts If you dont change I will come to you quickly and fight against these people w...
---------------------------------------



Prediction shape: (782, 125), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 28948 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: But Jesus went to the Mount of Olives...
Label[0]: Jesus went to the Mount of Olives...
--------------------------------------------------
Pred[1]: They were very afraid of this So they threw away the four anchors and threw the ships into the water...
Label[1]: The sailors were afraid that we would hit the rocks so they threw four anchors into the water Then t...
--------------------------------------------------
Pred[2]: Some people said He is a good man But others said He is a follower...
Label[2]: There was a large group of people there Many of them were talking secretly to each other about Jesus...
--------------------------------------------------

Metrics: BLEU = 15.1449, Gen Length = 46.555




Prediction shape: (782, 128), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 27868 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: The second servant came and said Sir everythird of the golden feet had a golden bag...
Label[0]: The second servant said Sir with your one bag of money I earned five bags...
--------------------------------------------------
Pred[1]: But now I tell you this is not anyone who has the name of the brother Then he is a follower or a fol...
Label[1]: I wrote to you in my letter that you should not associate with people who sin sexually...
--------------------------------------------------
Pred[2]: This means Abraham had two sons One of his sons married a slave woman and another woman married her ...
Label[2]: The Scriptures say that Abraham had two sons The mother of one son was a slave woman and the mother ...
--------------------------------------------------

Metrics: BLEU = 15.3923, Gen Leng



Prediction shape: (782, 128), Labels shape: (782, 128)
Tokenizer vocabulary size: 58905
Found 29668 token IDs outside vocab range. Replacing with pad token.

===== PREDICTION EXAMPLES =====
Pred[0]: So the people were afraid of this because of Jesus...
Label[0]: So the people did not agree with each other about Jesus...
--------------------------------------------------
Pred[1]: But you teach that people will give their father or mother And they have helped you understand what ...
Label[1]: But you teach that a person can say to their father or mother I have something I could use to help y...
--------------------------------------------------
Pred[2]: God is the one who is our Savior and glory and power and power because of our Lord Jesus Christ...
Label[2]: He is the only God the one who saves us To him be glory greatness power and authority through Jesus ...
--------------------------------------------------

Metrics: BLEU = 15.1473, Gen Length = 47.6138


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].


TrainOutput(global_step=770, training_loss=3.0636291528677013, metrics={'train_runtime': 1556.0814, 'train_samples_per_second': 271.335, 'train_steps_per_second': 4.241, 'total_flos': 1669799557988352.0, 'train_loss': 3.0636291528677013, 'epoch': 7.0})