# Fine-tune Whisper on CDLI Non-Standard Speech Datasets



In [30]:
from huggingface_hub import login
HF_TOKEN = input()
login(token=HF_TOKEN)

 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


## Settings 

--> adapt for your scenario

### Directories

In [32]:
import os 

# storage in Volume that will persist
LOCAL_STORAGE_DIR = '/jupyter_kernel'

BASE_DIR = os.path.join(LOCAL_STORAGE_DIR, 'trained_models')
!mkdir -p {BASE_DIR}

# directory for model training
OUTPUT_DIR = os.path.join(BASE_DIR, 'en_nonstandard_tune_whisper_small_2')
#OUTPUT_DIR = os.path.join(BASE_DIR, 'en_nonstandard_tune_whisper_tiny_1')

print(f"Will write model to: {OUTPUT_DIR}")
if os.path.exists(OUTPUT_DIR):
    raise ValueError(f"Output directory already exists - if you continue this will overwrite data and may lead to strange results...")


Will write model to: /jupyter_kernel/trained_models/en_nonstandard_tune_whisper_small_2


### Model and Dataset settings

In [33]:

#WHISPER_MODEL_TYPE = "openai/whisper-tiny" 
WHISPER_MODEL_TYPE = "openai/whisper-small" 
# WHISPER_MODEL_TYPE = "openai/whisper-large-v3" 

LANGUAGE = 'en'
DATASET_NAME = "cdli/kenyan_english_nonstandard_speech_v0.9"

# LANGUAGE = 'sw'
#DATASET_NAME = "cdli/kenyan_swahili_nonstandard_speech_v0.9"


In [34]:
# which parts of the model to update
UPDATE_ENCODER = True
UPDATE_PROJ = True
UPDATE_DECODER = False

# Turn on SpecAugment
USE_SPECAUGMENT = True

In [35]:

#######################
## don't change these!
######################


TASK = "transcribe"

BASE_MODEL_NAME = WHISPER_MODEL_TYPE
print('Base model will be loaded from:', BASE_MODEL_NAME)

Base model will be loaded from: openai/whisper-small


### Trainer Settings

--> adjust as needed or keep defaults (these settings should be a good starting point)

In [36]:
LOGGING_STEPS = 10  # Increased slightly since training will be longer
# if save steps is 0, only last and best model will be written
SAVE_STEPS = 100    # Increased to reduce checkpoint frequency

# training duration
MAX_EPOCHS = 8      # Reduced from 10 for more reasonable training time
MAX_STEPS = 600     # Reduced from 1000 - sufficient for small model

# Learning Rate and LR Scheduler (LR_END and LR_DECAY_POWER only apply to polynomial)
LEARNING_RATE = 1e-4 #@param - Good for small model
LR_SCHEDULER_TYPE = 'polynomial' # constant_with_warmup or polynomial
LR_WARMUP_STEPS = 50
LR_END = 1e-8
LR_DECAY_POWER = 4

BATCH_SIZE = 12     # Reduced from 32 - small model needs smaller batches
EVAL_BATCH_SIZE = 8 # Reduced from 16

#@markdown other settings relevant for evaluation
MAX_GEN_LEN = 128 # increase if your data has long sequences!
EVAL_ON_START = True
EVAL_STEPS = 50    # Good frequency for monitoring

# for CPU, set both to false
USE_FP16 = True    # Keep enabled for memory efficiency
USE_BF16 = False   # only some GPUs support this, eg A100, A40

# checkpoints get huge for large models (~18 GB!)
NUM_CHECKPOINTS_TO_STORE = 2

## Imports and Prep

--> no need to change anything here, just run

In [37]:
import datasets
from huggingface_hub import hf_hub_download
import numpy as np
import pandas as pd
import os
import torch

# more efficient dataset handling
datasets.disable_caching()
print('cache:', datasets.is_caching_enabled())

torch.set_num_threads(1)
torch.get_num_threads()

# check if we have gpu
if torch.cuda.is_available():
    print("GPU is available")
else:
    print("GPU is not available, using CPU instead")

cache: False
GPU is available


In [38]:
from huggingface_hub import hf_hub_download

import random
import torchaudio
import librosa


import tarfile
import datasets
import matplotlib.pyplot as plt
import pandas as pd

import torch
import time


from dataclasses import dataclass
from typing import Any, Dict, List, Union

from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer

from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration
import os
import csv
import shutil
import numpy as np


import evaluate
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

transcript_normalizer = BasicTextNormalizer()

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

In [39]:
def count_trainable_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

In [40]:
def get_wer(references, predictions, normalize=True, verbose=True):
  rs = references
  ps = predictions
  if normalize:
    ps = [transcript_normalizer(x) for x in predictions]
    rs = [transcript_normalizer(x) for x in references]
  if verbose:
    for r, p in zip(rs, ps):
      print(r)
      print(p)
      print()

  return wer_metric.compute(references=rs, predictions=ps)

