In [None]:
!pip install -q jiwer==3.1.0
!pip install -q accelerate
!pip install -q transformers 
!pip install -q soundfile
!git clone https://github.com/SunbirdAI/salt.git
!pip install -qr salt/requirements.txt
!pip install -q peft
!pip install -q evaluate
!pip install -q silero_vad

In [2]:
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import salt.dataset
import salt.metrics
import salt.constants
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub
import peft
import pandas as pd
import tqdm.notebook as tqdm
import jiwer
import string

In [3]:
def strip_punctuation(text):
    # Create a translation table to remove all punctuation
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator)
    
def normalise(texts):
    return [strip_punctuation(t.lower()) for t in texts]

In [4]:
config = {'pretrained_model': 'jq/whisper-large-v3-kin-track-b'}
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'],
    language=None,
    task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'],
    # attn_implementation="flash_attention_2",
    # torch_dtype=torch.float16,
    device_map='auto',
)
model = model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None

In [6]:
import torch
import numpy as np
from typing import List, Tuple
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps

def trim_audio_with_silero_vad(audio: np.ndarray, sample_rate: int = 16000) -> np.ndarray:
    """
    Trim non-speech segments from audio using Silero VAD.
    
    Args:
        audio: Audio array (single channel, float32)
        sample_rate: Sample rate (default 16000)
    
    Returns:
        Trimmed audio array with only speech segments
    """
    # Load Silero VAD model
    model = load_silero_vad()
    
    # Ensure audio is float32 and in correct range
    if audio.dtype != np.float32:
        audio = audio.astype(np.float32)
    
    # Normalize audio to [-1, 1] range if needed
    if np.max(np.abs(audio)) > 1.0:
        audio = audio / np.max(np.abs(audio))
    
    # Convert to torch tensor
    audio_tensor = torch.from_numpy(audio)
    
    # Get speech timestamps
    speech_timestamps = get_speech_timestamps(
        audio_tensor, 
        model,
        sampling_rate=sample_rate,
        threshold=0.5,  # Adjust sensitivity (0.1-0.9)
        min_speech_duration_ms=250,  # Minimum speech segment length
        min_silence_duration_ms=100,  # Minimum silence to split on
        window_size_samples=1536,  # VAD window size
        speech_pad_ms=30  # Padding around speech segments
    )
    
    # Extract speech segments
    if not speech_timestamps:
        return np.array([])  # No speech detected
    
    # Collect all speech chunks
    speech_chunks = []
    for timestamp in speech_timestamps:
        start_sample = timestamp['start']
        end_sample = timestamp['end']
        chunk = audio[start_sample:end_sample]
        speech_chunks.append(chunk)
    
    # Concatenate all speech segments
    trimmed_audio = np.concatenate(speech_chunks)
    
    return trimmed_audio

def get_speech_segments(audio: np.ndarray, sample_rate: int = 16000) -> List[Tuple[int, int]]:
    """
    Get speech segment timestamps without trimming.
    
    Args:
        audio: Audio array (single channel, float32)
        sample_rate: Sample rate (default 16000)
    
    Returns:
        List of (start_sample, end_sample) tuples for speech segments
    """
    # Load Silero VAD model
    model = load_silero_vad()
    
    # Ensure audio is float32 and in correct range
    if audio.dtype != np.float32:
        audio = audio.astype(np.float32)
    
    # Normalize audio to [-1, 1] range if needed
    if np.max(np.abs(audio)) > 1.0:
        audio = audio / np.max(np.abs(audio))
    
    # Convert to torch tensor
    audio_tensor = torch.from_numpy(audio)
    
    # Get speech timestamps
    speech_timestamps = get_speech_timestamps(
        audio_tensor, 
        model,
        sampling_rate=sample_rate,
        threshold=0.5,
        min_speech_duration_ms=250,
        min_silence_duration_ms=100,
        window_size_samples=1536,
        speech_pad_ms=30
    )
    
    segments = [(ts['start'], ts['end']) for ts in speech_timestamps]
    return segments

