In [None]:
!pip install -qU bitsandbytes peft evaluate jiwer accelerate soundfile librosa

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset, Audio
from collections import defaultdict
import torchaudio
from tqdm import tqdm
import torch
# Load dataset
print("Loading dataset...")
dataset = load_dataset("sartifyllc/Sartify_ITU_Zindi_Testdataset", split="test")
dataset = dataset.cast_column("audio", Audio())

# Initialize statistics collectors
stats = {
    'sample_rates': defaultdict(int),
    'channels': defaultdict(int),
    'durations': [],
    'max_amplitudes': [],
    'mean_amplitudes': []
}

# Process files to collect statistics
print("\nAnalyzing audio files...")
for item in tqdm(dataset, desc="Processing audio files"):
    audio = item['audio']

    # Basic metadata
    stats['sample_rates'][audio['sampling_rate']] += 1
    channels = audio['array'].shape[0] if len(audio['array'].shape) > 1 else 1
    stats['channels'][channels] += 1

    # Duration calculation
    duration = len(audio['array']) / audio['sampling_rate']
    stats['durations'].append(duration)

    # Amplitude analysis
    waveform = torch.from_numpy(audio['array']).float()
    stats['max_amplitudes'].append(torch.max(torch.abs(waveform)).item())
    stats['mean_amplitudes'].append(torch.mean(torch.abs(waveform)).item())

# Convert to pandas DataFrame for better analysis
df_stats = pd.DataFrame({
    'duration': stats['durations'],
    'max_amplitude': stats['max_amplitudes'],
    'mean_amplitude': stats['mean_amplitudes']
})

# Print summary statistics
print("\n=== Audio Dataset Statistics ===")
print(f"\nTotal files: {len(dataset)}")
print(f"\nSample Rates Distribution:")
for rate, count in sorted(stats['sample_rates'].items()):
    print(f"- {rate} Hz: {count} files ({count/len(dataset)*100:.1f}%)")

print(f"\nChannels Distribution:")
for channels, count in sorted(stats['channels'].items()):
    print(f"- {'Mono' if channels == 1 else 'Stereo'}: {count} files ({count/len(dataset)*100:.1f}%)")

print("\nDuration Statistics (seconds):")
print(df_stats['duration'].describe())

print("\nAmplitude Statistics:")
print(f"- Max amplitude: mean={np.mean(stats['max_amplitudes']):.4f}, std={np.std(stats['max_amplitudes']):.4f}")
print(f"- Mean amplitude: mean={np.mean(stats['mean_amplitudes']):.4f}, std={np.std(stats['mean_amplitudes']):.4f}")

# Plot distributions
plt.figure(figsize=(15, 10))

plt.subplot(2, 2, 1)
plt.hist(df_stats['duration'], bins=50)
plt.title('Duration Distribution (seconds)')
plt.xlabel('Duration')
plt.ylabel('Count')

plt.subplot(2, 2, 2)
plt.hist(df_stats['max_amplitude'], bins=50)
plt.title('Max Amplitude Distribution')
plt.xlabel('Max Amplitude')
plt.ylabel('Count')

plt.subplot(2, 2, 3)
plt.hist(df_stats['mean_amplitude'], bins=50)
plt.title('Mean Amplitude Distribution')
plt.xlabel('Mean Amplitude')
plt.ylabel('Count')

plt.subplot(2, 2, 4)
plt.scatter(df_stats['duration'], df_stats['max_amplitude'], alpha=0.3)
plt.title('Duration vs Max Amplitude')
plt.xlabel('Duration (s)')
plt.ylabel('Max Amplitude')

plt.tight_layout()
plt.show()

# Print potential issues
print("\n=== Potential Audio Quality Issues ===")
quiet_threshold = 0.1
short_threshold = 1.0  # seconds

quiet_files = df_stats[df_stats['max_amplitude'] < quiet_threshold]
short_files = df_stats[df_stats['duration'] < short_threshold]

print(f"\nFiles with max amplitude < {quiet_threshold} (possibly too quiet): {len(quiet_files)}")
print(f"Files with duration < {short_threshold} seconds: {len(short_files)}")

if len(quiet_files) > 0:
    print("\nSample quiet files statistics:")
    print(quiet_files.describe())

if len(short_files) > 0:
    print("\nSample short files statistics:")
    print(short_files.describe())

In [None]:
from datasets import load_dataset, Audio
import matplotlib.pyplot as plt
import numpy as np