def compute_lattescore(references, predictions, similarity_threshold=0.7):
    """
    LATTEScore calculation - percentage of transcripts that preserve meaning
    Based on the paper: Large Language Models As A Proxy For Human Evaluation
    This is a simplified version using semantic similarity
    """
    try:
        # Try to use sentence transformers for better semantic similarity
        from sentence_transformers import SentenceTransformer, util
        model = SentenceTransformer('all-MiniLM-L6-v2')
        
        preserved_count = 0
        for ref, pred in zip(references, predictions):
            # Skip empty strings
            if not ref.strip() or not pred.strip():
                continue
                
            # Get sentence embeddings
            emb_ref = model.encode(ref, convert_to_tensor=True)
            emb_pred = model.encode(pred, convert_to_tensor=True)
            
            # Calculate cosine similarity
            similarity = util.pytorch_cos_sim(emb_ref, emb_pred).item()
            
            # Consider meaning preserved if similarity > threshold
            if similarity > similarity_threshold:
                preserved_count += 1
        
        total_count = len([r for r in references if r.strip()])
        lattescore = (preserved_count / total_count) * 100 if total_count > 0 else 0
        
    except ImportError:
        # Fallback: use WER-based approximation if sentence-transformers not available
        print("Sentence transformers not available, using WER-based LATTEScore approximation")
        preserved_count = 0
        for ref, pred in zip(references, predictions):
            # Skip empty strings
            if not ref.strip() or not pred.strip():
                continue
                
            wer = wer_metric.compute(predictions=[pred], references=[ref])
            # Conservative threshold: meaning preserved if WER < 0.3 (30%)
            if wer < 0.3:
                preserved_count += 1
        
        total_count = len([r for r in references if r.strip()])
        lattescore = (preserved_count / total_count) * 100 if total_count > 0 else 0
    
    return lattescore

def compute_metrics(pred):
    """
    Compute metrics for ASR evaluation including WER, CER, and LATTEScore
    """
    # for training metrics
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_strs = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_strs = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # calculate a per-example average for WER and CER
    wers = []
    cers = []
    for pred_str, label_str in zip(pred_strs, label_strs):
        p = transcript_normalizer(pred_str)
        l = transcript_normalizer(label_str)
        # Skip empty strings for metric calculation
        if l.strip() and p.strip():
            wer = wer_metric.compute(predictions=[p], references=[l])
            cer = cer_metric.compute(predictions=[p], references=[l])
            wers.append(wer)
            cers.append(cer)

    wer = np.mean([min(1.0, x) for x in wers]) if wers else 1.0
    cer = np.mean([min(1.0, x) for x in cers]) if cers else 1.0
    
    # Calculate LATTEScore
    lattescore = compute_lattescore(label_strs, pred_strs)
    
    print('=== Metrics ===')
    print(f'Adjusted WER: {wer:.4f}')
    print(f'Adjusted CER: {cer:.4f}')
    print(f'LATTEScore: {lattescore:.2f}%')
    print(f'Un-adjusted WER: {np.mean(wers) if wers else 1.0:.4f}')
    print(f'Un-adjusted CER: {np.mean(cers) if cers else 1.0:.4f}')
    print('===============')
    
    return {
        "wer": wer, 
        "cer": cer, 
        "lattescore": lattescore
    }

# Optional: Add a function to analyze model quality based on LATTEScore
def analyze_model_deployment(lattescore, threshold=80.0):
    """
    Analyze if model meets quality standards for deployment based on LATTEScore
    Using the 80% threshold from the research paper
    """
    print(f"\n=== Model Deployment Analysis ===")
    print(f"LATTEScore: {lattescore:.2f}%")
    print(f"Deployment Threshold: {threshold}%")
    
    if lattescore >= threshold:
        print("✅ RECOMMENDATION: Model meets quality standards for deployment")
        print("   The ASR model preserves meaning in most transcripts")
    else:
        print("❌ RECOMMENDATION: Model does not meet quality standards")
        print("   Consider: More training data, hyperparameter tuning, or different architecture")
    
    return lattescore >= threshold

In [41]:
def load_dataset(dataset_name, split='test', limit_to_30_seconds=True):
    """
    Load a dataset from Hugging Face Hub.
    If limit_to_30_seconds is True, will only load examples with audio length <= 30 seconds.
    """
    if split not in ['train', 'test', 'validation']:
        raise ValueError("split must be one of 'train', 'test', or 'validation'")
    ds = datasets.load_dataset(dataset_name, split=split, streaming=False)
    orig_len = len(ds)
    if limit_to_30_seconds:
        ds = ds.filter(lambda example: example['audio_length'] <= 30)
        print(f"Filtered dataset from {orig_len} to {len(ds)} examples with audio length <= 30 seconds")
    return ds

In [42]:
# The following warning can be ignored:
# "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
# See: https://discuss.huggingface.co/t/finetuning-whisper-attention-mask-not-set-and-canot-be-inferred/97456
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

## Download datasets and prepare features

--> no need to change anything here, just run

### Optimizing some settings for dataset access

In [43]:
datasets.disable_caching()
print('cache:', datasets.is_caching_enabled())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is: ', device)

