# Bangla Long-Form ASR Baseline

Minimal baseline using **wav2vec2-large-xlsr-53** with CTC decoding.

- **Input**: Long Bangla .wav files
- **Output**: Bangla text transcription
- **Method**: 25-second chunking, greedy CTC decoding

In [1]:
# Install dependencies
# !pip install -q transformers librosa soundfile torchaudio
import warnings
warnings.filterwarnings("ignore")


In [None]:
# Imports
import os
import glob
import numpy as np
import pandas as pd
import torch
import librosa
import torchaudio
from tqdm import tqdm
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from torch.utils.data import Dataset, DataLoader
from difflib import SequenceMatcher
import unicodedata
import re

# Configuration
BASE_INPUT_DIR = "/kaggle/input/dl-sprint-4-0-bengali-long-form-speech-recognition/transcription/transcription"
BASE_OUTPUT_DIR = "/kaggle/working/"
TEST_AUDIO_DIR = os.path.join(BASE_INPUT_DIR, "test")
SUBMISSION_PATH = os.path.join(BASE_OUTPUT_DIR, "submission.csv")

MODEL_NAME = "arijitx/wav2vec2-xls-r-300m-bengali"  # Bengali-specific model

SAMPLE_RATE = 16000
CHUNK_LENGTH_SEC = 15
OVERLAP_SEC = 3  # Overlap duration in seconds (recommended: 2-5 seconds)

# Spectral Gating Configuration
ENABLE_DENOISING = True  # Toggle denoising on/off
NOISE_GATE_THRESHOLD_K = 2.0  # Conservative: 1.5-2.5 (higher = less aggressive)
STFT_WIN_LENGTH_MS = 25  # ~25ms window
STFT_HOP_LENGTH_MS = 10  # ~10ms hop
SOFT_MASK_MIN = 0.1  # Minimum mask value (no hard zeroing)

# Post-processing Configuration
ENABLE_LM_DECODING = False  # N-gram LM decoding (requires KenLM installation)
ENABLE_UNICODE_NORMALIZATION = True  # Unicode normalization
ENABLE_SENTENCE_END_CHAR = False  # Append period at end if missing
SENTENCE_END_CHAR = "।"  # Bengali sentence end character (dari)

# Create output directory
os.makedirs(BASE_OUTPUT_DIR, exist_ok=True)

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")
print(f"Denoising enabled: {ENABLE_DENOISING}")
print(f"Chunk length: {CHUNK_LENGTH_SEC}s, Overlap: {OVERLAP_SEC}s")
print(f"Post-processing: LM={ENABLE_LM_DECODING}, Unicode={ENABLE_UNICODE_NORMALIZATION}, SentenceEnd={ENABLE_SENTENCE_END_CHAR}")

2026-02-02 20:24:45.501633: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770063885.873101     102 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770063886.009890     102 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770063886.876627     102 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770063886.876668     102 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770063886.876670     102 computation_placer.cc:177] computation placer alr

Device: cuda
Denoising enabled: True


In [3]:
print(f"Loading processor: {MODEL_NAME}")
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

print(f"Loading model: {MODEL_NAME}")
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.eval()
print("Model loaded successfully")

Loading processor: arijitx/wav2vec2-xls-r-300m-bengali


preprocessor_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/309 [00:00<?, ?B/s]

Loading model: arijitx/wav2vec2-xls-r-300m-bengali


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Model loaded successfully