# Load dataset with audio
dataset = load_dataset("sartifyllc/Sartify_ITU_Zindi_Testdataset", split="test")
dataset = dataset.cast_column("audio", Audio())

def inspect_channels(sample_index=0, plot_waveform=True):
    """Inspect audio channels for a given sample"""
    audio = dataset[sample_index]['audio']
    waveform = audio['array']
    sample_rate = audio['sampling_rate']
    filename = dataset[sample_index]['filename']

    # Channel info
    num_channels = waveform.shape[0] if len(waveform.shape) > 1 else 1
    duration = len(waveform[0]) / sample_rate if num_channels > 1 else len(waveform) / sample_rate

    print(f"\nFile: {filename}")
    print(f"Sample Rate: {sample_rate}Hz")
    print(f"Channels: {'Stereo' if num_channels > 1 else 'Mono'}")
    print(f"Duration: {duration:.2f} seconds")
    print(f"Waveform shape: {waveform.shape}")

    if plot_waveform:
        plt.figure(figsize=(12, 4))

        if num_channels > 1:
            plt.plot(waveform[0], label='Left Channel', alpha=0.7)
            plt.plot(waveform[1], label='Right Channel', alpha=0.7)
            plt.legend()
        else:
            plt.plot(waveform)

        plt.title(f"Audio Waveform\n{filename}")
        plt.xlabel("Samples")
        plt.ylabel("Amplitude")
        plt.show()

# Inspect first 5 files
for i in range(5):
    inspect_channels(i)

# Get channel statistics for entire dataset
print("\nCalculating channel statistics for entire dataset...")
channel_counts = {'Mono': 0, 'Stereo': 0}

for item in dataset:
    waveform = item['audio']['array']
    channels = waveform.shape[0] if len(waveform.shape) > 1 else 1
    channel_counts['Mono' if channels == 1 else 'Stereo'] += 1

print("\nChannel Distribution:")
for channel_type, count in channel_counts.items():
    print(f"{channel_type}: {count} files ({count/len(dataset)*100:.1f}%)")

In [None]:
from huggingface_hub import login
from tqdm import tqdm
import time
import os

# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

# Hugging Face login
HF_TOKEN = ""
login(token=HF_TOKEN)

In [None]:
import os
import random
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Union

import numpy as np
import torch
import pandas as pd
import datasets
from datasets import load_dataset, Audio, Dataset
import evaluate
import librosa
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    TrainerCallback,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from tqdm import tqdm

# ==============================================
# 1. CONFIGURATION
# ==============================================
# --- Paths and Datasets ---
OUTPUT_DIR = "whisper_turbo_finetuned"
os.makedirs(OUTPUT_DIR, exist_ok=True)
LOG_CSV = os.path.join(OUTPUT_DIR, "training_log.csv")
EVAL_CSV = os.path.join(OUTPUT_DIR, "eval_log.csv")

# New Model & Hub
model_name_or_path = "Abdoul27/whisper-turbo-v3-model"
task = "transcribe"

# Datasets to load (Swahili-only)
SWAHILI_DATASETS = [
    {"path": "Sunbird/salt", "config": "multispeaker-swa", "split": "train"},
    {"path": "Sunbird/salt", "config": "studio-swa", "split": "train"},
]

# Noise dataset for augmentation
NOISE_DATASET_SPEC = {"path": "Sunbird/urban-noise-uganda-61k", "config": "small", "split": "train"}

# --- LANGUAGE CONFIGURATION ---
language = "swahili"
language_abbr = "sw"
TARGET_SAMPLING_RATE = 16000

# --- Augmentation Hyperparameters ---
MAX_REL_NOISE_AMP = 0.5
P_NOISE_AUG = 0.5  # Probability of adding noise

# ==============================================
# 2. LOAD MODELS & PROCESSORS
# ==============================================
print("Loading model & processor:", model_name_or_path)
processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)
model = WhisperForConditionalGeneration.from_pretrained(
    model_name_or_path,
    load_in_8bit=True,
    device_map="auto"
)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

# Prepare the model for 8-bit training
model = prepare_model_for_kbit_training(model)

# Define the LoRA configuration
config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none"
)

# Wrap the base model with the PEFT model
model = get_peft_model(model, config)
model.print_trainable_parameters()