def trim_audio_start_end(audio: np.ndarray, sample_rate: int = 16000, buffer_seconds: float = 0.5, minimum_seconds_removed: float = 3.0) -> np.ndarray:
    """
    Trim silence from the beginning and end, keeping internal silences and a buffer.
    
    Args:
        audio: Audio array (single channel, float32)
        sample_rate: Sample rate (default 16000)
        buffer_seconds: Minimum buffer to leave at start/end (default 0.5s)
    
    Returns:
        Audio array with silence trimmed from start and end, but with buffer preserved
    """
    segments = get_speech_segments(audio, sample_rate)
    
    if not segments:
        return np.array([])  # No speech detected
    
    # Get overall start and end of speech
    speech_start = segments[0][0]
    speech_end = segments[-1][1]
    
    # Calculate buffer in samples
    buffer_samples = int(buffer_seconds * sample_rate)
    
    # Apply buffer while staying within audio bounds
    trim_start = max(0, speech_start - buffer_samples)
    trim_end = min(len(audio), speech_end + buffer_samples)

    seconds_removed = (len(audio) - (trim_end - trim_start)) / sample_rate

    if seconds_removed > minimum_seconds_removed:
        # Return audio with buffer
        return audio[trim_start:trim_end], seconds_removed
    else:
        return audio, 0

In [7]:
# Make predictions on the entire test set, or do a small sample from dev_test to check metrics look OK.
predict_full_test_set = False

test_repo = 'jq/kinyarwanda-speech-hackathon'

if predict_full_test_set:
    test_ds = datasets.load_dataset(test_repo, split='test', num_proc=10)
else:
    test_ds = datasets.load_dataset(test_repo, split='dev_test[:300]')
    
test_ds = test_ds.cast_column("audio", datasets.Audio(sampling_rate=16000))

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

In [None]:
TRIM_LEADING_TRAILING_SILENCE = False
NUM_BEAMS = 5
NORMALISE_VOLUME = False

test_ids = []
test_transcriptions = []
test_labels = []

pipe = transformers.pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    return_timestamps=True,
    generate_kwargs={
        "language": processor.tokenizer.decode(
            salt.constants.SALT_LANGUAGE_TOKENS_WHISPER['kin']),
        "num_beams": NUM_BEAMS,
    },
    device_map="auto",
)

for i in tqdm.tqdm(range(len(test_ds))):   
    example = test_ds[i]

    audio_array = example['audio']['array']
    
    if NORMALISE_VOLUME:
        audio_array = audio_array / np.max(np.abs(audio_array))

    if TRIM_LEADING_TRAILING_SILENCE:
        audio_array, _ = trim_audio_start_end(audio_array)

    if len(audio_array) / 16000 < 30.0:
        input_features = processor(
            audio_array, sampling_rate=16000, return_tensors="pt").input_features
        input_features = input_features.to('cuda')
            
        predicted_ids = model.generate(
            input_features,
            num_beams=5,
            max_length=400,
            language=processor.tokenizer.decode(salt.constants.SALT_LANGUAGE_TOKENS_WHISPER['kin']),
        )
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    else:
        result = pipe(audio_array)
        transcription = result['text']

    if not predict_full_test_set:
        test_labels.append(example['text'])

    test_transcriptions.append(transcription)
    test_ids.append(example['id'])
    
if not predict_full_test_set:
    total_wer = jiwer.wer(normalise(test_labels), normalise(test_transcriptions))
    total_cer = jiwer.cer(normalise(test_labels), normalise(test_transcriptions))
    score = 1 - (0.6 * total_cer + 0.4 * total_wer)
    
    print(f"Word Error Rate (WER): {total_wer:.4f}")
    print(f"Character Error Rate (CER): {total_cer:.4f}")
    print(f"Score: {score:.4f}")

In [None]:
with open('test.json') as f:
    test_metadata = json.load(f)

test_keys = test_metadata.keys()

In [None]:
predictions = {}
for i, pred in zip(test_ids, test_transcriptions):
    predictions[i] = pred

In [None]:
import string

def strip_punctuation(text):
    # Create a translation table to remove all punctuation
    punctuation = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
    translator = str.maketrans('', '', punctuation)
    return text.translate(translator)
    
with open(f'submission-{PART}.csv', "w", encoding="utf-8") as f:
    if PART == 1:
        f.write('id,transcription\n')
    for k in test_keys:
        pred = predictions.get(k)
        if not pred:
            print('No prediction for key ', k)
            f.write(f"{k},a\n")
        else:
            normalised_pred = strip_punctuation(pred.lower())
            f.write(f"{k},{normalised_pred}\n")

In [None]:
!wc -l submission.csv

In [None]:
!head submission.csv