In [4]:
class SpectralGatingDenoiser:
    """
    Conservative spectral gating denoiser optimized for ASR (not audio quality).
    
    Uses soft masking with no hard zeroing to preserve phonetic content.
    Designed for Bangla speech where consonant articulation is critical.
    
    Args:
        sample_rate: Audio sample rate (default: 16000)
        threshold_k: Noise gate threshold multiplier (default: 2.0)
                     Higher = more conservative (less denoising)
        win_length_ms: STFT window length in milliseconds
        hop_length_ms: STFT hop length in milliseconds
        soft_mask_min: Minimum mask value to prevent hard zeroing
    """
    
    def __init__(
        self,
        sample_rate=16000,
        threshold_k=2.0,
        win_length_ms=25,
        hop_length_ms=10,
        soft_mask_min=0.1
    ):
        self.sample_rate = sample_rate
        self.threshold_k = threshold_k
        self.soft_mask_min = soft_mask_min
        
        # Convert ms to samples
        self.win_length = int(win_length_ms * sample_rate / 1000)
        self.hop_length = int(hop_length_ms * sample_rate / 1000)
        
        # Ensure win_length is valid
        if self.win_length % 2 == 1:
            self.win_length += 1
            
    def __call__(self, waveform):
        """
        Apply spectral gating to waveform.
        
        Args:
            waveform: numpy array or torch tensor of shape (n_samples,)
            
        Returns:
            Denoised waveform as numpy array
        """
        # Convert to numpy if tensor
        if isinstance(waveform, torch.Tensor):
            waveform = waveform.cpu().numpy()
            
        # Ensure mono
        if waveform.ndim > 1:
            waveform = waveform.mean(axis=0)
            
        # Skip if audio too short
        if len(waveform) < self.win_length:
            return waveform
            
        # Step 1: Compute STFT with Hann window
        stft = librosa.stft(
            waveform,
            n_fft=self.win_length,
            hop_length=self.hop_length,
            window='hann'
        )
        
        magnitude = np.abs(stft)
        phase = np.angle(stft)
        
        # Step 2: Estimate noise profile from low-energy frames
        # Use bottom 20% of frames by energy as noise estimate
        frame_energy = np.sum(magnitude ** 2, axis=0)
        noise_threshold_percentile = 20
        noise_frames_mask = frame_energy <= np.percentile(frame_energy, noise_threshold_percentile)
        
        # Ensure we have some noise frames
        if noise_frames_mask.sum() < 5:
            # Fallback: use lowest 5 frames
            noise_frame_indices = np.argsort(frame_energy)[:5]
            noise_frames_mask = np.zeros_like(noise_frames_mask, dtype=bool)
            noise_frames_mask[noise_frame_indices] = True
            
        noise_magnitude = magnitude[:, noise_frames_mask]
        
        # Step 3: Compute per-frequency noise statistics
        noise_mean = np.mean(noise_magnitude, axis=1, keepdims=True)
        noise_std = np.std(noise_magnitude, axis=1, keepdims=True)
        
        # Add small epsilon to prevent division by zero
        noise_std = np.maximum(noise_std, 1e-8)
        
        # Step 4: Compute gating threshold
        # threshold(f) = noise_mean(f) + k * noise_std(f)
        threshold = noise_mean + self.threshold_k * noise_std
        
        # Step 5: Apply soft spectral mask
        # Mask = min(1.0, max(soft_mask_min, magnitude / threshold))
        mask = magnitude / (threshold + 1e-8)
        mask = np.clip(mask, self.soft_mask_min, 1.0)
        
        # Apply mask to magnitude
        denoised_magnitude = magnitude * mask
        
        # Step 6: Reconstruct with original phase
        denoised_stft = denoised_magnitude * np.exp(1j * phase)
        
        # Inverse STFT
        denoised_waveform = librosa.istft(
            denoised_stft,
            hop_length=self.hop_length,
            window='hann',
            length=len(waveform)  # Ensure same length as input
        )
        
        return denoised_waveform


print("SpectralGatingDenoiser class defined")

SpectralGatingDenoiser class defined