# IMPORTANT! need to set to 1 to avoid the mapping to hang!
torch.set_num_threads(1)
torch.get_num_threads()

num_proc = min(32, os.cpu_count())
print('# processors:', num_proc)



cache: False
device is:  cuda
# processors: 24


### Load feature extractor

--> for the model type you specified above

In [44]:

# Load processor
print('Using Language: ', LANGUAGE)
print('Using model:', WHISPER_MODEL_TYPE)
processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_TYPE, language=LANGUAGE, task=TASK)

# since this tokenizer isn't a FastTokenizer, so there is no point in running it with is_batched=True
# see: processor.tokenizer.is_fast
def prepare_features(example):
    example["input_features"] = processor.feature_extractor(example["audio"]["array"], sampling_rate=example["audio"]["sampling_rate"]).input_features[0]
    example["labels"] = processor.tokenizer(example["transcription"]).input_ids
    # also count number of tokens
    example["token_length"] = len(example["labels"])
    return example

Using Language:  en
Using model: openai/whisper-small


### Load non-standard speech dataset

We need to filter to 30 seconds, as Whisper can only train on that.

In [45]:
train_dataset = load_dataset(DATASET_NAME, split='train', limit_to_30_seconds=True)
train_dataset = train_dataset.map(prepare_features, remove_columns=['audio'], writer_batch_size=1, num_proc=num_proc)
print(f"Loaded TRAIN dataset with {len(train_dataset)} examples")

Filter:   0%|          | 0/4236 [00:00<?, ? examples/s]

Filtered dataset from 4236 to 3130 examples with audio length <= 30 seconds


Map (num_proc=24):   0%|          | 0/3130 [00:00<?, ? examples/s]

Loaded TRAIN dataset with 3130 examples


In [46]:
test_dataset = load_dataset(DATASET_NAME, split='test', limit_to_30_seconds=True)
test_dataset = test_dataset.map(prepare_features, remove_columns=['audio'], writer_batch_size=1, num_proc=num_proc)
print(f"Loaded TEST dataset with {len(test_dataset)} examples")

Filter:   0%|          | 0/993 [00:00<?, ? examples/s]

Filtered dataset from 993 to 705 examples with audio length <= 30 seconds


Map (num_proc=24):   0%|          | 0/705 [00:00<?, ? examples/s]

Loaded TEST dataset with 705 examples


In [47]:
dev_dataset = load_dataset(DATASET_NAME, split='validation', limit_to_30_seconds=True)
dev_dataset = dev_dataset.map(prepare_features, remove_columns=['audio'], writer_batch_size=1, num_proc=num_proc)
print(f"Loaded DEV dataset with {len(dev_dataset)} examples")

Filter:   0%|          | 0/572 [00:00<?, ? examples/s]

Filtered dataset from 572 to 342 examples with audio length <= 30 seconds


Map (num_proc=24):   0%|          | 0/342 [00:00<?, ? examples/s]

Loaded DEV dataset with 342 examples


## Prepare Trainer

--> no need to change anything here, just run

Whenever something is changed in the settings, you need to rerun this part.

In [48]:
base_model = WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_NAME)
_ = base_model.to(device)
print('Using Language: ', LANGUAGE)
print('Using model:', WHISPER_MODEL_TYPE)

# ensure task and language for training
base_model.generation_config.language = LANGUAGE
base_model.generation_config.task = TASK
base_model.generation_config.forced_decoder_ids = None
base_model.config.forced_decoder_ids = None
# to use gradient checkpointing
base_model.config.use_cache = False
print('language set to:', base_model.generation_config.language)

Using Language:  en
Using model: openai/whisper-small
language set to: en


In [49]:
# Add SpecAugment
if USE_SPECAUGMENT:
    base_model.config.apply_spec_augment = USE_SPECAUGMENT

    # Specaugment (use default settings, as per paper)
    # time masking
    base_model.config.mask_time_prob = 0.05
    base_model.config.mask_time_length = 10
    base_model.config.mask_time_min_masks = 2

    # feature masking
    base_model.config.mask_feature_prob = 0.05 # def: 0
    base_model.config.mask_feature_length = 10
    base_model.config.mask_feature_min_masks = 2 # def: 0

print('Using specaugment:', base_model.config.apply_spec_augment)

# Add Audio Augmentation functions
import torchaudio
import torchaudio.transforms as T
import numpy as np

