Whisper Fine-tuning Quick Experiment

In [1]:
import os
import json
import pandas as pd
import numpy as np
import torch
from pathlib import Path
import soundfile as sf
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from datasets import Dataset, DatasetDict
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import mlflow
import mlflow.pytorch

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")

Using device: cuda
GPU available: True
GPU: NVIDIA GeForce RTX 4060 Ti


In [2]:
def load_labelstudio_export(export_path):
    """
    Load Label Studio JSONL export and extract audio-text pairs
    
    Expected format per line:
    {"audio_filepath": "audio/chunk_xxx.wav", "duration": 12.277875, "text": "transcription"}
    """
    processed_data = []
    
    with open(export_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                item = json.loads(line.strip())
                
                # Extract required fields
                audio_path = item.get('audio_filepath')
                transcription = item.get('text')
                duration = item.get('duration')
                
                # Validate required fields
                if audio_path and transcription:
                    transcription = transcription.strip()
                    
                    # Skip very short or empty transcriptions
                    if len(transcription) < 3:
                        continue
                    
                    # Skip very long transcriptions (might be errors)
                    if len(transcription) > 1000:
                        continue
                    
                    processed_data.append({
                        'audio_path': audio_path,
                        'transcription': transcription,
                        'duration': duration
                    })
                    
            except json.JSONDecodeError:
                continue
            except Exception:
                continue
    
    return processed_data

# Load your Label Studio export
LABELSTUDIO_EXPORT_PATH = "label-studio-export/project-11-at-2025-07-08-08-18-a315a9a2/manifest.jsonl"  # UPDATE THIS PATH (.jsonl extension)
AUDIO_BASE_PATH = "label-studio-export/project-11-at-2025-07-08-08-18-a315a9a2/"  # UPDATE THIS PATH (should contain the 'audio/' folder)

print("Loading Label Studio JSONL export...")
labelstudio_data = load_labelstudio_export(LABELSTUDIO_EXPORT_PATH)
print(f"Loaded {len(labelstudio_data)} samples from Label Studio")

# Quick data inspection
df = pd.DataFrame(labelstudio_data)
print("\nDataset overview:")
print(df.head())
print(f"\nTranscription length stats (characters):")
df['transcription_length'] = df['transcription'].str.len()
print(df['transcription_length'].describe())

print(f"\nDuration stats (seconds):")
if 'duration' in df.columns:
    print(df['duration'].describe())

# Sample transcriptions
print(f"\nSample transcriptions:")
for i in range(min(3, len(df))):
    print(f"Sample {i+1}: {df.iloc[i]['transcription'][:100]}...")
    print(f"Duration: {df.iloc[i].get('duration', 'Unknown')} seconds")
    print()

Loading Label Studio JSONL export...
Loaded 73 samples from Label Studio

Dataset overview:
                         audio_path  \
0  audio/chunk_3489d71581224dfa.wav   
1  audio/chunk_8e7e6b99099516b0.wav   
2  audio/chunk_106196e04d812c87.wav   
3  audio/chunk_f48d4ca17284401d.wav   
4  audio/chunk_f20fbdcc562302c1.wav   

                                       transcription   duration  
0  Nami buridi. Najichanganya taki bali kwamba la...  12.277875  
1  Na... Naona kama vile juhudi zangu zinagonga m...  12.490875  
2  Unaeza pia mkajaribu kuongea na mama, sawa. Bi...  11.961875  
3  Aaaaah, huko kwote sipo. Mimi ukinipigia 116 u...  12.418875  
4  Ni mtendaji wa kijiji. Okay, ni mtendaji wa ki...  11.711875  

Transcription length stats (characters):
count     73.000000
mean     175.424658
std       36.083975
min       48.000000
25%      154.000000
50%      177.000000
75%      195.000000
max      266.000000
Name: transcription_length, dtype: float64

Duration stats (seconds):
count

AUDIO PROCESSING AND QUALITY FILTERING

In [3]:
def load_and_validate_audio(audio_path, target_sr=16000):
    """Load preprocessed audio file (already cleaned and normalized)"""
    try:
        # Handle relative paths from Label Studio
        if not os.path.isabs(audio_path):
            audio_path = os.path.join(AUDIO_BASE_PATH, audio_path)
        
        # Check if file exists
        if not os.path.exists(audio_path):
            return None, None, f"File not found: {audio_path}"
        
        # Check file size
        file_size = os.path.getsize(audio_path)
        if file_size == 0:
            return None, None, f"Empty file: {audio_path}"
        
        # Load audio (already preprocessed, so minimal validation needed)
        audio, sr = sf.read(audio_path)
        
        # Convert to mono if stereo
        if len(audio.shape) > 1:
            audio = np.mean(audio, axis=1)
        
        # Audio should already be 16kHz, but check
        if sr != target_sr:
            return None, None, f"Unexpected sample rate {sr}Hz (expected {target_sr}Hz): {audio_path}"
        
        # Check if audio was loaded
        if audio is None or len(audio) == 0:
            return None, None, f"Empty audio data: {audio_path}"
        
        # Simple duration check (already preprocessed, so should be reasonable)
        duration = len(audio) / sr
        if duration < 0.1 or duration > 60:  # Very permissive since it's preprocessed
            return None, None, f"Unexpected duration ({duration:.2f}s): {audio_path}"
        
        return audio, sr, "OK"
    
    except Exception as e:
        return None, None, f"Error loading {audio_path}: {str(e)}"

# Process and filter audio files
print("\nProcessing preprocessed audio files...")
print(f"Base audio path: {AUDIO_BASE_PATH}")
print(f"Audio files should already be:")
print(f"  - 16kHz sample rate")
print(f"  - Silence removed") 
print(f"  - Quality filtered")

processed_samples = []
failed_samples = []

for i, sample in enumerate(labelstudio_data):
    audio_path = sample['audio_path']
    
    # Show first few attempts in detail
    if i < 3:
        print(f"\nSample {i+1}:")
        print(f"  Original path: {audio_path}")
        if not os.path.isabs(audio_path):
            full_path = os.path.join(AUDIO_BASE_PATH, audio_path)
            print(f"  Full path: {full_path}")
            print(f"  File exists: {os.path.exists(full_path)}")
        else:
            print(f"  File exists: {os.path.exists(audio_path)}")
    
    audio, sr, status = load_and_validate_audio(audio_path)
    
    if audio is not None:
        processed_samples.append({
            'audio': audio,
            'transcription': sample['transcription'],
            'audio_path': sample['audio_path'],
            'duration': len(audio) / sr
        })
        if i < 3:
            print(f"  Status: SUCCESS - Duration: {len(audio)/sr:.2f}s")
    else:
        failed_samples.append({'path': audio_path, 'reason': status})
        if i < 3:
            print(f"  Status: FAILED - {status}")
    
    if i % 20 == 0:
        print(f"Processed {i}/{len(labelstudio_data)} samples, {len(processed_samples)} valid")

print(f"\nFinal dataset: {len(processed_samples)} valid samples")
if failed_samples:
    print(f"Failed samples: {len(failed_samples)}")
    
    # Show common failure reasons
    print(f"\nFailure analysis:")
    failure_reasons = {}
    for fail in failed_samples:
        reason = fail['reason'].split(':')[0]  # Get main reason
        failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
    
    for reason, count in failure_reasons.items():
        print(f"  {reason}: {count} files")
else:
    print("All samples loaded successfully!")

# Quick stats on loaded audio
if processed_samples:
    durations = [s['duration'] for s in processed_samples]
    print(f"\nAudio duration stats:")
    print(f"  Mean: {np.mean(durations):.2f}s")
    print(f"  Min: {np.min(durations):.2f}s") 
    print(f"  Max: {np.max(durations):.2f}s")
    print(f"  Total: {np.sum(durations)/60:.1f} minutes")


Processing preprocessed audio files...
Base audio path: label-studio-export/project-11-at-2025-07-08-08-18-a315a9a2/
Audio files should already be:
  - 16kHz sample rate
  - Silence removed
  - Quality filtered

Sample 1:
  Original path: audio/chunk_3489d71581224dfa.wav
  Full path: label-studio-export/project-11-at-2025-07-08-08-18-a315a9a2/audio/chunk_3489d71581224dfa.wav
  File exists: True
  Status: SUCCESS - Duration: 12.28s
Processed 0/73 samples, 1 valid

Sample 2:
  Original path: audio/chunk_8e7e6b99099516b0.wav
  Full path: label-studio-export/project-11-at-2025-07-08-08-18-a315a9a2/audio/chunk_8e7e6b99099516b0.wav
  File exists: True
  Status: SUCCESS - Duration: 12.49s

Sample 3:
  Original path: audio/chunk_106196e04d812c87.wav
  Full path: label-studio-export/project-11-at-2025-07-08-08-18-a315a9a2/audio/chunk_106196e04d812c87.wav
  File exists: True
  Status: SUCCESS - Duration: 11.96s
Processed 20/73 samples, 21 valid
Processed 40/73 samples, 41 valid
Processed 60/73 

PREPARE DATASET FOR WHISPER TRAINING

In [4]:

# Initialize Whisper processor and model
model_name = "openai/whisper-small"  # Start with small for quick experiments
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)