# ==============================================
# 3. LOAD & PREPARE DATASETS
# ==============================================
def safe_load_dataset(spec):
    try:
        if "config" in spec and spec["config"]:
            ds = load_dataset(spec["path"], spec["config"], split=spec.get("split", "train"))
        else:
            ds = load_dataset(spec["path"], split=spec.get("split", "train"))
        print(f"[OK] Loaded dataset {spec['path']} config={spec.get('config')} split={spec.get('split')}")
        return ds
    except Exception as e:
        print(f"[WARN] Could not load dataset {spec!r}: {e}")
        return None

def normalize_dataset(ds, keep_text_name="text"):
    """Normalize dataset features and column names."""
    if "audio" not in ds.column_names:
        for col in ds.column_names:
            if isinstance(ds.features.get(col), datasets.features.Audio):
                ds = ds.rename_column(col, "audio")
                break
        else:
            if "source" in ds.column_names:
                ds = ds.rename_column("source", "audio")

    text_col = next((c for c in ["transcription", "text", "sentence", "transcript", "target"] if c in ds.column_names), None)
    if text_col is None:
        ds = ds.add_column(keep_text_name, [""] * len(ds))
    else:
        if text_col != keep_text_name:
            ds = ds.rename_column(text_col, keep_text_name)

    keep_cols = [c for c in ["audio", keep_text_name] if c in ds.column_names]
    ds = ds.remove_columns([c for c in ds.column_names if c not in keep_cols])
    ds = ds.cast_column("audio", Audio(sampling_rate=TARGET_SAMPLING_RATE))
    return ds

print("\nLoading and preparing datasets...")
train_parts = []
val_parts = []
for spec in SWAHILI_DATASETS:
    ds = safe_load_dataset(spec)
    if ds:
        try:
            ds = normalize_dataset(ds)
            train_parts.append(ds)
        except Exception as e:
            print(f"[WARN] Failed to normalize dataset {spec}: {e}")

if not train_parts:
    raise RuntimeError("No Swahili training datasets loaded.")

train_ds = datasets.concatenate_datasets(train_parts, axis=0)
split = train_ds.train_test_split(test_size=0.02, seed=42)
train_ds = split["train"]
val_ds = split["test"]

print(f"Combined dataset: Train examples: {len(train_ds)}, Val examples: {len(val_ds)}")

# Load noise dataset
noise_ds = safe_load_dataset(NOISE_DATASET_SPEC)
noise_audio_arrays = []
if noise_ds is not None:
    try:
        noise_ds = normalize_dataset(noise_ds, keep_text_name="text")
        noise_audio_arrays = [np.asarray(ex["audio"]["array"], dtype=np.float32) for ex in noise_ds.select(range(min(len(noise_ds), 2000)))]
        print(f"Loaded {len(noise_audio_arrays)} noise clips for augmentation.")
    except Exception as e:
        print(f"[WARN] Failed processing noise dataset: {e}")
        noise_audio_arrays = []

# ==============================================
# 4. DEFINE DATA COLLATOR AND METRICS
# ==============================================
def add_noise(x: np.ndarray, noise_pool: List[np.ndarray], max_rel_amp=MAX_REL_NOISE_AMP):
    if not noise_pool:
        return x
    noise_clip = random.choice(noise_pool)
    L = len(x)
    if len(noise_clip) >= L:
        start = random.randint(0, len(noise_clip) - L)
        noise_seg = noise_clip[start:start + L]
    else:
        repeats = math.ceil(L / len(noise_clip))
        noise_seg = np.tile(noise_clip, repeats)[:L]
    audio_amp = np.max(np.abs(x)) + 1e-9
    rel = random.random() * max_rel_amp
    noise_scaled = noise_seg.astype(np.float32) * (rel * audio_amp)
    y = x + noise_scaled
    return np.clip(y, -1.0, 1.0).astype(np.float32)

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    noise_pool: List[np.ndarray] = None
    add_noise_p: float = 0.5

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Extract audio arrays and texts
        audio_arrays = [feature["audio"]["array"] for feature in features]
        texts = [feature.get("text") or feature.get("transcription") or feature.get("sentence") or "" for feature in features]

        # Apply noise augmentation directly on the fly
        if self.noise_pool and random.random() < self.add_noise_p:
             audio_arrays = [add_noise(arr, self.noise_pool) for arr in audio_arrays]

        # Extract features and pad
        input_features = self.processor.feature_extractor(audio_arrays, sampling_rate=TARGET_SAMPLING_RATE, return_tensors="pt").input_features


        # Tokenize and pad labels
        tokenizer_output = self.processor.tokenizer(texts, padding=True, return_tensors="pt")
        labels = tokenizer_output["input_ids"].masked_fill(tokenizer_output["attention_mask"].ne(1), -100)

        # Prepare the final batch
        batch = {
            "input_features": input_features,
            "labels": labels,
        }
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, noise_pool=noise_audio_arrays, add_noise_p=P_NOISE_AUG)

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(eval_pred):
    pred_ids = eval_pred.predictions
    label_ids = eval_pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    pred_str = [p.lower().strip() for p in pred_str]
    label_str = [l.lower().strip() for l in label_str]

    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
    cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