def apply_audio_augmentation(audio_array, sample_rate=16000, augmentation_prob=0.7):
    """
    Apply audio augmentations to training data
    augmentation_prob: probability of applying any augmentation to an example
    """
    if np.random.random() > augmentation_prob:
        return audio_array  # No augmentation applied
    
    # Convert to tensor for torchaudio transforms
    audio_tensor = torch.from_numpy(audio_array).float()
    original_length = len(audio_tensor)
    
    # Track which augmentations were applied for debugging
    applied_augmentations = []
    
    # 1. Add background noise (30% chance if augmenting)
    if np.random.random() < 0.3:
        noise_level = np.random.uniform(0.001, 0.01)
        noise = torch.randn_like(audio_tensor) * noise_level
        audio_tensor = audio_tensor + noise
        applied_augmentations.append(f"noise({noise_level:.3f})")
    
    # 2. Time masking (25% chance if augmenting)
    if np.random.random() < 0.25 and original_length > 1000:
        mask_length = np.random.randint(50, 300)
        mask_start = np.random.randint(0, max(1, original_length - mask_length))
        audio_tensor[mask_start:mask_start + mask_length] = 0
        applied_augmentations.append(f"time_mask({mask_length})")
    
    # 3. Pitch shift (20% chance if augmenting)
    if np.random.random() < 0.2:
        try:
            n_steps = np.random.choice([-2, -1, 1, 2])
            pitch_shift = T.PitchShift(sample_rate, n_steps=n_steps)
            audio_tensor = pitch_shift(audio_tensor.unsqueeze(0)).squeeze(0)
            applied_augmentations.append(f"pitch({n_steps})")
        except Exception as e:
            # Fallback to simple resampling if pitch shift fails
            try:
                speed_factor = 1.0 + (n_steps * 0.1)
                new_length = int(original_length / speed_factor)
                audio_tensor = torch.nn.functional.interpolate(
                    audio_tensor.unsqueeze(0).unsqueeze(0), 
                    size=new_length, 
                    mode='linear'
                ).squeeze(0).squeeze(0)
                if len(audio_tensor) > original_length:
                    audio_tensor = audio_tensor[:original_length]
                else:
                    audio_tensor = torch.nn.functional.pad(
                        audio_tensor, 
                        (0, original_length - len(audio_tensor))
                    )
                applied_augmentations.append(f"pitch_approx({n_steps})")
            except:
                pass
    
    # 4. Time stretching - speed up/slow down (25% chance if augmenting)
    if np.random.random() < 0.25 and original_length > 500:
        rate = np.random.uniform(0.85, 1.15)  # 15% speed variation
        new_length = int(original_length * rate)
        
        if new_length > 100:  # Ensure reasonable length
            try:
                # High-quality time stretching using interpolation
                audio_tensor_stretched = torch.nn.functional.interpolate(
                    audio_tensor.unsqueeze(0).unsqueeze(0), 
                    size=new_length, 
                    mode='linear',
                    align_corners=False
                ).squeeze(0).squeeze(0)
                
                # Trim or pad to original length to maintain consistency
                if len(audio_tensor_stretched) > original_length:
                    audio_tensor = audio_tensor_stretched[:original_length]
                else:
                    padding = torch.zeros(original_length - len(audio_tensor_stretched))
                    audio_tensor = torch.cat([audio_tensor_stretched, padding])
                
                applied_augmentations.append(f"time_stretch({rate:.2f})")
            except Exception as e:
                pass
    
    # 5. Volume change (35% chance if augmenting)
    if np.random.random() < 0.35:
        gain = np.random.uniform(0.5, 1.5)
        audio_tensor = audio_tensor * gain
        # Clip to prevent distortion
        audio_tensor = torch.clamp(audio_tensor, -1.0, 1.0)
        applied_augmentations.append(f"volume({gain:.2f})")
    
    # 6. Low-pass filter - room simulation (15% chance if augmenting)
    if np.random.random() < 0.15 and original_length > 1000:
        try:
            cutoff_freq = np.random.uniform(2000, 4000)  # Simulate telephone quality
            lowpass = T.LowpassBiquad(sample_rate, cutoff_freq=cutoff_freq)
            audio_tensor = lowpass(audio_tensor.unsqueeze(0)).squeeze(0)
            applied_augmentations.append(f"lowpass({cutoff_freq:.0f}Hz)")
        except Exception as e:
            pass
    
    # 7. High-pass filter (10% chance if augmenting) - remove low frequencies
    if np.random.random() < 0.1 and original_length > 1000:
        try:
            cutoff_freq = np.random.uniform(100, 500)
            highpass = T.HighpassBiquad(sample_rate, cutoff_freq=cutoff_freq)
            audio_tensor = highpass(audio_tensor.unsqueeze(0)).squeeze(0)
            applied_augmentations.append(f"highpass({cutoff_freq:.0f}Hz)")
        except Exception as e:
            pass
    
    # Debug: print augmentations for first few examples
    if len(applied_augmentations) > 0 and np.random.random() < 0.01:  # 1% of augmented examples
        print(f"Applied augmentations: {', '.join(applied_augmentations)}")
    
    return audio_tensor.numpy()

# Audio augmentation settings (add these to your settings cell)
USE_AUDIO_AUGMENTATION = True
AUGMENTATION_PROB = 0.7  # 70% of training examples get augmentation

print('Using audio augmentation:', USE_AUDIO_AUGMENTATION)
print('Audio augmentation probability:', AUGMENTATION_PROB)
print('Using specaugment:', base_model.config.apply_spec_augment)
print('=== Using BOTH augmentation types: SpecAugment (feature-level) + Audio (signal-level) ===')