# Move model to GPU if available
model = model.to(device)
print(f"Loaded {model_name}")

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator for speech-to-text models - Fixed version from working code
    """
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have to be of different lengths and need different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # Pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

def prepare_dataset(samples, processor, test_size=0.2):
    """Convert processed samples to Hugging Face dataset format"""
    
    def process_sample(sample):
        # Process audio to log-mel spectrogram
        input_features = processor.feature_extractor(
            sample['audio'], 
            sampling_rate=16000, 
            return_tensors="pt"
        ).input_features[0]
        
        # Tokenize transcription
        labels = processor.tokenizer(
            sample['transcription'], 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=448
        ).input_ids[0]
        
        return {
            "input_features": input_features,
            "labels": labels
        }
    
    # Process all samples
    processed = [process_sample(sample) for sample in samples]
    
    # Create dataset
    dataset = Dataset.from_list(processed)
    
    # Train/test split
    dataset = dataset.train_test_split(test_size=test_size, seed=42)
    
    return dataset

# Prepare dataset
print("\nPreparing dataset for training...")
dataset = prepare_dataset(processed_samples, processor)

print(f"Training samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['test'])}")

Loaded openai/whisper-small

Preparing dataset for training...
Training samples: 58
Validation samples: 15


EVALUATION METRICS

In [5]:
import evaluate
wer_metric = evaluate.load("wer")

def compute_metrics(eval_pred):
    """Compute WER metric during training"""
    pred_ids = eval_pred.predictions
    label_ids = eval_pred.label_ids

    # Replace -100 with pad token id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Compute WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

TRAINING SETUP

In [None]:
# Data collator - Standard configuration
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

# =============================================================================
# MODEL CONFIGURATION - STANDARD WHISPER SETUP
# =============================================================================

# Standard Whisper model configuration
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

print("Model configuration set for Whisper training")

# =============================================================================
# TRAINING ARGUMENTS - BASELINE CONFIGURATION
# =============================================================================

# Standard training arguments for RTX 4060 Ti (16GB VRAM)
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-finetuned-experiment",
    
    # Batch sizes optimized for 16GB VRAM
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,      # Effective batch size = 4*4 = 16
    
    # Learning configuration
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=1000,
    
    # Memory and performance settings
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": True},
    fp16=True,
    
    # Evaluation settings
    eval_strategy="steps",
    eval_steps=100,
    predict_with_generate=True,
    generation_max_length=225,
    
    # Saving and logging
    save_steps=100,
    logging_steps=25,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    
    # Integration
    report_to=["mlflow"],
    dataloader_pin_memory=False,
)

print("Training arguments configured with standard settings")

# =============================================================================
# TRAINER INITIALIZATION 
# =============================================================================

print("Initializing trainer...")

# Initialize trainer with standard configuration
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor,
)

print("Trainer initialized successfully")

Model configuration set for Whisper training
Training arguments configured with proper gradient checkpointing
Initializing trainer...
Trainer initialized successfully with evaluation enabled


In [None]:
# =============================================================================
# BASELINE EVALUATION
# =============================================================================

print("\n" + "="*50)
print("BASELINE EVALUATION")
print("="*50)

def evaluate_baseline(test_samples, model, processor):
    """Evaluate baseline Whisper model"""
    predictions = []
    references = []
    
    model.eval()
    with torch.no_grad():
        for sample in test_samples[:10]:
            # Process audio
            input_features = processor(
                sample['audio'], 
                sampling_rate=16000, 
                return_tensors="pt"
            ).input_features.to(device)
            
            # Generate transcription
            predicted_ids = model.generate(input_features)[0]
            prediction = processor.decode(predicted_ids, skip_special_tokens=True)
            
            predictions.append(prediction)
            references.append(sample['transcription'])
    
    # Calculate WER
    baseline_wer = wer_metric.compute(predictions=predictions, references=references)
    
    print(f"Baseline WER: {baseline_wer:.4f}")
    
    # Show examples
    print("\nBaseline Examples:")
    for i in range(min(3, len(predictions))):
        print(f"\nExample {i+1}:")
        print(f"Reference: {references[i][:100]}...")
        print(f"Prediction: {predictions[i][:100]}...")
    
    return baseline_wer, predictions, references

# Evaluate baseline
baseline_wer, baseline_preds, baseline_refs = evaluate_baseline(
    processed_samples[-20:], model, processor
)

Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.



BASELINE EVALUATION


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Baseline WER: 2.8528

Baseline Examples:

Example 1:
Reference: Mwanangu kwa sababu Mungu alijaliwa alikuwa kashapona. Nikaona hayana haja, nikayaacha. Sasa nimekaa juzi natumiwa tena mtoto, tena sio mtoto wangu, w...
Prediction:  naka...

Example 2:
Reference: Ndoa za utotoni, ni pale ambapo binti anabeba mimba akiwa chini ya miaka kumi na nane....
Prediction:  Doza ututu ni pari ya npako binti ana beba mi ba kiyo, siri wa mi ya kakumina nane....

Example 3:
Reference: Uhhhh. Mmmmh. Badala ya kumpa maziwa ya wanyama....
Prediction:  Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa....


In [None]:
from transformers import TrainerCallback

# Clear any existing MLflow runs
try:
    mlflow.end_run()
except:
    pass

# MLflow Configuration
mlflow.set_tracking_uri("http://localhost:5000")  # Update to your MLflow server URI
mlflow.set_experiment("whisper-quick-experiment-2")

print(f"MLflow tracking URI: {mlflow.get_tracking_uri()}")
print(f"MLflow experiment: {mlflow.get_experiment_by_name('whisper-quick-experiment-2')}")

class MLflowCallback(TrainerCallback):
    """Custom callback for real-time MLflow logging"""
    
    def __init__(self, run_name="whisper-small-quick-test"):
        self.run_name = run_name
        self.mlflow_run = None
    
    def on_train_begin(self, args, state, control, **kwargs):
        """Start MLflow run at beginning of training"""
        # Check for active run and end it if necessary
        if mlflow.active_run():
            print(f"Active run detected: {mlflow.active_run().info.run_id}")
            mlflow.end_run()
        
        # Start new MLflow run
        self.mlflow_run = mlflow.start_run(run_name=self.run_name)
        
        # Log training configuration
        mlflow.log_param("model_name", "openai/whisper-small")
        mlflow.log_param("dataset_size", len(processed_samples))
        mlflow.log_param("train_size", len(dataset["train"]))
        mlflow.log_param("val_size", len(dataset["test"]))
        mlflow.log_param("learning_rate", args.learning_rate)
        mlflow.log_param("batch_size", args.per_device_train_batch_size)
        mlflow.log_param("effective_batch_size", args.per_device_train_batch_size * args.gradient_accumulation_steps)
        mlflow.log_param("max_steps", args.max_steps)
        mlflow.log_param("fp16", args.fp16)
        mlflow.log_param("gradient_checkpointing", args.gradient_checkpointing)
        mlflow.log_param("eval_steps", args.eval_steps)
        mlflow.log_param("warmup_steps", args.warmup_steps)
        
        print("MLflow logging started!")
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        """Log metrics in real-time during training"""
        if logs and self.mlflow_run:
            step = state.global_step
            
            # Log all metrics from the logs
            for key, value in logs.items():
                if isinstance(value, (int, float)):
                    mlflow.log_metric(key, value, step=step)
            
            # Log epoch
            mlflow.log_metric("epoch", state.epoch, step=step)
            
            # Print key metrics for monitoring
            if "train_loss" in logs:
                print(f"Step {step}: Loss = {logs['train_loss']:.4f}")
            if "eval_wer" in logs:
                print(f"Step {step}: WER = {logs['eval_wer']:.2f}%")
    
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Log evaluation metrics"""
        if metrics and self.mlflow_run:
            step = state.global_step
            
            for key, value in metrics.items():
                if isinstance(value, (int, float)):
                    mlflow.log_metric(key, value, step=step)
            
            print(f"Evaluation at step {step}: {metrics}")
    
    def on_train_end(self, args, state, control, **kwargs):
        """Log final training completion"""
        if self.mlflow_run:
            print("Training completed! MLflow run ready for final evaluation.")