In [None]:
class ASRDataset(Dataset):
    """
    Memory-safe dataset for long-form ASR with overlap-aware chunking and on-the-fly denoising.
    
    Features:
    - Lazy audio loading (no preloading)
    - Overlap-aware chunk-wise processing
    - Optional spectral gating denoising
    - No intermediate file I/O
    """
    
    def __init__(
        self,
        audio_paths,
        chunk_length_sec=15,
        overlap_sec=3,
        sample_rate=16000,
        denoiser=None
    ):
        self.audio_paths = audio_paths
        self.chunk_length_sec = chunk_length_sec
        self.overlap_sec = overlap_sec
        self.sample_rate = sample_rate
        self.denoiser = denoiser
        
        # Calculate step size (chunk length minus overlap)
        self.step_sec = chunk_length_sec - overlap_sec
        
        # Precompute chunk indices for each audio file
        self.chunk_info = []  # [(audio_idx, chunk_idx, start_sec, end_sec, total_chunks)]
        
        for audio_idx, audio_path in enumerate(audio_paths):
            # Get audio duration without loading full file
            info = torchaudio.info(audio_path)
            duration_sec = info.num_frames / info.sample_rate
            
            # Calculate chunks with overlap
            chunk_idx = 0
            start_sec = 0.0
            
            while start_sec < duration_sec:
                end_sec = min(start_sec + chunk_length_sec, duration_sec)
                self.chunk_info.append((audio_idx, chunk_idx, start_sec, end_sec))
                start_sec += self.step_sec
                chunk_idx += 1
            
            # Store total chunks for this audio (update all entries)
            total_chunks = chunk_idx
            for i in range(len(self.chunk_info) - total_chunks, len(self.chunk_info)):
                audio_idx_stored, chunk_idx_stored, start_stored, end_stored = self.chunk_info[i]
                self.chunk_info[i] = (audio_idx_stored, chunk_idx_stored, start_stored, end_stored, total_chunks)
                
    def __len__(self):
        return len(self.chunk_info)
    
    def __getitem__(self, idx):
        """
        Load and process a single chunk with overlap information.
        
        Returns:
            dict with:
                - waveform: denoised audio chunk tensor
                - audio_path: original audio file path
                - chunk_idx: chunk index
                - total_chunks: total chunks for this audio
                - start_sec: chunk start time in seconds
                - end_sec: chunk end time in seconds
        """
        audio_idx, chunk_idx, start_sec, end_sec, total_chunks = self.chunk_info[idx]
        audio_path = self.audio_paths[audio_idx]
        
        # Load audio chunk on-the-fly
        start_frame = int(start_sec * self.sample_rate)
        num_frames = int((end_sec - start_sec) * self.sample_rate)
        
        # Load only the required chunk
        waveform, sr = torchaudio.load(
            audio_path,
            frame_offset=start_frame,
            num_frames=num_frames
        )
        
        # Convert to mono and resample if needed
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
            
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)
            
        # Normalize to [-1, 1]
        max_val = torch.max(torch.abs(waveform))
        if max_val > 0:
            waveform = waveform / max_val
            
        # Apply denoising if enabled
        if self.denoiser is not None:
            # Denoiser works on 1D numpy/tensor
            waveform_np = waveform.squeeze(0).numpy()
            denoised_np = self.denoiser(waveform_np)
            waveform = torch.from_numpy(denoised_np).unsqueeze(0)
            
        # Skip if chunk too short (< 0.5 sec)
        if waveform.shape[1] < self.sample_rate // 2:
            # Return empty marker
            return {
                'waveform': None,
                'audio_path': audio_path,
                'chunk_idx': chunk_idx,
                'total_chunks': total_chunks,
                'start_sec': start_sec,
                'end_sec': end_sec
            }
            
        return {
            'waveform': waveform.squeeze(0),  # Return 1D tensor
            'audio_path': audio_path,
            'chunk_idx': chunk_idx,
            'total_chunks': total_chunks,
            'start_sec': start_sec,
            'end_sec': end_sec
        }


def collate_fn(batch):
    """Custom collate function to handle variable-length chunks."""
    # Filter out None waveforms
    valid_batch = [item for item in batch if item['waveform'] is not None]
    
    if len(valid_batch) == 0:
        return None
        
    return valid_batch


