In [None]:
%%capture
!git clone https://github.com/sunbirdai/salt.git
!pip install -r salt/requirements.txt
!pip install -q transformers

In [None]:
!export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:2048
!export CUDA_LAUNCH_BLOCKING=0
!export CUBLAS_WORKSPACE_CONFIG=:4096:8

In [None]:
import os
import torch
import librosa
import salt.constants
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, Audio
from transformers import WhisperProcessor, pipeline
from typing import Optional, Union, List
import numpy as np
import time
import gc
import warnings


In [None]:
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
# Enable optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.enable_flash_sdp(True)  # A100 flash attention

In [None]:
# Configuration
MODEL_NAME = "jq/whisper-large-v3-kin-track-b"
HF_TOKEN = "" #hf token or login into hugging face
OUTPUT_FILE = "/kin/whisper_transcriptions.csv"
DATASET_NAME = "jq/kinyarwanda-speech-hackathon"
LANGUAGE = "kinyarwanda"    
BATCH_SIZE = 100        
CHECKPOINT_EVERY = 500

In [None]:
class Whisper:
    LANGUAGE_CODES = {v.lower(): k for k, v in salt.constants.SALT_LANGUAGE_NAMES.items()}
    LANGUAGE_ID_TOKENS = salt.constants.SALT_LANGUAGE_TOKENS_WHISPER.copy()
    
    DEFAULT_GENERATION_CONFIG = {
        "prompt_condition_type": "first-segment",
        "condition_on_prev_tokens": True,
        "repetition_penalty": 1.1,
        "no_repeat_ngram_size": 5,
        "num_beams": 1,
        "max_length": 448,
        "do_sample": False,
    }
    
    SAMPLE_RATE = 16000

    def __init__(self, model_name: str, hf_token: str):
        self.device = torch.device("cuda")
        dtype = torch.float16  
        
        print("Loading processor...")
        self.processor = WhisperProcessor.from_pretrained(
            model_name, 
            token=hf_token, 
            language=None,
            cache_dir="/tmp/whisper_cache"
        )

        print("Loading model...")
        self.model = pipeline(
            task="automatic-speech-recognition",
            model=model_name,
            tokenizer=self.processor.tokenizer,
            feature_extractor=self.processor.feature_extractor,
            device=self.device,
            torch_dtype=dtype,
            token=hf_token,
            model_kwargs={
                "attn_implementation": "sdpa",
                "use_cache": True,
                "low_cpu_mem_usage": True,
            },
            chunk_length_s=30,
            batch_size=BATCH_SIZE, 
        )
        
        # Try to compile for extra speed
        try:
            self.model.model = torch.compile(self.model.model, mode="max-autotune")
            print("✅ Model compiled successfully!")
        except Exception as e:
            print(f"⚠️ Model compilation not available: {e}")

    def get_language_code(self, language: str) -> str:
        lang_key = language.strip().lower()
        if lang_key not in self.LANGUAGE_CODES:
            raise ValueError(f"Language '{language}' is not supported.")
        iso_code = self.LANGUAGE_CODES[lang_key]
        if iso_code not in self.LANGUAGE_ID_TOKENS:
            raise ValueError(f"No token mapping found for language code '{iso_code}'.")
        token = self.LANGUAGE_ID_TOKENS[iso_code]
        decoded_lang = self.processor.tokenizer.decode(token)[2:-2]
        return decoded_lang

    def transcribe_batch(self, audio_arrays: List[np.ndarray], language: str) -> List[str]:
        """Batch transcription with updated autocast API"""
        language_code = self.get_language_code(language)
        
        # Prepare batch input
        batch_input = [{"array": audio_array, "sampling_rate": self.SAMPLE_RATE} 
                      for audio_array in audio_arrays]
        
        try:
            with torch.amp.autocast('cuda'):
                results = self.model(
                    batch_input,
                    generate_kwargs={
                        **self.DEFAULT_GENERATION_CONFIG,
                        "language": language_code
                    },
                    return_timestamps=False
                )
            
            if isinstance(results, list):
                return [result['text'] for result in results]
            else:
                return [results['text']]
                
        except Exception as e:
            print(f"Batch error: {e}")
            return [f"[ERROR] {e}"] * len(audio_arrays)