# Initialize MLflow callback
mlflow_callback = MLflowCallback()

# Add callback to trainer
trainer.add_callback(mlflow_callback)

# =============================================================================
# BASELINE EVALUATION
# =============================================================================

print("\n" + "="*50)
print("BASELINE EVALUATION")
print("="*50)

def evaluate_baseline(test_samples, model, processor):
    """Evaluate baseline Whisper model - simplified approach"""
    predictions = []
    references = []
    
    model.eval()
    with torch.no_grad():
        for sample in test_samples[:10]:  # Evaluate on first 10 samples for speed
            # Process audio - simplified approach like reference code
            input_features = processor(
                sample['audio'], 
                sampling_rate=16000, 
                return_tensors="pt"
            ).input_features.to(device)
            
            # Generate transcription - simplified like reference code
            predicted_ids = model.generate(input_features)[0]
            prediction = processor.decode(predicted_ids, skip_special_tokens=True)
            
            predictions.append(prediction)
            references.append(sample['transcription'])
    
    # Calculate WER
    baseline_wer = wer_metric.compute(predictions=predictions, references=references)
    
    print(f"Baseline WER: {baseline_wer:.4f}")
    
    # Show some examples
    print("\nBaseline Examples:")
    for i in range(min(3, len(predictions))):
        print(f"\nExample {i+1}:")
        print(f"Reference: {references[i][:150]}...")  # Truncate for readability
        print(f"Prediction: {predictions[i][:150]}...")
    
    return baseline_wer, predictions, references

