In [1]:
from datasets import load_from_disk

In [2]:
import torch
from unsloth import FastModel, is_bfloat16_supported
from trl import SFTTrainer, SFTConfig
from datasets import load_from_disk

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [3]:
# parameters 

xDIR_CHECKPOINT = "gemma_3n_kinyarwanda_checkpoints"
#xDIR_CHECKPOINT = "gemma_3n_kinyarwanda_checkpoints_16Bit"

xFile_examples = '/workspace/work/test_epoch_3_samples10000.txt'

xEPOCHS = 3 


In [4]:
# 1. Load Model & Processor
# -------------------------

xmodel = "unsloth/gemma-3n-E4B-it"

# 1. Load Model
model, processor = FastModel.from_pretrained(
    model_name = xmodel ,
    load_in_4bit = True, # False #True, we try not 4bit to see how it goes 
    max_seq_length = 2048, # Audio tokens take up space, don't go too small
)


==((====))==  Unsloth 2026.2.1: Fast Gemma3N patching. Transformers: 4.57.1. vLLM: 0.11.2.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.543 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
# 2. Add LoRA Adapters
# --------------------
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, 
    finetune_language_layers   = True, 
    finetune_attention_modules = True, 
    finetune_mlp_modules       = True, 
    r = 8, lora_alpha = 16, lora_dropout = 0,
    bias = "none", random_state = 3407,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
        # Audio specific layers
        "post", "linear_start", "linear_end", "embedding_projection",
    ],
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


In [6]:
# 3. Load Your Pre-processed Data
# -------------------------------
#xFld = '/workspace/work/AI_Training_dset/audio_gemma/train/' #   2500 samples 
#xFld = '/workspace/work/AI_Training_dset/audio_gemma/train_samples_5000/'  # 5000 samples 
#xFld = '/workspace/work/AI_Training_dset/audio_gemma/train_samples_7500/'  # 7.500 samples 
xFld = '/workspace/work/AI_Training_dset/audio_gemma/train_samples_10000/'  # 10000


dataset = load_from_disk(xFld)


dataset = dataset.shuffle()

print(f"Loaded dataset with {len(dataset)} rows.")


dataset

Loading dataset from disk:   0%|          | 0/26 [00:00<?, ?it/s]

Loaded dataset with 10000 rows.


Dataset({
    features: ['messages'],
    num_rows: 10000
})

In [7]:
# 4. Define the Custom Data Collator (CRITICAL STEP)
# --------------------------------------------------
def collate_fn(examples):
    texts = []
    audios = []

    for example in examples:
        # Get the messages list
        msgs = example["messages"]
        
        # 4a. Fix Roles (Safety Step)
        # Standardize 'model' to 'assistant' for the chat template if needed
        # (Gemma often uses 'model' natively, but HF templates usually expect 'assistant')
        cleaned_msgs = []
        for m in msgs:
            role = "assistant" if m["role"] == "model" else m["role"]
            content = m["content"]
            cleaned_msgs.append({"role": role, "content": content})

        # 4b. Extract Text using Chat Template
        text = processor.apply_chat_template(
            cleaned_msgs, tokenize = False, add_generation_prompt = False
        ).strip()
        texts.append(text)

        # 4c. Extract Audio (Specific to your structure)
        # Your structure: messages[0] -> content[0] -> 'audio' -> [0.0, ...]
        # We assume the audio is always in the first message's first content block
        try:
            audio_data = msgs[0]["content"][0]["audio"]
            if audio_data is None:
                 # Fallback search if audio isn't in the exact expected spot
                 for m in msgs:
                    for c in m["content"]:
                        if c.get("type") == "audio" and c.get("audio") is not None:
                            audio_data = c["audio"]
                            break
            audios.append(audio_data)
        except (KeyError, IndexError):
            print("Warning: Could not find audio in example.")
            audios.append(None) # This might crash the processor, ensure data is clean

    # 5. Batch Process (Tokenize Text + Encode Audio)
    # -----------------------------------------------
    batch = processor(
        text = texts, 
        audio = audios, 
        return_tensors = "pt", 
        padding = True
    )

    # 6. Create Labels & Masking
    # --------------------------
    labels = batch["input_ids"].clone()
    
    # Mask padding
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # Mask special multimodal tokens (so we don't try to predict the audio placeholder)
    if hasattr(processor.tokenizer, 'image_token_id'):
        labels[labels == processor.tokenizer.image_token_id] = -100
    if hasattr(processor.tokenizer, 'audio_token_id'):
        labels[labels == processor.tokenizer.audio_token_id] = -100
    if hasattr(processor.tokenizer, 'boi_token_id'):
        labels[labels == processor.tokenizer.boi_token_id] = -100
    if hasattr(processor.tokenizer, 'eoi_token_id'):
        labels[labels == processor.tokenizer.eoi_token_id] = -100

    batch["labels"] = labels
    return batch