In [None]:
# Initialize model
whisper_model = Whisper(MODEL_NAME, HF_TOKEN)

In [None]:
# streaming  dataset 
print("📁 Loading dataset...")
dataset = load_dataset(DATASET_NAME, split="train", token=HF_TOKEN, streaming=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    

In [None]:
def transcribe_dataset():
    """
    single-GPU transcription
    """
    
    print("Starting transcription...")
    
    # Resume support
    if os.path.exists(OUTPUT_FILE):
        existing_df = pd.read_csv(OUTPUT_FILE)
        completed_ids = set(existing_df["id"])
        print(f"📋 Resuming from {len(completed_ids)} completed samples")
    else:
        completed_ids = set()
    
    # Processing
    results = []
    batch_data = []
    processed_count = 0
    start_time = time.time()
    
    print("Starting batch processing...")

    for sample in tqdm(dataset.skip(len(completed_ids)), desc="Processing"):
        if sample["id"] in completed_ids:
            continue
            
        batch_data.append(sample)
        
        # Process when batch is full
        if len(batch_data) >= BATCH_SIZE:
            audio_arrays = [s["audio"]["array"] for s in batch_data]
            
            try:
                transcriptions = whisper_model.transcribe_batch(audio_arrays, LANGUAGE)
                
                for sample, transcription in zip(batch_data, transcriptions):
                    results.append({
                        "id": sample["id"],
                        "language": LANGUAGE.lower(),
                        "original_text": sample["text"],
                        "transcription": transcription
                    })
                
                processed_count += len(batch_data)
                
                # Show progress
                elapsed = time.time() - start_time
                rate = processed_count / elapsed if elapsed > 0 else 0
                eta = (300000 - processed_count) / rate / 3600 if rate > 0 else 0
                print(f"⚡ Processed: {processed_count} | Rate: {rate:.1f}/sec | ETA: {eta:.1f}h")
                
            except Exception as e:
                print(f"Batch processing error: {e}")
                # Add error entries
                for sample in batch_data:
                    results.append({
                        "id": sample["id"],
                        "language": LANGUAGE.lower(),
                        "original_text": sample["text"],
                        "transcription": f"[ERROR] {e}"
                    })
                processed_count += len(batch_data)
            
            batch_data = []
            
            # Memory cleanup
            if processed_count % 100 == 0:
                torch.cuda.empty_cache()
                gc.collect()
            
            # Checkpoint
            if len(results) >= CHECKPOINT_EVERY:
                print(f"Checkpointing at {processed_count} samples...")
                df = pd.DataFrame(results)
                df.to_csv(
                    OUTPUT_FILE,
                    mode="a",
                    index=False,
                    header=not os.path.exists(OUTPUT_FILE)
                )
                results = []
    
    # Process remaining batch
    if batch_data:
        audio_arrays = [s["audio"]["array"] for s in batch_data]
        transcriptions = whisper_model.transcribe_batch(audio_arrays, LANGUAGE)
        
        for sample, transcription in zip(batch_data, transcriptions):
            results.append({
                "id": sample["id"],
                "language": LANGUAGE.lower(),
                "original_text": sample["text"],
                "transcription": transcription
            })
    
    # Save final results
    if results:
        df = pd.DataFrame(results)
        df.to_csv(
            OUTPUT_FILE,
            mode="a",
            index=False,
            header=not os.path.exists(OUTPUT_FILE)
        )
    
    total_time = time.time() - start_time
    print(f"Transcription complete!")
    print(f"Total time: {total_time/3600:.2f} hours")
    print(f"Average rate: {processed_count/(total_time/60):.1f} samples/minute")
    print(f" Results saved to: {OUTPUT_FILE}")


In [None]:
if __name__ == "__main__":
    # Run the optimized transcription
    transcribe_dataset()