# Evaluate baseline
baseline_wer, baseline_preds, baseline_refs = evaluate_baseline(
    processed_samples[-20:], model, processor
)

# =============================================================================
# TRAINING EXECUTION WITH MLflow TRACKING
# =============================================================================

print("\n" + "="*50)
print("STARTING TRAINING")
print("="*50)

# Training will automatically start MLflow run and log metrics
trainer.train()

# =============================================================================
# FINAL EVALUATION AND MODEL SAVING
# =============================================================================

print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

final_eval = trainer.evaluate()
final_wer = final_eval['eval_wer']

print(f"Final WER: {final_wer:.2f}%")
print(f"Baseline WER: {baseline_wer*100:.2f}%")
improvement = ((baseline_wer*100 - final_wer) / (baseline_wer*100)) * 100
print(f"WER Improvement: {improvement:.2f}%")

# Log final metrics to MLflow (ensure we're in the run context)
if mlflow.active_run():
    mlflow.log_metric("final_wer", final_wer)
    mlflow.log_metric("baseline_wer", baseline_wer * 100)  # Convert to percentage
    mlflow.log_metric("wer_improvement_percent", improvement)

    # Save model
    trainer.save_model()

    # Log model to MLflow
    try:
        # Method 1: Use trainer's unwrapped model
        unwrapped_model = trainer.accelerator.unwrap_model(trainer.model)
        
        mlflow.pytorch.log_model(
            pytorch_model=unwrapped_model,
            artifact_path="model",
            registered_model_name="whisper-finetuned-quick"
        )
        print("Model saved to MLflow!")
        
    except Exception as e:
        print(f"MLflow model logging failed: {e}")
        print("Model was saved locally with trainer.save_model()")
        
        # Alternative: Log just the model state dict
        try:
            model_state = {
                'model_state_dict': trainer.model.state_dict(),
                'model_config': trainer.model.config,
            }
            mlflow.pytorch.log_state_dict(model_state, artifact_path="model_state")
            print("Model state dict saved to MLflow as fallback")
        except:
            print("All MLflow model logging attempts failed - model saved locally only")