print("ASRDataset with overlap-aware chunking and collate_fn defined")

ASRDataset and collate_fn defined


In [None]:
def transcribe_batch(batch, processor, model, device):
    """
    Transcribe a batch of audio chunks.
    
    Args:
        batch: List of chunk dictionaries
        processor: Wav2Vec2Processor
        model: Wav2Vec2ForCTC model
        device: torch device
        
    Returns:
        List of transcriptions
    """
    if batch is None or len(batch) == 0:
        return []
        
    # Extract waveforms
    waveforms = [item['waveform'].numpy() for item in batch]
    
    # Process batch
    inputs = processor(
        waveforms,
        sampling_rate=SAMPLE_RATE,
        return_tensors="pt",
        padding=True
    )
    
    input_values = inputs.input_values.to(device)
    
    # Inference
    with torch.no_grad():
        logits = model(input_values).logits
        
    # Greedy CTC decoding
    predicted_ids = torch.argmax(logits, dim=-1)
    transcriptions = processor.batch_decode(predicted_ids)
    
    return transcriptions


def clean_text(text):
    """Minimal text cleanup: remove extra spaces and empty tokens."""
    # Remove special tokens that might appear
    text = text.replace("<s>", "").replace("</s>", "")
    text = text.replace("<pad>", "").replace("<unk>", "")
    # Normalize whitespace
    text = " ".join(text.split())
    return text.strip()


def find_overlap_match(text1, text2, min_overlap_words=2, max_overlap_words=15):
    """
    Find the best overlap between the end of text1 and beginning of text2.
    
    Uses longest common subsequence matching to handle slight variations
    in transcription at chunk boundaries.
    
    Args:
        text1: First transcription (previous chunk)
        text2: Second transcription (current chunk)
        min_overlap_words: Minimum words to consider for overlap
        max_overlap_words: Maximum words to look back/forward for overlap
        
    Returns:
        Tuple of (overlap_end_idx_in_text1, overlap_start_idx_in_text2, confidence)
    """
    words1 = text1.split()
    words2 = text2.split()
    
    if len(words1) < min_overlap_words or len(words2) < min_overlap_words:
        return len(words1), 0, 0.0
    
    # Look at the end of text1 and beginning of text2
    end_words1 = words1[-max_overlap_words:] if len(words1) > max_overlap_words else words1
    start_words2 = words2[:max_overlap_words] if len(words2) > max_overlap_words else words2
    
    best_overlap = 0
    best_confidence = 0.0
    best_i = len(end_words1)
    best_j = 0
    
    # Try different overlap lengths
    for overlap_len in range(min_overlap_words, min(len(end_words1), len(start_words2)) + 1):
        # Get candidate overlap regions
        end_region = " ".join(end_words1[-overlap_len:])
        start_region = " ".join(start_words2[:overlap_len])
        
        # Calculate similarity using SequenceMatcher
        similarity = SequenceMatcher(None, end_region, start_region).ratio()
        
        # Weight longer overlaps slightly more
        weighted_score = similarity * (1 + overlap_len * 0.05)
        
        if similarity > 0.6 and weighted_score > best_confidence:
            best_overlap = overlap_len
            best_confidence = weighted_score
            best_i = len(words1) - overlap_len
            best_j = overlap_len
    
    # If no good overlap found, try word-by-word exact matching
    if best_confidence < 0.6:
        for i in range(max(0, len(end_words1) - max_overlap_words), len(end_words1)):
            for j in range(min(max_overlap_words, len(start_words2))):
                if end_words1[i].lower() == start_words2[j].lower():
                    # Found matching word, extend match
                    match_len = 1
                    while (i + match_len < len(end_words1) and 
                           j + match_len < len(start_words2) and 
                           end_words1[i + match_len].lower() == start_words2[j + match_len].lower()):
                        match_len += 1
                    
                    if match_len >= min_overlap_words:
                        # Calculate position in original words1
                        original_i = len(words1) - len(end_words1) + i
                        return original_i, j + match_len, 0.8
    
    if best_confidence > 0.6:
        # Calculate position in original words1
        original_i = len(words1) - len(end_words1) + (len(end_words1) - best_overlap)
        return original_i, best_j, best_confidence
    
    return len(words1), 0, 0.0