# ==============================================
# 5. DEFINE TRAINING ARGUMENTS
# ==============================================
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=500,
    num_train_epochs=3,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=400,
    eval_steps=200,
    logging_steps=200,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    remove_unused_columns=False,
)

# ==============================================
# 6. INITIALIZE AND START TRAINING
# ==============================================
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.tokenizer,

)

processor.save_pretrained(training_args.output_dir)

print("\nStarting PEFT training...")
trainer.train()

print("\nTraining complete.")
trainer.save_model(training_args.output_dir)
print("Artifacts saved to:", training_args.output_dir)

In [None]:
import torch
import pandas as pd
from datasets import load_dataset, Audio
from transformers import pipeline, WhisperProcessor
from transformers.utils import is_flash_attn_2_available
from tqdm import tqdm
import warnings
import os

# Suppress all warnings for cleaner output
warnings.filterwarnings("ignore")

# --- 1. Configuration ---
# Use our newly finetuned model
MODEL_ID = "./whisper_turbo_finetuned"
OUTPUT_FILE = "submission_turbo_finetuned_beam_search.csv"
SAMPLE_RATE = 16000

print("Loading the newly finetuned Swahili ASR model...")
print(f"Model: {MODEL_ID}")
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
print(f"Using device: {device}")

# Load processor and get forced_decoder_ids for the older approach
processor = WhisperProcessor.from_pretrained(MODEL_ID)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="swahili", task="transcribe")

# --- 2. Pipeline Creation ---
# The pipeline automatically handles the model and processor loading
print("Setting up ASR pipeline with beam search...")
asr_pipeline = pipeline(
    "automatic-speech-recognition",
    model=MODEL_ID,
    torch_dtype=torch_dtype,
    device=device,
    model_kwargs={
        "attn_implementation": "flash_attention_2"
    } if is_flash_attn_2_available() else {}, # Removed SDPA as it's the default
)

# Load dataset
print("Loading dataset...")
dataset = load_dataset("sartifyllc/Sartify_ITU_Zindi_Testdataset", split="test")
dataset = dataset.cast_column("audio", Audio(sampling_rate=SAMPLE_RATE))

print(f"Total files to process: {len(dataset)}")
print(f"\nStarting transcription with {MODEL_ID} model... 🎤")

results = []
for i, row in enumerate(tqdm(dataset, desc=f"Transcribing with {MODEL_ID}")):
    try:
        # Get audio array
        audio_array = row["audio"]["array"]

        # To use beam search, we pass the appropriate arguments to generate_kwargs.
        transcription_result = asr_pipeline(
            audio_array,
            generate_kwargs={
                'forced_decoder_ids': forced_decoder_ids,
                'num_beams': 3,
                'max_new_tokens': 440,
                'early_stopping': True,
                'repetition_penalty': 1.2
            }
        )
        transcription = transcription_result["text"].strip().lower()

        results.append({
            'filename': row['filename'],
            'text': transcription if transcription else "" # Use empty string for no transcription
        })

        # Print details for first 5 files
        if i < 5:
            duration = len(audio_array) / SAMPLE_RATE
            print(f"\n[{i+1}/5] File: {row['filename']}")
            print(f"Audio duration: {duration:.2f}s")
            print(f"Transcription: '{transcription}'")
            print("-" * 60)

    except Exception as e:
        print(f"\nError processing {row['filename']}: {str(e)}")
        results.append({
            'filename': row['filename'],
            'text': ""  # Default fallback
        })

print(f"\nProcessed {len(results)} files successfully")
print("Saving results to CSV...")

# Create submission dataframe
submission_df = pd.DataFrame(results)

# Save to CSV
submission_df.to_csv(OUTPUT_FILE, index=False)
print(f"\nSubmission file saved to {OUTPUT_FILE}")

# Show sample results
print("\nSample transcriptions:")
print(submission_df.head(10))