# End MLflow run safely
try:
    mlflow.end_run()
except:
    pass

print("\n" + "="*60)
print("EXPERIMENT COMPLETE!")
print("="*60)
print(f"Baseline WER: {baseline_wer*100:.2f}%")
print(f"Fine-tuned WER: {final_wer:.2f}%")
print(f"Improvement: {improvement:.2f}%")
print("\nView results in MLflow UI:")
print("Run: mlflow ui")
print("Open: http://localhost:5000")

MLflow tracking URI: http://localhost:5000
MLflow experiment: <Experiment: artifact_location='/mlflow/artifacts/8', creation_time=1752065869356, experiment_id='8', last_update_time=1752065869356, lifecycle_stage='active', name='whisper-quick-experiment-2', tags={}>

BASELINE EVALUATION
Baseline WER: 2.8528

Baseline Examples:

Example 1:
Reference: Mwanangu kwa sababu Mungu alijaliwa alikuwa kashapona. Nikaona hayana haja, nikayaacha. Sasa nimekaa juzi natumiwa tena mtoto, tena sio mtoto wangu, w...
Prediction:  naka...

Example 2:
Reference: Ndoa za utotoni, ni pale ambapo binti anabeba mimba akiwa chini ya miaka kumi na nane....
Prediction:  Doza ututu ni pari ya npako binti ana beba mi ba kiyo, siri wa mi ya kakumina nane....