# Modified prepare_features function with augmentation support
def prepare_features(example, is_training=True):
    """
    Prepare features with optional augmentation for training data
    """
    audio_array = example["audio"]["array"]
    sample_rate = example["audio"]["sampling_rate"]
    
    # Apply augmentation ONLY to training data
    if is_training and USE_AUDIO_AUGMENTATION:
        audio_array = apply_audio_augmentation(audio_array, sample_rate, AUGMENTATION_PROB)
    
    # Extract features as before
    example["input_features"] = processor.feature_extractor(
        audio_array, 
        sampling_rate=sample_rate
    ).input_features[0]
    
    example["labels"] = processor.tokenizer(example["transcription"]).input_ids
    example["token_length"] = len(example["labels"])
    
    return example

# Create wrapper functions for dataset mapping
def prepare_train_features(example):
    return prepare_features(example, is_training=True)

def prepare_eval_features(example):
    return prepare_features(example, is_training=False)

print("Audio augmentation functions defined. Use prepare_train_features for training data and prepare_eval_features for test/dev data.")

Using specaugment: True
Using audio augmentation: True
Audio augmentation probability: 0.7
Using specaugment: True
=== Using BOTH augmentation types: SpecAugment (feature-level) + Audio (signal-level) ===
Audio augmentation functions defined. Use prepare_train_features for training data and prepare_eval_features for test/dev data.


In [50]:
# which layers to tune

print("Updating encoder:", UPDATE_ENCODER)
print("Updating projection layer:", UPDATE_PROJ)
print("Updating decoder:", UPDATE_DECODER)


base_model.model.encoder.requires_grad_(UPDATE_ENCODER)
base_model.model.decoder.requires_grad_(UPDATE_DECODER)
base_model.proj_out.requires_grad_(UPDATE_PROJ)

print("Overview to number of model parameters to be updated:")
print('* encoder params to update/total:', count_trainable_parameters(base_model.model.encoder), base_model.model.encoder.num_parameters())
print('* decoder parans to update/total:', count_trainable_parameters(base_model.model.decoder), base_model.model.decoder.num_parameters())

print('* overall # trainable parameters:', count_trainable_parameters(base_model))
print('*     overall # model parameters:', base_model.model.num_parameters())

Updating encoder: True
Updating projection layer: True
Updating decoder: False
Overview to number of model parameters to be updated:
* encoder params to update/total: 88154112 88154112
* decoder parans to update/total: 39832320 153580800
* overall # trainable parameters: 127986432
*     overall # model parameters: 241734912


In [51]:
# Training Hyper Parameters
# don't change settings here, but instead at very top!
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    logging_dir=os.path.join(OUTPUT_DIR, 'logs'),
    logging_steps=LOGGING_STEPS,
    report_to=["tensorboard"],
    include_num_input_tokens_seen=True,
    ### on GPU, can either do fp16 or bf16 depending on specific GPU
    fp16=USE_FP16, 
    bf16=USE_BF16, 
    push_to_hub=False,
    remove_unused_columns=False,
    #
    num_train_epochs=MAX_EPOCHS,
    max_steps=MAX_STEPS,
    #
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    #
    per_device_train_batch_size=BATCH_SIZE,
    #
    eval_on_start=EVAL_ON_START,
    predict_with_generate=True,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    eval_steps=EVAL_STEPS,
    eval_strategy="steps",
    generation_max_length=MAX_GEN_LEN,
    #
    metric_for_best_model="wer",
    greater_is_better=False,
    #
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    #
    # only applies to polynomial schedule (constant ignores args)
    lr_scheduler_kwargs={
        "lr_end": LR_END, # The final LR.  Crucial for polynomial decay.
        "power": LR_DECAY_POWER, # for decay
        # we don't need to set the other arguments as they are already set in the args outside
        #"num_warmup_steps": WARMUP_STEPS, # The number of steps for the warmup phase.
        #"num_training_steps": MAX_STEPS, # The total number of training steps.
        #"lr_init": 1e-5 # we take the LR setting
    },

    learning_rate=LEARNING_RATE,
    warmup_steps=LR_WARMUP_STEPS, # what happens if we have this and the LR schedule args ?
    #
    save_steps=SAVE_STEPS,
    save_strategy="steps",
    save_total_limit=NUM_CHECKPOINTS_TO_STORE,
    load_best_model_at_end=True,
    # group_by_length=True
    # auto_find_batch_size=True
)

print('trainer args set, writing to:', OUTPUT_DIR)

trainer args set, writing to: /jupyter_kernel/trained_models/en_nonstandard_tune_whisper_small_2


In [52]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=base_model.config.decoder_start_token_id,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor
)


Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


## Run the training

Note: tensorboard doesn't show properly in jupyter notebooks, use the tensorboard_server.py tool to host a tensorboard instance on Modal, using below model training dir:

In [53]:
print('model training dir:', OUTPUT_DIR)

model training dir: /jupyter_kernel/trained_models/en_nonstandard_tune_whisper_small_2


In [54]:
# train from scratch
trainer.train()