def stitch_transcriptions(transcriptions_with_timing):
    """
    Stitch overlapping transcriptions together intelligently.
    
    Args:
        transcriptions_with_timing: List of tuples (chunk_idx, start_sec, end_sec, transcription)
        
    Returns:
        Merged transcription string
    """
    if not transcriptions_with_timing:
        return ""
    
    # Sort by chunk index
    sorted_transcriptions = sorted(transcriptions_with_timing, key=lambda x: x[0])
    
    # Clean all transcriptions first
    cleaned = [(idx, start, end, clean_text(text)) for idx, start, end, text in sorted_transcriptions]
    
    # Filter out empty transcriptions
    cleaned = [(idx, start, end, text) for idx, start, end, text in cleaned if text.strip()]
    
    if not cleaned:
        return ""
    
    if len(cleaned) == 1:
        return cleaned[0][3]
    
    # Stitch transcriptions with overlap detection
    result_words = cleaned[0][3].split()
    
    for i in range(1, len(cleaned)):
        prev_text = " ".join(result_words)
        curr_text = cleaned[i][3]
        
        # Find overlap between accumulated result and current chunk
        overlap_end, overlap_start, confidence = find_overlap_match(prev_text, curr_text)
        
        curr_words = curr_text.split()
        
        if confidence > 0.5:
            # Good overlap found - merge at overlap point
            result_words = result_words[:overlap_end] + curr_words[overlap_start:]
        else:
            # No clear overlap - just concatenate (may cause some duplication)
            # But prefer current chunk's beginning as it's fresher
            result_words.extend(curr_words)
    
    return " ".join(result_words)


print("Transcription helper functions with overlap-aware stitching defined")

Transcription helper functions defined


In [None]:
# Post-processing Functions

def lm_decode_with_ngram(text, lm_path=None, alpha=0.5, beta=1.5):
    """
    Apply n-gram language model decoding to improve transcription.
    
    This function would use KenLM or pyctcdecode for beam search with LM.
    For now, it's a placeholder that returns the text as-is.
    
    Args:
        text: Input transcription text
        lm_path: Path to KenLM language model file (.arpa or .bin)
        alpha: LM weight
        beta: Word insertion bonus
        
    Returns:
        LM-corrected text
    """
    # Placeholder implementation
    # To use this, you would need:
    # 1. Install: pip install pyctcdecode
    # 2. Download/train a Bengali language model
    # 3. Use pyctcdecode.BeamSearchDecoderCTC with the model
    
    # For now, just return original text
    return text


def normalize_unicode_bangla(text):
    """
    Normalize Unicode characters for Bengali text.
    
    Handles:
    - NFC/NFD normalization
    - Bengali-specific character normalization
    - Zero-width joiners and non-joiners
    - Common Unicode variants
    
    Args:
        text: Input Bengali text
        
    Returns:
        Normalized text
    """
    if not text:
        return text
    
    # Step 1: Apply NFC (Canonical Composition) normalization
    # This is the standard for Bengali text
    text = unicodedata.normalize('NFC', text)
    
    # Step 2: Bengali-specific normalizations
    replacements = {
        # Normalize various forms of zero
        '০': '০',  # Ensure Bengali zero
        '\u09e6': '০',  # Alternative Bengali zero
        
        # Normalize anusvara and chandrabindu variants
        '\u0981': 'ঁ',  # Chandrabindu
        '\u0982': 'ং',  # Anusvara
        
        # Normalize nukta forms
        '\u09bc': '়',  # Bengali nukta
        
        # Normalize virama (hasant)
        '\u09cd': '্',  # Hasant/Virama
        
        # Normalize common punctuation
        '|': '।',  # Replace pipe with dari
        '॥': '।।',  # Double dari normalization
    }
    
    for old, new in replacements.items():
        text = text.replace(old, new)
    
    # Step 3: Remove zero-width characters (except necessary ones)
    # Keep Zero Width Joiner (ZWJ) and Zero Width Non-Joiner (ZWNJ) as they're important for Bengali
    # But remove other zero-width spaces
    text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
    
    # Step 4: Normalize multiple spaces to single space
    text = re.sub(r'\s+', ' ', text)
    
    # Step 5: Remove leading/trailing whitespace
    text = text.strip()
    
    return text