Example 3:
Reference: Uhhhh. Mmmmh. Badala ya kumpa maziwa ya wanyama....
Prediction:  Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa. Mwaa....

STARTING TRAINING
A

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss,Wer
100,2.1201,2.860643,1.668269
200,0.196,2.754059,1.769231
300,0.0028,3.130589,1.360577
400,0.0009,3.240578,1.728365
500,0.0005,3.312905,1.733173
600,0.0003,3.359595,1.733173
700,0.0002,3.387628,1.348558
800,0.0002,3.404866,1.334135
900,0.0002,3.414341,1.346154
1000,0.0002,3.418957,1.334135


Step 100: WER = 1.67%
Evaluation at step 100: {'eval_loss': 2.860643148422241, 'eval_wer': 1.6682692307692308}




Step 200: WER = 1.77%
Evaluation at step 200: {'eval_loss': 2.754058837890625, 'eval_wer': 1.7692307692307692}
Step 300: WER = 1.36%
Evaluation at step 300: {'eval_loss': 3.1305885314941406, 'eval_wer': 1.3605769230769231}
Step 400: WER = 1.73%
Evaluation at step 400: {'eval_loss': 3.2405784130096436, 'eval_wer': 1.7283653846153846}
Step 500: WER = 1.73%
Evaluation at step 500: {'eval_loss': 3.3129053115844727, 'eval_wer': 1.7331730769230769}
Step 600: WER = 1.73%
Evaluation at step 600: {'eval_loss': 3.3595945835113525, 'eval_wer': 1.7331730769230769}
Step 700: WER = 1.35%
Evaluation at step 700: {'eval_loss': 3.387627601623535, 'eval_wer': 1.3485576923076923}
Step 800: WER = 1.33%
Evaluation at step 800: {'eval_loss': 3.4048657417297363, 'eval_wer': 1.3341346153846154}
Step 900: WER = 1.35%
Evaluation at step 900: {'eval_loss': 3.4143412113189697, 'eval_wer': 1.3461538461538463}
Step 1000: WER = 1.33%
Evaluation at step 1000: {'eval_loss': 3.418957233428955, 'eval_wer': 1.33413461538

There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


Step 1000: Loss = 0.4374
🏃 View run whisper-small-quick-test at: http://localhost:5000/#/experiments/8/runs/1ccc5f4f71974256a644dd0fe7e8bdff
🧪 View experiment at: http://localhost:5000/#/experiments/8
Training completed! MLflow run ready for final evaluation.

FINAL EVALUATION


Step 1000: WER = 1.33%
Evaluation at step 1000: {'eval_loss': 3.4048657417297363, 'eval_wer': 1.3341346153846154, 'eval_runtime': 20.7909, 'eval_samples_per_second': 0.721, 'eval_steps_per_second': 0.721, 'epoch': 250.0}
Final WER: 1.33%
Baseline WER: 285.28%
WER Improvement: 99.53%




PicklingError: Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it.

In [9]:
print("\n" + "="*50)
print("TESTING FINE-TUNED MODEL")
print("="*50)

def test_finetuned_model(test_samples, trainer, processor, num_samples=5):
    """Test the fine-tuned model on new samples with improved generation"""
    model = trainer.model
    model.eval()
    
    predictions = []
    references = []
    
    with torch.no_grad():
        for i, sample in enumerate(test_samples[-num_samples:]):  # Test on last few samples
            # Process audio
            input_features = processor.feature_extractor(
                sample['audio'], 
                sampling_rate=16000, 
                return_tensors="pt"
            ).input_features.to(device)
            
            # Generate transcription with improved parameters
            predicted_ids = model.generate(
                input_features,
                max_length=448,  # Increased max length
                num_beams=5,     # Beam search for better quality
                early_stopping=True,
                do_sample=False,  # Deterministic generation
                pad_token_id=processor.tokenizer.eos_token_id,
                forced_decoder_ids=processor.get_decoder_prompt_ids(language="sw", task="transcribe")
            )
            prediction = processor.tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
            
            predictions.append(prediction)
            references.append(sample['transcription'])
            
            print(f"\nTest Sample {i+1}:")
            print(f"Reference:  {sample['transcription'][:150]}...")  # Truncate for readability
            print(f"Prediction: {prediction[:150]}...")
    
    # Calculate final WER on test samples
    test_wer = wer_metric.compute(predictions=predictions, references=references)
    print(f"\nTest WER on {num_samples} samples: {test_wer:.4f}")
    
    return test_wer

# Test fine-tuned model
test_wer = test_finetuned_model(processed_samples, trainer, processor)

print("\n" + "="*60)
print("NEXT STEPS:")
print("="*60)
print("1. If results look good, increase max_steps for longer training")
print("2. Experiment with different learning rates and batch sizes")
print("3. Try larger Whisper models (medium, large)")
print("4. Implement the full MLOps structure we discussed")
print("5. Add more sophisticated data preprocessing")
print("6. Implement cross-validation and more robust evaluation")


TESTING FINE-TUNED MODEL

Test Sample 1:
Reference:  Ni maeneo gani? Je, anaishi wapi huyo mtu? Ni maeneo , ni moja wa mwanafamilia, kama ni mwanafamilia, je ana mahusiana yapi, ni baba, ni mama?...
Prediction:  Anahishi wapi huyo mtu. Ni maeneo, ni moja wa mwanafamilia, kama ni mwanafamilia je ana mahusiana yapi ni baba, ni mama...

Test Sample 2:
Reference:  Democracy. Katika democracy tukasoma moja, mbili, tatu. Sasa unaeza kumbuka, aaaaha, mwalimu hapa alivyoelezea, akatoa mfano hivi. Kuwa unaeza kukumbu...
Prediction:  demokrasi, katika demokrasi utukasuma moja mbilitato kwafana kukumboka hapa mwalimu ya walezea, akatowa mfano hivi, kwao kukumboka piawa kupitia mfan...

Test Sample 3:
Reference:  Ambaye mtoto huyo ataelekezwa kuenda kukaa sawa. Eeeeh? Eeeeh. Mtoto ana haki ya kulindwa dhidi ya ukatili wa aina yoyote ile. Iwe ukatili wa kimwili,...
Prediction:  kwa mbaye mtotowe ataelekeza kuenda kukaa sawa. Mtoto ana haki ya kulindwa dhidi ya ukatili wa aina yoyote ile. Iwe ukati