# # alternatively, you can continue training if a previous job was interrupted
#trainer.train(resume_from_checkpoint = True)


Step,Training Loss,Validation Loss,Wer,Cer,Lattescore,Input Tokens Seen
0,No log,1.416936,0.278586,0.166302,48.245614,0
50,0.999600,0.796919,0.234155,0.132961,60.526316,144000000
100,0.958800,0.756722,0.24074,0.144147,61.988304,288000000
150,0.802600,0.724311,0.22736,0.1357,64.619883,432000000
200,0.820200,0.702249,0.208163,0.121648,67.836257,576000000
250,0.882900,0.675111,0.213495,0.128683,66.959064,720000000
300,0.484300,0.662475,0.20184,0.119898,69.298246,863520000
350,0.706600,0.664432,0.201682,0.122105,69.298246,1007520000
400,0.489400,0.663069,0.202475,0.121708,67.836257,1151520000
450,0.466500,0.660234,0.200163,0.11913,68.128655,1295520000


Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2786
Adjusted CER: 0.1663
LATTEScore: 48.25%
Un-adjusted WER: 0.2915
Un-adjusted CER: 0.1788
Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2342
Adjusted CER: 0.1330
LATTEScore: 60.53%
Un-adjusted WER: 0.2394
Un-adjusted CER: 0.1330
Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2407
Adjusted CER: 0.1441
LATTEScore: 61.99%
Un-adjusted WER: 0.2798
Un-adjusted CER: 0.1715




Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2274
Adjusted CER: 0.1357
LATTEScore: 64.62%
Un-adjusted WER: 0.2512
Un-adjusted CER: 0.1427
Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2082
Adjusted CER: 0.1216
LATTEScore: 67.84%
Un-adjusted WER: 0.2476
Un-adjusted CER: 0.1524
Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2135
Adjusted CER: 0.1287
LATTEScore: 66.96%
Un-adjusted WER: 0.2769
Un-adjusted CER: 0.1776
Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2018
Adjusted CER: 0.1199
LATTEScore: 69.30%
Un-adjusted WER: 0.2672
Un-adjusted CER: 0.1653
Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.2017
Adjusted CER: 0.1221
LATTEScore: 69.30%
Un-adjusted WER: 0.2659
Un-adjust

There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=600, training_loss=0.6770224614938101, metrics={'train_runtime': 2553.761, 'train_samples_per_second': 2.819, 'train_steps_per_second': 0.235, 'total_flos': 2.07666054070272e+18, 'train_loss': 0.6770224614938101, 'epoch': 2.2988505747126435, 'num_input_tokens_seen': 1727040000})

## Post-Training Evaluation

when you run this after your training has finished it will use the best checkpoint (because we set "load_best_model_at_end=True" in the trainer args)

### On DEV set

In [55]:
# (should give the same result shown in trainig progress on dev set)
trainer.evaluate(dev_dataset, language=LANGUAGE)

Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.1972
Adjusted CER: 0.1169
LATTEScore: 69.30%
Un-adjusted WER: 0.2498
Un-adjusted CER: 0.1640


{'eval_loss': 0.6597035527229309,
 'eval_wer': 0.19719239518466153,
 'eval_cer': 0.1169473428980371,
 'eval_lattescore': 69.2982456140351,
 'eval_runtime': 116.9009,
 'eval_samples_per_second': 2.926,
 'eval_steps_per_second': 0.368,
 'epoch': 2.2988505747126435,
 'num_input_tokens_seen': 1727040000}

### On TEST set

In [56]:
# run on dev-set 
# (should give the same result shown in trainig progress on dev set)
trainer.evaluate(test_dataset, language=LANGUAGE)

Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.1228
Adjusted CER: 0.0681
LATTEScore: 77.73%
Un-adjusted WER: 0.1358
Un-adjusted CER: 0.0717


{'eval_loss': 0.5866317749023438,
 'eval_wer': 0.12280530354136088,
 'eval_cer': 0.06807830036387313,
 'eval_lattescore': 77.73049645390071,
 'eval_runtime': 242.9586,
 'eval_samples_per_second': 2.902,
 'eval_steps_per_second': 0.366,
 'epoch': 2.2988505747126435,
 'num_input_tokens_seen': 1727040000}

In [60]:
from IPython.display import FileLink, display
import pandas as pd
from datasets import load_dataset

# --- Load external speaker metadata directly ---
metadata_url = "https://huggingface.co/datasets/cdli/kenyan_english_nonstandard_speech_v0.9/resolve/main/speaker_metadata.tsv"
speaker_metadata_ds = load_dataset("csv", data_files=metadata_url, sep="\t")["train"]
speaker_metadata = pd.DataFrame(speaker_metadata_ds)

# Drop unwanted columns
speaker_metadata = speaker_metadata.drop(columns=["comments", "slp_id"], errors="ignore")

print("Speaker metadata loaded:", speaker_metadata.shape)
print("Columns:", speaker_metadata.columns.tolist())
print(speaker_metadata.head())

# --- Function to calculate individual WER/accuracy ---
def calculate_individual_wer(prediction, reference):
    return wer_metric.compute(predictions=[prediction], references=[reference])