def add_sentence_ending(text, end_char="।"):
    """
    Append sentence-ending character if not present.
    
    For Bengali, the dari (।) is the standard sentence-ending punctuation.
    
    Args:
        text: Input text
        end_char: Character to append (default: Bengali dari ।)
        
    Returns:
        Text with sentence ending
    """
    if not text:
        return text
    
    text = text.strip()
    
    # Check if text already ends with common Bengali punctuation
    bengali_punctuation = ['।', '॥', '?', '!', '.']
    
    if not any(text.endswith(p) for p in bengali_punctuation):
        text = text + end_char
    
    return text


def apply_post_processing(text, enable_lm=False, enable_unicode=True, enable_end_char=True, end_char="।"):
    """
    Apply all post-processing steps to transcription.
    
    Args:
        text: Input transcription
        enable_lm: Enable n-gram LM decoding
        enable_unicode: Enable Unicode normalization
        enable_end_char: Enable sentence ending character
        end_char: Sentence ending character
        
    Returns:
        Post-processed text
    """
    if not text:
        return text
    
    # Step 1: N-gram LM decoding (if enabled and available)
    if enable_lm:
        text = lm_decode_with_ngram(text)
    
    # Step 2: Unicode normalization
    if enable_unicode:
        text = normalize_unicode_bangla(text)
    
    # Step 3: Add sentence ending
    if enable_end_char:
        text = add_sentence_ending(text, end_char)
    
    return text


print("Post-processing functions defined (LM decoding, Unicode normalization, sentence ending)")

In [None]:
# Initialize denoiser (if enabled)
denoiser = None
if ENABLE_DENOISING:
    denoiser = SpectralGatingDenoiser(
        sample_rate=SAMPLE_RATE,
        threshold_k=NOISE_GATE_THRESHOLD_K,
        win_length_ms=STFT_WIN_LENGTH_MS,
        hop_length_ms=STFT_HOP_LENGTH_MS,
        soft_mask_min=SOFT_MASK_MIN
    )
    print(f"Denoiser initialized (threshold_k={NOISE_GATE_THRESHOLD_K})")
else:
    print("Denoising disabled")

# Get test audio files
test_files = sorted(glob.glob(os.path.join(TEST_AUDIO_DIR, "audio", "*.wav")))
print(f"\nFound {len(test_files)} test audio files")

# Create dataset and dataloader with overlap-aware chunking
dataset = ASRDataset(
    audio_paths=test_files,
    chunk_length_sec=CHUNK_LENGTH_SEC,
    overlap_sec=OVERLAP_SEC,
    sample_rate=SAMPLE_RATE,
    denoiser=denoiser
)

dataloader = DataLoader(
    dataset,
    batch_size=4,  # Process 4 chunks at a time
    shuffle=False,
    num_workers=2,  # Set to 0 for Kaggle compatibility
    collate_fn=collate_fn
)

print(f"Dataset created: {len(dataset)} total chunks (with {OVERLAP_SEC}s overlap)")
print(f"Processing with batch_size=4\n")

# Process all chunks and aggregate by file with timing info
from collections import defaultdict
file_transcriptions = defaultdict(list)