In [8]:
# 7. Configure Trainer
# --------------------
trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    processing_class = processor.tokenizer,
    data_collator = collate_fn,  # Pass the function we defined above
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 20,
        #max_steps = 300,
        num_train_epochs = xEPOCHS, # 1 we test with two epochs 
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 50,
        output_dir = xDIR_CHECKPOINT,
        save_strategy = "steps",
        save_steps = 100,
        
        # Audio Tuning Requirements
        remove_unused_columns = False, 
        dataset_text_field = "",  # Leave empty because we use a collator
        dataset_kwargs = {"skip_prepare_dataset": True}, # We processed data manually
    )
)

In [9]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA GeForce RTX 4090. Max memory = 23.543 GB.
11.973 GB of memory reserved.


In [10]:
# 8. Run Training
# ---------------
print("Starting training...")
#trainer_stats = trainer.train()

trainer_stats = trainer.train(resume_from_checkpoint = True)

Starting training...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 10,000 | Num Epochs = 3 | Total steps = 3,750
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 21,188,608 of 7,871,166,800 (0.27% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
2550,0.2251
2600,0.2187
2650,0.2394
2700,0.2487
2750,0.2328
2800,0.2312
2850,0.2315
2900,0.2318
2950,0.2487
3000,0.2342


In [11]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")


2987.2172 seconds used for training.
49.79 minutes used for training.
Peak reserved memory = 12.547 GB.
Peak reserved memory for training = 0.574 GB.
Peak reserved memory % of max memory = 53.294 %.
Peak reserved memory for training % of max memory = 2.438 %.


## Inference 

In [12]:
import torch
import numpy as np
import librosa  # We need this for resampling if your audio isn't 16k
from jiwer import wer, cer
from tqdm import tqdm
import time 

In [13]:
# DATA 

### HERE WE TAKE 10 random samples 

xFld_test = '/workspace/work/AI_Training_dset/audio_badrex_kinyarwanda-speech-1000h/validation'

#cache 
#cache_file = "/tmp/shuffled_test.arrow"
ds = load_from_disk(xFld_test)
#xdset = ds.shuffle(seed=50, indices_cache_file_name=cache_file).select(range(20))
xdset = ds.shuffle(seed=50).select(range(20))

xdset 

Dataset({
    features: ['audio_id', 'audio', 'transcription', 'sampling_rate'],
    num_rows: 20
})

In [14]:
# ------------------------------------------------------------------------
# 2. DEFINE PROMPTS & RESAMPLER
# ------------------------------------------------------------------------
system_prompt = "You are an assistant that transcribes speech accurately."
user_instruction = "Please transcribe this Kinyarwanda audio."

def ensure_16k(audio_data, source_sr):
    """
    Gemma-3N expects 16kHz audio. 
    If your dataset says sampling_rate is 44100 or 48000, we must resample.
    """
    audio_np = np.array(audio_data, dtype=np.float32)
    
    if source_sr != 16000:
        # Using librosa to resample quickly
        audio_np = librosa.resample(audio_np, orig_sr=source_sr, target_sr=16000)
        
    return audio_np

In [15]:
# 1. CRITICAL: Set Padding Side to LEFT for Generation
# Decoder-only models (like Gemma) generate from left to right. 
# If you pad on the right, the model sees [audio, pad, pad] and gets confused.
# It must be [pad, pad, audio] so generation starts immediately after the audio tokens.
processor.tokenizer.padding_side = "left" 

def run_stable_batch_inference(dataset, batch_size=4):
    # 2. Sort dataset by audio length to minimize padding
    # This reduces the chance of "out of bounds" errors caused by massive padding gaps
    # We add a temporary column for length, sort, and then remove it
    dataset = dataset.map(lambda x: {"len": len(x["audio"])})
    dataset = dataset.sort("len", reverse=True) # Process longest first (often helps OOM)
    
    all_predictions = []
    all_references = []
    
    # Iterate with the sorted dataset
    for i in tqdm(range(0, len(dataset), batch_size), desc="Processing Batches"):
        batch = dataset[i : i + batch_size]
        
        batch_audio_lists = batch['audio']
        batch_references = batch['transcription']
        batch_srs = batch['sampling_rate']
        
        prompts_text = []
        processed_audios = []
        
        for j, raw_audio_list in enumerate(batch_audio_lists):
            current_sr = batch_srs[j]
            audio_array = ensure_16k(raw_audio_list, current_sr)
            processed_audios.append(audio_array)
            
            messages = [
                {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
                {"role": "user", "content": [
                    {"type": "audio", "audio": audio_array}, 
                    {"type": "text", "text": user_instruction}
                ]}
            ]
            prompts_text.append(
                processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            )

        # 3. Generate Inputs
        inputs = processor(
            text=prompts_text,
            audio=processed_audios,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,          # This will now apply LEFT padding
        ).to("cuda")

        # 4. Run Generation
        # We use a try-except block to gracefully handle any lingering shape errors
        try:
            with torch.no_grad():
                generated_ids = model.generate(
                    **inputs,
                    max_new_tokens=256,
                    do_sample=False,
                    use_cache=True,
                )
            
            # Decode output
            input_len = inputs.input_ids.shape[1]
            new_tokens = generated_ids[:, input_len:]
            decoded_output = processor.batch_decode(
                new_tokens, 
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )
            
            all_predictions.extend(decoded_output)
            all_references.extend(batch_references)

        except RuntimeError as e:
            if "setStorage" in str(e):
                print(f"\n[Warning] Batch {i//batch_size} failed with storage error. Retrying with batch_size=1...")
                # Fallback: If a specific batch fails, run its items one by one
                for k in range(len(prompts_text)):
                    single_pred = run_single_inference(
                        prompts_text[k], processed_audios[k], model, processor
                    )
                    all_predictions.append(single_pred)
                    all_references.append(batch_references[k])
            else:
                raise e # Re-raise if it's a different error

    return all_predictions, all_references

# Helper for the fallback mechanism
def run_single_inference(text, audio, model, processor):
    inputs = processor(
        text=[text],
        audio=[audio],
        sampling_rate=16000,
        return_tensors="pt",
        padding=False # No padding needed for single item
    ).to("cuda")
    
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False)
    
    input_len = inputs.input_ids.shape[1]
    return processor.batch_decode(generated_ids[:, input_len:], skip_special_tokens=True)[0]

In [16]:
# Assuming 'my_dataset' is your variable name

print('start:', time.asctime())

predictions, references = run_stable_batch_inference(xdset, batch_size=2)

print('end:', time.asctime())

start: Sat Feb 14 10:53:09 2026


Processing Batches: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [01:20<00:00,  8.04s/it]

end: Sat Feb 14 10:54:29 2026





In [17]:
# ------------------------------------------------------------------------
# 5. METRICS
# ------------------------------------------------------------------------
print("\n--- Evaluation ---")
wer_score = wer(references, predictions)
cer_score = cer(references, predictions)

print(f"WER: {wer_score:.4f}")
print(f"CER: {cer_score:.4f}")


--- Evaluation ---
WER: 0.1975
CER: 0.0508


In [18]:
# Optional: Print a few examples


with open(xFile_examples , 'w') as xff:
    xff.write(f"WER: {wer_score:.4f}")
    xff.write('\n')    
    xff.write(f"CER: {cer_score:.4f}")
    xff.write('\n#############################')       
    for i in range(min(20, len(predictions))):
        xff.write(f"\nRef:  {references[i]}")
        xff.write('\n\n')
        xff.write(f"Pred: {predictions[i]}")              
        xff.write('\n--------')        
              

for i in range(min(20, len(predictions))):
    print(f"\nRef:  {references[i]}")
    print(f"Pred: {predictions[i]}")




Ref:  Igikoresho cyifashishwa n'abaganga bavura amagufwa ndetse n'imvune mu gihe barimo kuvura abarwayi bagize imvune zo mu ivi. Aho iki gikoresho bacyambikwa kugira ngo kibafashe gukanda ndetse no gusubiranya amagufwa n'imikaya byagize ikibazo igihe umuntu agize imvune. Ibi bikagaragaza iterambere mu buvuzi aho ibikoresho bigezweho bikoreshwa mu kunoza iyi serivisi.
Pred: Igikoresho kifashishwa n'abaganga bavura amagufwa ndetse n'imvune mu gihe barimo kuvura abarwayi bagize imvune zo mu ivi, aho iki gikoresho cyandikwa kugira ngo kibafashe gukanda ndetse no gusubiranya amagufwa n'ibikaya byagize ikibazo igihe umuntu ageze imvune, ibikagaragaza iterambere mu buvuzi aho ibikoresho byegizweho bikoreshwa mu kunoza iyi serivisi.

Ref:  Umugore wambaye agapira k'umukara imbere ye hicaye undi wambaye ishati y'umutuku ndetse n'iy'umweru ari kumwambika agakoresho ku kuboko kifashishwa mu kumenya umuvuduko w'amaraso, imbere yabo hari ameza manini ateretseho indangururamajwi hirya yabo hari aba

In [19]:
CombinedError= (0.4 * wer_score) + (0.6 * cer_score) 

Score = (1 - CombinedError) * 100

Score

89.0483989151917

In [20]:
model.save_pretrained_merged("/workspace/work/gemma_3n_kin_10000_epochs_3", processor)

Found HuggingFace hub cache directory: /workspace/.cache/huggingface/hub


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Checking cache directory for required files...
Cache check failed: model-00001-of-00004.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Checking cache directory for required files...
Cache check failed: tokenizer.model not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Preparing safetensor model files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files:  25%|â–ˆâ–ˆâ–Œ       | 1/4 [00:24<01:13, 24.64s/it]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 2/4 [01:03<01:05, 32.77s/it]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files:  75%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ  | 3/4 [01:43<00:36, 36.32s/it]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [02:06<00:00, 31.61s/it]
Unsloth: Merging weights into 16bit: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [00:29<00:00,  7.37s/it]


Unsloth: Merge process complete. Saved to `/workspace/work/gemma_3n_kin_10000_epochs_3`