# --- LATTEScore calculation function ---
def calculate_lattescore(prediction, reference, similarity_threshold=0.7):
    """
    Calculate whether meaning is preserved for a single transcript pair
    Returns 1 if meaning preserved, 0 if meaning lost
    """
    try:
        from sentence_transformers import SentenceTransformer, util
        model = SentenceTransformer('all-MiniLM-L6-v2')
        
        # Skip empty strings
        if not reference.strip() or not prediction.strip():
            return 0
            
        # Get sentence embeddings
        emb_ref = model.encode(reference, convert_to_tensor=True)
        emb_pred = model.encode(prediction, convert_to_tensor=True)
        
        # Calculate cosine similarity
        similarity = util.pytorch_cos_sim(emb_ref, emb_pred).item()
        
        # Consider meaning preserved if similarity > threshold
        return 1 if similarity > similarity_threshold else 0
        
    except ImportError:
        # Fallback: use WER-based approximation
        wer = wer_metric.compute(predictions=[prediction], references=[reference])
        # Conservative threshold: meaning preserved if WER < 0.3 (30%)
        return 1 if wer < 0.3 else 0

# --- You need to get predictions first! Add this: ---
print("Generating predictions...")

# Generate predictions for dev set
preds_dev = trainer.predict(dev_dataset)
dev_predictions = processor.tokenizer.batch_decode(preds_dev.predictions, skip_special_tokens=True)
dev_references = [transcript_normalizer(x) for x in dev_dataset["transcription"]]

# Generate predictions for test set  
preds_test = trainer.predict(test_dataset)
test_predictions = processor.tokenizer.batch_decode(preds_test.predictions, skip_special_tokens=True)
test_references = [transcript_normalizer(x) for x in test_dataset["transcription"]]

# Create prediction dictionaries
preds_dev_dict = {
    "speaker_id": dev_dataset["speaker_id"],
    "transcription": dev_references,
    "prediction": dev_predictions
}

preds_test_dict = {
    "speaker_id": test_dataset["speaker_id"], 
    "transcription": test_references,
    "prediction": test_predictions
}

# Calculate WER for entire sets
wer_dev = wer_metric.compute(predictions=dev_predictions, references=dev_references)
wer_test = wer_metric.compute(predictions=test_predictions, references=test_references)

# Calculate metrics for individual examples
dev_wer_individual = [calculate_individual_wer(p, r) for p, r in zip(dev_predictions, dev_references)]
dev_acc_individual = [(1 - wer) * 100 for wer in dev_wer_individual]
dev_lattescore_individual = [calculate_lattescore(p, r) for p, r in zip(dev_predictions, dev_references)]

test_wer_individual = [calculate_individual_wer(p, r) for p, r in zip(test_predictions, test_references)]
test_acc_individual = [(1 - wer) * 100 for wer in test_wer_individual]
test_lattescore_individual = [calculate_lattescore(p, r) for p, r in zip(test_predictions, test_references)]

# Calculate overall LATTEScore percentages
dev_lattescore_percent = (sum(dev_lattescore_individual) / len(dev_lattescore_individual)) * 100
test_lattescore_percent = (sum(test_lattescore_individual) / len(test_lattescore_individual)) * 100

acc_dev = (1 - wer_dev) * 100
acc_test = (1 - wer_test) * 100

print(f"Dev WER: {wer_dev:.3f} | Word Accuracy: {acc_dev:.1f}% | LATTEScore: {dev_lattescore_percent:.1f}%")
print(f"Test WER: {wer_test:.3f} | Word Accuracy: {acc_test:.1f}% | LATTEScore: {test_lattescore_percent:.1f}%")

# --- Enhanced DataFrame creation with metadata merge ---
def create_enhanced_dataframe(dataset, wer_individual, acc_individual, lattescore_individual, dataset_name):
    base_data = {
        "speaker_id": dataset["speaker_id"],
        "reference": dataset["transcription"],
        "prediction": dataset["prediction"],
        "wer": wer_individual,
        "word_accuracy": acc_individual,
        "lattescore_meaning_preserved": lattescore_individual,  # 1=preserved, 0=lost
    }

    df = pd.DataFrame(base_data)
    df = df.merge(speaker_metadata, on="speaker_id", how="left")

    print(f"\n{dataset_name} DataFrame shape: {df.shape}")
    print("Columns:", df.columns.tolist())
    return df

# Build enhanced DataFrames (THIS WAS MISSING!)
df_dev_enhanced = create_enhanced_dataframe(preds_dev_dict, dev_wer_individual, dev_acc_individual, dev_lattescore_individual, "Dev")
df_test_enhanced = create_enhanced_dataframe(preds_test_dict, test_wer_individual, test_acc_individual, test_lattescore_individual, "Test")