for batch in tqdm(dataloader, desc="Processing audio chunks"):
    if batch is None:
        continue
        
    # Transcribe batch
    transcriptions = transcribe_batch(batch, processor, model, DEVICE)
    
    # Aggregate transcriptions by file with timing information
    for item, transcription in zip(batch, transcriptions):
        audio_path = item['audio_path']
        filename = os.path.basename(audio_path)
        chunk_idx = item['chunk_idx']
        start_sec = item['start_sec']
        end_sec = item['end_sec']
        
        if transcription.strip():
            file_transcriptions[filename].append((chunk_idx, start_sec, end_sec, transcription))

# Merge transcriptions for each file using overlap-aware stitching
print("\nApplying overlap-aware stitching and post-processing...")
results = []
for filename in sorted(file_transcriptions.keys()):
    # Use the stitching function
    full_text = stitch_transcriptions(file_transcriptions[filename])
    
    # Apply post-processing
    processed_text = apply_post_processing(
        full_text,
        enable_lm=ENABLE_LM_DECODING,
        enable_unicode=ENABLE_UNICODE_NORMALIZATION,
        enable_end_char=ENABLE_SENTENCE_END_CHAR,
        end_char=SENTENCE_END_CHAR
    )
    
    results.append({
        "filename": filename,
        "transcript": processed_text
    })

print(f"Processed {len(results)} files with overlap-aware stitching + post-processing")

Denoiser initialized (threshold_k=2.0)

Found 24 test audio files
Dataset created: 5341 total chunks
Processing with batch_size=4



Processing audio chunks: 100%|██████████| 1336/1336 [11:41<00:00,  1.90it/s]


Processed 24 files





In [8]:
SUBMISSION_PATH = "/kaggle/working/"
# Create submission DataFrame
submission_df = pd.DataFrame(results)
submission_df = submission_df[["filename", "transcript"]]

# Fill any empty transcriptions
submission_df["transcript"] = submission_df["transcript"].fillna("")

# Save submission
submission_df.to_csv(SUBMISSION_PATH + "submission.csv", index=False, encoding="utf-8")
print(f"Submission saved to: {SUBMISSION_PATH}submission.csv")

# Display preview
print(f"\nSubmission preview ({len(submission_df)} rows):")
print(submission_df.head(10))

Submission saved to: /kaggle/working/submission.csv

Submission preview (24 rows):
       filename                                         transcript
0  test_001.wav  এআক্সক্ষির এটেনিকেপকা আপনাকেদে ালোমানার মিডিগি...
1  test_002.wav  মিন্তু রচ্ছাধারীনা আগিন সৈনাআমি সব দিল পমলা এব...
2  test_003.wav  গল্পুটির সত্য আনন্দ পাবলিশাস প্রাইভেটলিমেটে কো...
3  test_004.wav  যে কোনো জায়গায় যেতে রাতের ট্রেনি আমাদের প্সব...
4  test_005.wav  বেচি নিবেদন ফ্রাইডেইক্লাসেক্স পরে বকিমেগ গাকছে...
5  test_006.wav  আাদের খুব প্রন্দ হইছে আা ছেলে পছন্দমি ইলে আপনা...
6  test_008.wav  বাদরির পোচাগর এই রকম াঙি দুগডবাটা পড সইে যাওয়...
7  test_009.wav  দ এব দি কেকদ মৃত্যাগত গলতে সরায কেবে িতঅমিবভাজ...
8  test_010.wav  আাই একটা তাড়াতালি ক ইে ফোন জলে আসছে আরে এতারা...
9  test_011.wav  বেচি নিবেদন ফ্রাইডে ক্লাসেক্স ই জুই মন্টু গিয়...


In [9]:
# Verify submission file
final_df = pd.read_csv(SUBMISSION_PATH)
print("Submission verification:")
print(f"  - Total rows: {len(final_df)}")
print(f"  - Columns: {list(final_df.columns)}")
print(f"  - Empty transcriptions: {(final_df['transcription'] == '').sum()}")
print("\nDone!")

IsADirectoryError: [Errno 21] Is a directory: '/kaggle/working/'