# --- Model deployment analysis using LATTEScore ---
def analyze_model_deployment(lattescore, threshold=80.0):
    """Analyze if model meets quality standards based on LATTEScore"""
    print(f"\n=== Model Deployment Analysis ===")
    print(f"LATTEScore: {lattescore:.1f}%")
    print(f"Deployment Threshold: {threshold}%")
    
    if lattescore >= threshold:
        print("✅ RECOMMENDATION: Model meets quality standards for deployment")
        print("   The ASR model preserves meaning in most transcripts")
    else:
        print("❌ RECOMMENDATION: Model does not meet quality standards")
        print("   Consider: More training data, hyperparameter tuning, or different architecture")
    
    return lattescore >= threshold

# Run deployment analysis
deployment_ready = analyze_model_deployment(test_lattescore_percent)

# --- Save only the main analysis files ---
df_dev_enhanced.to_csv("dev_predictions.csv", index=False)
df_test_enhanced.to_csv("test_predictions.csv", index=False)

print("\n=== FILES SAVED ===")
print(f"dev_predictions.csv: {len(df_dev_enhanced)} samples")
print(f"test_predictions.csv: {len(df_test_enhanced)} samples")

# --- Preview data ---
print("\n=== DATA PREVIEW ===")
print(df_dev_enhanced[['speaker_id', 'reference', 'prediction', 'wer', 'lattescore_meaning_preserved']].head(10))

# --- Download links ---
print("\n=== DOWNLOAD LINKS ===")
display(FileLink("dev_predictions.csv"))
display(FileLink("test_predictions.csv"))

# --- Updated next steps ---
print("\n=== NEXT STEPS ===")
print("1. Analyze LATTEScore by speaker metadata (etiology, severity, gender)")
print("2. Compare LATTEScore with WER to see if meaning preservation differs from word accuracy")
print("3. Use LATTEScore for model deployment decisions")
print("4. Calculate LATTEScore breakdown by speaker characteristics")

speaker_metadata.tsv:   0%|          | 0.00/9.27k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Speaker metadata loaded: (52, 6)
Columns: ['speaker_id', 'gender', 'age', 'severity_speech_impairment', 'type_nonstandard_speech', 'etiology']
  speaker_id  gender    age  \
0     KES001  Female  30-40   
1     KES002  Female  30-40   
2     KES003    Male  25-30   
3     KES004    Male  25-30   
4     KES005    Male  18-24   

                          severity_speech_impairment  \
0                       Severe (frequent breakdowns)   
1                       Severe (frequent breakdowns)   
2  Profound (communication very difficult or impo...   
3                       Severe (frequent breakdowns)   
4           Moderate (requires effort to understand)   

             type_nonstandard_speech               etiology  
0                         Dysarthria         Cerebral Palsy  
1                         Dysarthria         Cerebral Palsy  
2  Stuttering (Disfluency Disorders)         Cerebral Palsy  
3  Stuttering (Disfluency Disorders)  Neurological disorder  
4  Stuttering (Disfluen

Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.1972
Adjusted CER: 0.1169
LATTEScore: 69.30%
Un-adjusted WER: 0.2498
Un-adjusted CER: 0.1640


Sentence transformers not available, using WER-based LATTEScore approximation
=== Metrics ===
Adjusted WER: 0.1228
Adjusted CER: 0.0681
LATTEScore: 77.73%
Un-adjusted WER: 0.1358
Un-adjusted CER: 0.0717
Dev WER: 0.361 | Word Accuracy: 63.9% | LATTEScore: 50.6%
Test WER: 0.274 | Word Accuracy: 72.6% | LATTEScore: 64.4%

Dev DataFrame shape: (342, 11)
Columns: ['speaker_id', 'reference', 'prediction', 'wer', 'word_accuracy', 'lattescore_meaning_preserved', 'gender', 'age', 'severity_speech_impairment', 'type_nonstandard_speech', 'etiology']

Test DataFrame shape: (705, 11)
Columns: ['speaker_id', 'reference', 'prediction', 'wer', 'word_accuracy', 'lattescore_meaning_preserved', 'gender', 'age', 'severity_speech_impairment', 'type_nonstandard_speech', 'etiology']

=== Model Deployment Analysis ===
LATTEScore: 64.4%
Deployment Threshold: 80.0%
❌ RECOMMENDATION: Model does not meet quality standards
   Consider: More training data, hyperparameter tuning, or different architecture

=== FILES


=== NEXT STEPS ===
1. Analyze LATTEScore by speaker metadata (etiology, severity, gender)
2. Compare LATTEScore with WER to see if meaning preservation differs from word accuracy
3. Use LATTEScore for model deployment decisions
4. Calculate LATTEScore breakdown by speaker characteristics


## Store Model

--> save best model

### Save to your volume

In [61]:
# with "load_best_model_at_end=True" set in the settings (this is the default, so don't change that), after training is completed the best model is loaded and then saved
best_model_dir = os.path.join(OUTPUT_DIR, 'best_model')
print(f"Saving to: {best_model_dir}")
trainer.model.save_pretrained(best_model_dir, safe_serialization=True)
trainer.tokenizer.save_pretrained(best_model_dir)

Saving to: /jupyter_kernel/trained_models/en_nonstandard_tune_whisper_small_2/best_model


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


[]