In [None]:
!apt-get update -y && apt-get install -y ffmpeg
!pip install -U "datasets>=2.17.0" "transformers>=4.44.0" "accelerate>=0.33.0" "evaluate>=0.4.2" jiwer librosa bitsandbytes peft gradio
!pip install torchcodec
!pip install --upgrade transformers accelerate

0% [Working]            Hit:1 https://cli.github.com/packages stable InRelease
0% [Connecting to archive.ubuntu.com (91.189.91.81)] [Connecting to security.ub                                                                               Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
0% [Connecting to archive.ubuntu.com (91.189.91.81)] [Connecting to security.ub                                                                               Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
0% [Waiting for headers] [Waiting for headers] [Connecting to r2u.stat.illinois0% [Waiting for headers] [Waiting for headers] [Connecting to r2u.stat.illinois                                                                               Get:4 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease [18.1 kB]
0% [Waiting for headers] [Waiting for headers] [Connecting to r2u.stat.illinois

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import librosa
from datasets import Dataset, DatasetDict, concatenate_datasets
from sklearn.model_selection import train_test_split
from transformers import (
    WhisperProcessor,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from dataclasses import dataclass
from typing import Any, Dict, List, Union, Tuple, Optional
import evaluate
from tqdm import tqdm
from datasets import load_from_disk
from google.colab import files
import time
import soundfile as sf
from tqdm.auto import tqdm
import os, math, shutil, hashlib
from multiprocessing import Pool
from transformers import AutoTokenizer
from pathlib import Path
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from collections import defaultdict
from transformers import TrainerCallback
import random
from transformers import EarlyStoppingCallback
import sys
import re
from collections import Counter
import torchaudio

In [None]:
# Try to import forced alignment - fallback if not available
try:
    from torchaudio.pipelines import MMS_FA as FA_BUNDLE
    FORCED_ALIGNMENT_AVAILABLE = True
except ImportError:
    FORCED_ALIGNMENT_AVAILABLE = False

# For LUFS normalization
try:
    import pyloudnorm as pyln
    PYLOUDNORM_AVAILABLE = True
except ImportError:
    PYLOUDNORM_AVAILABLE = False

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
MODEL_NAME = "openai/whisper-large-v3"
MAX_AUDIO_SEC = 30
MAX_LABEL_TOK = 448
SR = 16000
N_MELS = 128

In [None]:
QUALITY_CONFIG = {
    "min_duration": 0.5,        # seconds - filter out very short clips
    "max_duration": 30.0,       # seconds - Whisper's limit
    "min_snr_clinical": 5.0,    # dB - lower threshold for db data
    "min_snr_standard": 10.0,   # dB - higher threshold for more data
    "max_tokens": 448,          # Whisper's token limit
}

# Dataset balancing config
BALANCE_CONFIG = {
    "dementia_oversample": 5,   # Repeat DementiaBank samples 5x
    "cv_undersample_ratio": 1.0,  # Keep all CommonVoice
    "seed": 42,
}

# Training format distribution (for timestamp preservation)
FORMAT_CONFIG = {
    "text_only_ratio": 0.63,           # Plain text, no timestamps
    "timestamps_ratio": 0.25,           # Text with timestamps
    "timestamps_with_prev_ratio": 0.12, # Text with timestamps and previous context
}

# Paths
DB_AUDIO_FOLDER = "/content/drive/MyDrive/Capstone/Data/dementiabank_v2/audio_data"
CV_AUDIO_FOLDER = "/content/drive/MyDrive/vin_capstone/new_data"
WAV_OUTPUT_DIR = "/content/drive/MyDrive/whisper_wavs_normalized"
CACHE_DIR = "/content/drive/MyDrive/vin-capstone/data/preprocessed_whisper_large_v3_features_v2"

In [None]:
def setup_tokenizer_with_fillers(model_name: str, filler_tokens: List[str]) -> Tuple[WhisperTokenizer, WhisperProcessor, int]:
    """
    Setup tokenizer with added filler tokens for dementia speech.

    CRITICAL: This must be done BEFORE any text encoding, and the model
    embeddings has to be resized to match.

    Returns:
        tokenizer: WhisperTokenizer with filler tokens added
        processor: WhisperProcessor with updated tokenizer
        new_vocab_size: Size to resize model embeddings to
    """

    tokenizer = WhisperTokenizer.from_pretrained(
        model_name,
        language="en",
        task="transcribe"
    )
    processor = WhisperProcessor.from_pretrained(
        model_name,
        language="en",
        task="transcribe"
    )

    original_vocab_size = len(tokenizer)

    # Add filler tokens
    num_added = tokenizer.add_tokens(filler_tokens, special_tokens=False)

    # Update processors tokenizer
    processor.tokenizer = tokenizer

    new_vocab_size = len(tokenizer)

    print(f"  Original vocab size: {original_vocab_size:,}")
    print(f"  Added {num_added} filler tokens: {filler_tokens}")
    print(f"  New vocab size: {new_vocab_size:,}")
    print(f"\n  IMPORTANT: Resize model embeddings to {new_vocab_size} when loading!")

    return tokenizer, processor, new_vocab_size

In [None]:
def clean_dementiabank_preserve_fillers(raw_transcript: str) -> str:
    """
    Clean DementiaBank CHAT transcripts while keeping disfluencies.

    These are clinically meaningful for dementia detection:
    - Filled pauses (uh, um) - converted to special tokens
    - Repetitions
    - False starts
    - Word-finding pauses
    """

    if pd.isna(raw_transcript) or raw_transcript is None:
        return ""

    text = str(raw_transcript)

    # Remove CHAT metadata lines
    text = re.sub(r'^@.*$', '', text, flags=re.MULTILINE)

    # Remove speaker labels (*PAR:, *INV:, etc.)
    text = re.sub(r'^\*\w+:\s*', '', text, flags=re.MULTILINE)

    # Expand repetition markers: word [x 3] - word word word
    def expand_repetition(match):
        word = match.group(1)
        count = int(match.group(2))
        return ' '.join([word] * count)
    text = re.sub(r'(\w+)\s*\[x\s*(\d+)\]', expand_repetition, text)

    # Normalize filled pauses to our special tokens (don't remove!!)
    text = re.sub(r'\b[Uu]h+\b', '[UH]', text)
    text = re.sub(r'\b[Uu]m+\b', '[UM]', text)
    text = re.sub(r'\b[Uu]mm+\b', '[UM]', text)
    text = re.sub(r'\b[Ee]r+\b', '[ER]', text)
    text = re.sub(r'\b[Aa]h+\b', '[AH]', text)
    text = re.sub(r'\b[Hh]mm*\b', '[HM]', text)

    # Keep partial words (false starts) - marked with & in CHAT
    text = re.sub(r'&\+(\w+)', r'\1', text)  # &+word - word (attempted word)
    text = re.sub(r'&(\w+)', r'\1', text)    # &word - word (fragment)

    # Remove annotation brackets but process content
    text = re.sub(r'\[:\s*[^\]]+\]', '', text)   # [: timing annotations]
    text = re.sub(r'\[=!\s*[^\]]+\]', '', text)  # [=! actions]
    text = re.sub(r'\[=\s*[^\]]+\]', '', text)   # [= comments]
    text = re.sub(r'\[\*[^\]]*\]', '', text)     # [* error codes]
    text = re.sub(r'\[/+\]', '', text)           # [/] [//] retracing markers
    text = re.sub(r'\[%[^\]]*\]', '', text)      # [% dependent tier]

    # Handle unintelligible speech
    text = re.sub(r'\bxxx\b', '[UNINTELLIGIBLE]', text)
    text = re.sub(r'\bwww\b', '', text)  # Untranscribed portions
    text = re.sub(r'\byyy\b', '', text)  # Phonological coding

    # Remove remaining CHAT symbols
    text = re.sub(r'[<>]', '', text)     # Overlap markers
    text = re.sub(r'\+["/.]', '', text)  # Utterance terminators
    text = re.sub(r'‡|„', '', text)      # Special markers

    # Clean up whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    return text

In [None]:
def clean_commonvoice_transcript(transcript: str) -> str:
    """
    Clean CommonVoice transcripts
    CommonVoice is already fairly clean, but normalize any fillers
    that might be present to match token format.
    """
    if pd.isna(transcript) or transcript is None:
        return ""

    text = str(transcript).strip()

    text = re.sub(r'\b[Uu]h+\b', '[UH]', text)
    text = re.sub(r'\b[Uu]m+\b', '[UM]', text)
    text = re.sub(r'\b[Ee]r+\b', '[ER]', text)
    text = re.sub(r'\b[Aa]h+\b', '[AH]', text)

    return text

In [None]:
def prepare_transcript_for_whisper(text: str) -> str:
    """
    Final prep of transcript for Whisper training

    Whisper was trained with:
    1. Leading space before text
    2. Preserved punctuation and capitalization
    """

    if pd.isna(text) or text is None:
        return ""

    text = str(text).strip()

    if not text:
        return ""

    # Add leading space (Whisper convention from pretraining)
    if not text.startswith(" "):
        text = " " + text

    # Do NOT lowercase - Whisper expects natural capitalization
    # Do NOT remove punctuation - Whisper expects it
    return text

In [None]:
def normalize_audio_rms(y: np.ndarray, target_db: float = -20.0) -> np.ndarray:
    """
    Normalize audio to target RMS loudness level
    """
    rms = np.sqrt(np.mean(y**2))

    if rms < 1e-10:  # Essentially silence
        return y

    current_db = 20 * np.log10(rms + 1e-10)
    gain_db = target_db - current_db
    gain_linear = 10 ** (gain_db / 20)

    y_normalized = y * gain_linear
    y_normalized = np.clip(y_normalized, -1.0, 1.0)

    return y_normalized.astype(np.float32)

In [None]:
def normalize_audio_lufs(y: np.ndarray, sr: int = 16000, target_lufs: float = -23.0) -> np.ndarray:
    """
    Normalize audio using LUFS loudness (broadcast standard).
    Requires pyloudnorm. Falls back to RMS if not available.
    """
    if not PYLOUDNORM_AVAILABLE:
        return normalize_audio_rms(y)

    try:
        meter = pyln.Meter(sr)
        current_lufs = meter.integrated_loudness(y)

        if current_lufs > -70:  # Not silence
            y = pyln.normalize.loudness(y, current_lufs, target_lufs)

        return np.clip(y, -1.0, 1.0).astype(np.float32)
    except:
        return normalize_audio_rms(y)

In [None]:
def remove_dc_offset(y: np.ndarray) -> np.ndarray:
    """Remove DC offset from audio signal """
    return (y - np.mean(y)).astype(np.float32)

In [None]:
def convert_audio_to_wav_normalized(
    input_path: str,
    output_dir: str,
    target_sr: int = 16000,
    use_lufs: bool = True
) -> Optional[str]:
    """
    Convert audio to WAV with normalization
    - Added loudness normalization
    - Added DC offset removal
    - Better error handling
    """
    os.makedirs(output_dir, exist_ok=True)

    # Create unique filename using hash
    file_hash = hashlib.md5(input_path.encode()).hexdigest()[:8]
    filename = f"{file_hash}_{Path(input_path).stem}.wav"
    output_path = os.path.join(output_dir, filename)

    # Skip if already converted
    if os.path.exists(output_path):
        return output_path

    try:
        # Load and resample
        y, sr = librosa.load(input_path, sr=target_sr, mono=True)

        # Remove DC offset
        y = remove_dc_offset(y)

        # Normalize loudness
        if use_lufs and PYLOUDNORM_AVAILABLE:
            y = normalize_audio_lufs(y, sr=target_sr)
        else:
            y = normalize_audio_rms(y, target_db=-20.0)

        # Save as 16-bit PCM WAV
        sf.write(output_path, y, target_sr, subtype='PCM_16')

        return output_path

    except Exception as e:
        print(f"\nError converting {input_path}: {e}")
        return None

In [None]:
def read_window_wav(
    path: str,
    sr: int = 16000,
    offset_sec: float = 0.0,
    win_sec: float = 30.0
) -> np.ndarray:
    """
    Fast reading of WAV file window
    """
    start_frame = int(offset_sec * sr)
    num_frames = int(win_sec * sr)

    try:
        y, file_sr = sf.read(
            path,
            start=start_frame,
            frames=num_frames,
            dtype="float32",
            always_2d=False
        )

        if y.ndim == 2:
            y = y.mean(axis=1)

    except Exception as e:
        y = np.zeros(num_frames, dtype=np.float32)

    # Pad or trim to exact length
    target_len = int(win_sec * sr)
    if len(y) < target_len:
        y = np.pad(y, (0, target_len - len(y)))
    elif len(y) > target_len:
        y = y[:target_len]

    return y

In [None]:
def estimate_snr(audio_path: str, sr: int = 16000) -> Optional[float]:
    """
    Estimate SNR using energy percentiles
    Assumes bottom 10% energy frames are noise, top 90% is signal
    """
    try:
        y, file_sr = sf.read(audio_path, dtype='float32')

        if file_sr != sr:
            y = librosa.resample(y, orig_sr=file_sr, target_sr=sr)

        if len(y.shape) > 1:
            y = y.mean(axis=1)

        # Frame-based energy calculation
        frame_length = int(0.025 * sr)  # 25ms frames
        hop_length = int(0.010 * sr)    # 10ms hop

        n_frames = (len(y) - frame_length) // hop_length
        if n_frames < 10:
            return None  # Too short to estimate

        energy = np.array([
            np.sum(y[i*hop_length:i*hop_length+frame_length]**2)
            for i in range(n_frames)
        ])

        energy_db = 10 * np.log10(energy + 1e-10)

        noise_floor = np.percentile(energy_db, 10)
        signal_level = np.percentile(energy_db, 90)

        return signal_level - noise_floor

    except Exception:
        return None

In [None]:
def get_audio_duration(audio_path: str) -> Optional[float]:
    """Get audio duration in seconds"""
    try:
        info = sf.info(audio_path)
        return info.frames / info.samplerate
    except:
        return None

In [None]:
def check_transcript_quality(
    transcript: str,
    tokenizer: WhisperTokenizer,
    max_tokens: int = 448
) -> Dict:
    """Check transcript quality metrics."""
    if pd.isna(transcript) or not transcript:
        return {"valid": False, "reason": "empty", "token_count": 0}

    transcript = str(transcript).strip()

    if len(transcript) < 2:
        return {"valid": False, "reason": "too_short", "token_count": 0}

    # Check token count
    tokens = tokenizer(transcript, add_special_tokens=True).input_ids
    token_count = len(tokens)

    if token_count > max_tokens:
        return {"valid": False, "reason": "too_many_tokens", "token_count": token_count}

    return {"valid": True, "reason": None, "token_count": token_count}

In [None]:
def quality_filter_dataset(
    df: pd.DataFrame,
    tokenizer: WhisperTokenizer,
    config: Dict
) -> pd.DataFrame:
    """
    Filter dataset by quality criteria
    Uses different thresholds for db vs standard speech
    """

    results = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Checking quality"):
        issues = []
        duration = None
        snr = None

        # Duration check
        duration = get_audio_duration(row['audio_path'])
        if duration is None:
            issues.append("unreadable_audio")
        elif duration < config["min_duration"]:
            issues.append(f"too_short")
        elif duration > config["max_duration"]:
            issues.append(f"too_long")

        # SNR check (only if audio readable)
        if duration is not None and not issues:
            snr = estimate_snr(row['audio_path'])
            # Use different threshold based on source
            is_clinical = row.get('source', '') == 'dementiabank'
            min_snr = config["min_snr_clinical"] if is_clinical else config["min_snr_standard"]

            if snr is not None and snr < min_snr:
                issues.append(f"low_snr")

        # Transcript check
        transcript_result = check_transcript_quality(
            row['transcript_clean'],
            tokenizer,
            config["max_tokens"]
        )
        if not transcript_result["valid"]:
            issues.append(f"transcript_{transcript_result['reason']}")

        results.append({
            "idx": idx,
            "valid": len(issues) == 0,
            "issues": "|".join(issues) if issues else None,
            "duration": duration,
            "snr": snr,
            "token_count": transcript_result["token_count"],
        })

    results_df = pd.DataFrame(results)

    # Print statistics
    n_valid = results_df['valid'].sum()
    n_total = len(df)
    print(f"\n  Results: {n_valid:,} / {n_total:,} valid ({100*n_valid/n_total:.1f}%)")

    # Breakdown of rejection reasons
    rejected = results_df[~results_df['valid']]
    if len(rejected) > 0:
        all_issues = []
        for issues_str in rejected['issues'].dropna():
            all_issues.extend(issues_str.split('|'))

        issue_counts = Counter(all_issues)
        print(f"\n  Rejection breakdown:")
        for issue, count in issue_counts.most_common():
            print(f"    {issue}: {count:,}")

    # Duration statistics for valid samples
    valid_results = results_df[results_df['valid']]
    if len(valid_results) > 0:
        print(f"\n  Valid sample statistics:")
        print(f"    Duration: {valid_results['duration'].mean():.1f}s avg, "
              f"{valid_results['duration'].min():.1f}s min, "
              f"{valid_results['duration'].max():.1f}s max")
        if valid_results['snr'].notna().any():
            print(f"    SNR: {valid_results['snr'].mean():.1f}dB avg")
        print(f"    Tokens: {valid_results['token_count'].mean():.0f} avg")

    # Filter and return
    valid_indices = results_df[results_df['valid']]['idx'].tolist()
    filtered_df = df.loc[valid_indices].reset_index(drop=True)

    return filtered_df

In [None]:
def check_missing_files(df: pd.DataFrame, audio_folder: str) -> pd.DataFrame:
    """
    Check which audio files exist
    """
    df_check = df.copy()
    df_check['full_audio_path'] = df_check['path'].apply(
        lambda x: os.path.join(audio_folder, x)
    )

    df_check['file_exists'] = df_check['full_audio_path'].apply(os.path.exists)

    total = len(df_check)
    existing = df_check['file_exists'].sum()
    missing = total - existing

    print(f"File Existence Check: {audio_folder}")
    print(f"Total files in DataFrame: {total:,}")
    print(f"Files that exist:         {existing:,} ({existing/total*100:.1f}%)")
    print(f"Files that are missing:   {missing:,} ({missing/total*100:.1f}%)")

    if missing > 0:
        missing_files = df_check[~df_check['file_exists']]
        print(f"\nFirst 10 missing files:")
        for _, row in missing_files.head(10).iterrows():
            print(f"  {row['path']}")

    return df_check

In [None]:
# Global variables for forced alignment model
_FA_MODEL = None
_FA_TOKENIZER = None
_FA_ALIGNER = None
_FA_DEVICE = None

In [None]:
def _init_forced_aligner():
    """Initialize forced alignment model (called once)."""
    global _FA_MODEL, _FA_TOKENIZER, _FA_ALIGNER, _FA_DEVICE

    if not FORCED_ALIGNMENT_AVAILABLE:
        return None, None, None, None

    if _FA_MODEL is None:
        _FA_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"  Loading forced alignment model on {_FA_DEVICE}...")
        _FA_MODEL = FA_BUNDLE.get_model().to(_FA_DEVICE)
        _FA_TOKENIZER = FA_BUNDLE.get_tokenizer()
        _FA_ALIGNER = FA_BUNDLE.get_aligner()

    return _FA_MODEL, _FA_TOKENIZER, _FA_ALIGNER, _FA_DEVICE

In [None]:
def get_word_timestamps(audio_path: str, transcript: str) -> Optional[List[Dict]]:
    """
    Get word-level timestamps using forced alignment
    Returns list of {"word": str, "start": float, "end": float}
    """
    model, fa_tokenizer, aligner, device = _init_forced_aligner()

    if model is None:
        return None

    try:
        waveform, sr = torchaudio.load(audio_path)

        # Resample if needed
        if sr != FA_BUNDLE.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, FA_BUNDLE.sample_rate)

        # Prepare transcript for alignment (lowercase, alphanumeric only)
        transcript_norm = transcript.lower()
        transcript_norm = ''.join(c for c in transcript_norm if c.isalnum() or c.isspace())
        words = transcript_norm.split()

        if not words:
            return None

        with torch.no_grad():
            emission, _ = model(waveform.to(device))
            tokens = fa_tokenizer(transcript_norm)

            if not tokens:
                return None

            token_spans = aligner(emission[0], tokens)

        # Convert token spans to word timestamps
        ratio = waveform.shape[1] / emission.shape[1] / FA_BUNDLE.sample_rate
        word_timestamps = []
        token_idx = 0

        for word in words:
            word_tokens = fa_tokenizer(word)
            n_tokens = len(word_tokens)

            if token_idx + n_tokens <= len(token_spans):
                start = token_spans[token_idx].start * ratio
                end = token_spans[token_idx + n_tokens - 1].end * ratio
                word_timestamps.append({
                    "word": word,
                    "start": float(start),
                    "end": float(end)
                })
                token_idx += n_tokens

        return word_timestamps if word_timestamps else None

    except Exception as e:
        # Silently fail -  use fallback
        return None

In [None]:
# Global tokenizer for chunking (set by _init_whisper_tokenizer)
_WHISPER_TOK = None

def _init_whisper_tokenizer(tokenizer_name: str, filler_tokens: List[str]):
    """
    Initialize Whisper tokenizer for chunking
    """
    global _WHISPER_TOK
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

    _WHISPER_TOK = WhisperTokenizer.from_pretrained(
        tokenizer_name,
        language="en",
        task="transcribe"
    )

    # Add filler tokens to match main tokenizer
    if filler_tokens:
        _WHISPER_TOK.add_tokens(filler_tokens, special_tokens=False)

In [None]:
def chunk_with_alignment(
    audio_path: str,
    transcript: str,
    max_audio_sec: float = 30.0,
    max_label_tok: int = 448
) -> List[Dict]:
    """
    Chunk audio using forced alignment
    Falls back to simple duration-based chunking if alignment fails
    """
    rows = []

    # Get audio duration
    duration = get_audio_duration(audio_path)
    if not duration or duration <= 0:
        return rows

    # For short audio, no chunking needed
    if duration <= max_audio_sec:
        token_count = len(_WHISPER_TOK(transcript, add_special_tokens=False).input_ids)
        if token_count <= max_label_tok:
            return [{
                "audio_path": audio_path,
                "transcript_chunk": transcript,
                "chunk_idx": 0,
                "total_chunks": 1,
                "offset_sec": 0.0,
                "end_sec": duration,
                "alignment_method": "no_chunking_needed"
            }]

    # Try forced alignment
    word_times = get_word_timestamps(audio_path, transcript)

    if word_times and len(word_times) > 0:
        # Build chunks at natural word boundaries
        chunk_start_time = 0.0
        chunk_words = []
        chunk_start_idx = 0

        for i, wt in enumerate(word_times):
            word, w_start, w_end = wt["word"], wt["start"], wt["end"]

            # Would this word exceed max duration?
            potential_duration = w_end - chunk_start_time

            if potential_duration > max_audio_sec and chunk_words:
                # Finalize current chunk
                chunk_text = ' '.join(chunk_words)
                chunk_end_time = word_times[i-1]["end"] if i > 0 else w_start

                # Verify token count
                tokens = _WHISPER_TOK(chunk_text, add_special_tokens=False).input_ids
                if len(tokens) <= max_label_tok:
                    rows.append({
                        "audio_path": audio_path,
                        "transcript_chunk": chunk_text,
                        "chunk_idx": len(rows),
                        "total_chunks": -1,  # Updated later
                        "offset_sec": chunk_start_time,
                        "end_sec": chunk_end_time,
                        "alignment_method": "forced_alignment"
                    })

                # Start new chunk
                chunk_start_time = w_start
                chunk_words = [word]
            else:
                chunk_words.append(word)

        # last chunk
        if chunk_words:
            chunk_text = ' '.join(chunk_words)
            chunk_end_time = word_times[-1]["end"]

            tokens = _WHISPER_TOK(chunk_text, add_special_tokens=False).input_ids
            if len(tokens) <= max_label_tok:
                rows.append({
                    "audio_path": audio_path,
                    "transcript_chunk": chunk_text,
                    "chunk_idx": len(rows),
                    "total_chunks": -1,
                    "offset_sec": chunk_start_time,
                    "end_sec": chunk_end_time,
                    "alignment_method": "forced_alignment"
                })

        # Update total_chunks
        for row in rows:
            row["total_chunks"] = len(rows)

        if rows:
            return rows

    # FALLBACK: duration-based chunking
    # Used when forced alignment fails
    n_chunks = max(1, math.ceil(duration / max_audio_sec))

    # Split transcript by tokens
    all_tokens = _WHISPER_TOK(transcript, add_special_tokens=False).input_ids
    tokens_per_chunk = max(1, len(all_tokens) // n_chunks)

    for i in range(n_chunks):
        start_tok = i * tokens_per_chunk
        end_tok = min((i + 1) * tokens_per_chunk, len(all_tokens))

        if i == n_chunks - 1:  # Last chunk gets remainder
            end_tok = len(all_tokens)

        chunk_tokens = all_tokens[start_tok:end_tok]
        chunk_text = _WHISPER_TOK.decode(chunk_tokens, skip_special_tokens=True)

        rows.append({
            "audio_path": audio_path,
            "transcript_chunk": chunk_text,
            "chunk_idx": i,
            "total_chunks": n_chunks,
            "offset_sec": i * max_audio_sec,
            "end_sec": min((i + 1) * max_audio_sec, duration),
            "alignment_method": "fallback_duration"
        })

    return rows

In [None]:
def chunk_dataset_with_alignment(
    df: pd.DataFrame,
    tokenizer_name: str,
    filler_tokens: List[str],
    max_audio_sec: float = 30.0,
    max_label_tok: int = 448
) -> pd.DataFrame:
    """
    Chunk all samples in dataset using forced alignment.
    Note: Runs sequentially because forced alignment model is on GPU
    """

    # Initialize tokenizer
    _init_whisper_tokenizer(tokenizer_name, filler_tokens)

    # Initialize forced aligner
    if FORCED_ALIGNMENT_AVAILABLE:
        _init_forced_aligner()
        print("  Forced alignment model loaded")
    else:
        print("  Using fallback chunking (no forced alignment)")

    all_rows = []
    alignment_stats = {"forced_alignment": 0, "fallback_duration": 0, "no_chunking_needed": 0}

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Chunking"):
        chunks = chunk_with_alignment(
            row['audio_path'],
            row['transcript_clean'],
            max_audio_sec,
            max_label_tok
        )

        # Track alignment method and keep source
        for chunk in chunks:
            chunk['source'] = row.get('source', 'unknown')
            alignment_stats[chunk.get('alignment_method', 'unknown')] = \
                alignment_stats.get(chunk.get('alignment_method', 'unknown'), 0) + 1

        all_rows.extend(chunks)

    df_chunked = pd.DataFrame(all_rows)

    print(f"\n  Total chunks: {len(df_chunked):,}")
    print(f"  Alignment methods:")
    for method, count in alignment_stats.items():
        if count > 0:
            print(f"    {method}: {count:,} ({100*count/len(df_chunked):.1f}%)")

    return df_chunked

In [None]:
# Timestamp Training Format

def assign_training_format(df: pd.DataFrame, config: Dict, seed: int = 42) -> pd.DataFrame:
    """
    Assign training format to each sample for timestamp preservation
    Distribution (from research on preserving Whisper capabilities):
    - ~63% text only (no timestamps) - standard training
    - ~25% text + timestamps - teaches timestamp generation
    - ~12% text + timestamps + previous context - enables long-form transcription
    """

    random.seed(seed)
    n = len(df)

    # Shuffle indices
    indices = list(range(n))
    random.shuffle(indices)

    # Calculate split points
    n_with_prev = int(config["timestamps_with_prev_ratio"] * n)
    n_with_timestamps = int(config["timestamps_ratio"] * n)

    # Assign formats
    format_map = {}
    for i, idx in enumerate(indices):
        if i < n_with_prev:
            format_map[idx] = "timestamps_and_prev"
        elif i < n_with_prev + n_with_timestamps:
            format_map[idx] = "timestamps_only"
        else:
            format_map[idx] = "text_only"

    df = df.copy()
    df["training_format"] = df.index.map(format_map)

    # Print distribution
    format_counts = df["training_format"].value_counts()
    print(f"  Format distribution:")
    for fmt, count in format_counts.items():
        print(f"    {fmt}: {count:,} ({100*count/n:.1f}%)")

    return df

In [None]:
def format_transcript_with_timestamps(
    text: str,
    offset_sec: float,
    end_sec: float,
    include_timestamps: bool = False,
    include_prev: bool = False,
    prev_text: Optional[str] = None
) -> str:
    """
    Format transcript with optional Whisper-style timestamps.
    Timestamp format: <|0.00|> text <|2.40|>
    Previous context format: <|startofprev|>prev text<|startoftranscript|>current text
    """
    formatted = text

    if include_timestamps:
        start_ts = f"<|{offset_sec:.2f}|>"
        end_ts = f"<|{end_sec:.2f}|>"
        formatted = f"{start_ts}{text}{end_ts}"

    if include_prev and prev_text:
        formatted = f"<|startofprev|>{prev_text}<|startoftranscript|>{formatted}"

    return formatted

In [None]:
def encode_labels_with_format(
    df: pd.DataFrame,
    tokenizer: WhisperTokenizer,
    max_tokens: int = 448
) -> pd.DataFrame:
    """
    Encode labels with appropriate formatting based on training_format
    """

    # Sort by audio file and chunk index to enable previous context lookup
    df = df.sort_values(["audio_path", "chunk_idx"]).reset_index(drop=True)

    encoded_labels = []
    formatted_texts = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Encoding"):
        fmt = row["training_format"]

        # Get previous chunk text if needed
        prev_text = None
        if fmt == "timestamps_and_prev" and row["chunk_idx"] > 0:
            # Find previous chunk from same file
            prev_rows = df[
                (df["audio_path"] == row["audio_path"]) &
                (df["chunk_idx"] == row["chunk_idx"] - 1)
            ]
            if len(prev_rows) > 0:
                prev_text = prev_rows.iloc[0]["transcript_chunk"]
                # Truncate previous text if too long
                if prev_text and len(prev_text) > 100:
                    prev_text = prev_text[-100:]  # Last 100 chars

        # Format the text
        include_timestamps = fmt in ["timestamps_only", "timestamps_and_prev"]
        include_prev = fmt == "timestamps_and_prev" and prev_text is not None

        formatted_text = format_transcript_with_timestamps(
            text=row["transcript_chunk"],
            offset_sec=row.get("offset_sec", 0.0),
            end_sec=row.get("end_sec", 30.0),
            include_timestamps=include_timestamps,
            include_prev=include_prev,
            prev_text=prev_text
        )

        # Encode
        tokens = tokenizer(
            formatted_text,
            truncation=True,
            max_length=max_tokens,
            add_special_tokens=True
        ).input_ids

        encoded_labels.append(tokens)
        formatted_texts.append(formatted_text)

    df["labels"] = encoded_labels
    df["formatted_text"] = formatted_texts

    # Stats
    avg_len = np.mean([len(l) for l in encoded_labels])
    max_len = max(len(l) for l in encoded_labels)
    print(f"\n  Average token length: {avg_len:.1f}")
    print(f"  Max token length: {max_len}")

    return df

In [None]:
def balance_and_combine_datasets(
    db_df: pd.DataFrame,
    cv_df: pd.DataFrame,
    config: Dict
) -> pd.DataFrame:
    """
    Balance datasets by oversampling db data
    """

    # make sure source column exists
    db_df = db_df.copy()
    cv_df = cv_df.copy()

    if 'source' not in db_df.columns:
        db_df['source'] = 'dementiabank'
    if 'source' not in cv_df.columns:
        cv_df['source'] = 'commonvoice'

    print(f"  Original sizes:")
    print(f"    DementiaBank: {len(db_df):,}")
    print(f"    CommonVoice:  {len(cv_df):,}")
    print(f"    Ratio: 1:{len(cv_df)/max(1,len(db_df)):.1f}")

    # Oversample DementiaBank
    oversample_factor = config.get("dementia_oversample", 5)
    db_oversampled = pd.concat([db_df] * oversample_factor, ignore_index=True)

    # undersample CommonVoice (optioonal)
    undersample_ratio = config.get("cv_undersample_ratio", 1.0)
    if undersample_ratio < 1.0:
        cv_sampled = cv_df.sample(
            frac=undersample_ratio,
            random_state=config.get("seed", 42)
        ).reset_index(drop=True)
    else:
        cv_sampled = cv_df

    print(f"\n  After balancing:")
    print(f"    DementiaBank: {len(db_oversampled):,} ({oversample_factor}x oversample)")
    print(f"    CommonVoice:  {len(cv_sampled):,}")
    print(f"    New ratio: 1:{len(cv_sampled)/max(1,len(db_oversampled)):.1f}")

    # Combine and shuffle
    combined = pd.concat([db_oversampled, cv_sampled], ignore_index=True)
    combined = combined.sample(frac=1, random_state=config.get("seed", 42)).reset_index(drop=True)

    print(f"\n  Final dataset: {len(combined):,} samples")
    print(f"  Source distribution:")
    for source, count in combined['source'].value_counts().items():
        print(f"    {source}: {count:,} ({100*count/len(combined):.1f}%)")

    return combined

In [None]:
def process_dataset_for_features(
    dataset: Dataset,
    feature_extractor: WhisperFeatureExtractor,
    split_name: str,
    chunk_size: int = 5000,
    sr: int = 16000,
    max_audio_sec: float = 30.0
) -> Dataset:
    """
    Process dataset to extract mel spectrogram features
    """
    n_mels = feature_extractor.feature_size
    dummy_t = 3000

    total = len(dataset)
    all_chunks = []
    errors = 0

    for chunk_start in range(0, total, chunk_size):
        chunk_end = min(chunk_start + chunk_size, total)
        chunk_num = chunk_start // chunk_size + 1
        total_chunks = (total + chunk_size - 1) // chunk_size

        print(f"\n  {split_name} - Chunk {chunk_num}/{total_chunks} "
              f"(samples {chunk_start:,}-{chunk_end:,})")

        processed = []

        for i in tqdm(range(chunk_start, chunk_end), desc="Processing"):
            sample = dataset[i]

            try:
                # Read audio window
                y = read_window_wav(
                    sample["audio_path"],
                    sr=sr,
                    offset_sec=float(sample.get("offset_sec", 0.0)),
                    win_sec=max_audio_sec,
                )

                # Extract mel features
                feats = feature_extractor(
                    [y],
                    sampling_rate=sr,
                    return_tensors=None
                ).input_features[0]

                processed.append({
                    "input_features": np.asarray(feats, dtype=np.float16),
                    "labels": sample["labels"],
                    "input_length": feats.shape[-1],
                    "audio_path": sample["audio_path"],
                    "offset_sec": sample.get("offset_sec"),
                    "end_sec": sample.get("end_sec"),
                    "source": sample.get("source"),
                })


            except Exception as e:
                print(f"\n  Error on sample {i}: {e}")
                errors += 1

                # Add dummy features for failed samples
                processed.append({
                    "input_features": np.zeros((n_mels, dummy_t), dtype=np.float16),
                    "labels": sample["labels"],
                    "input_length": dummy_t,
                    "audio_path": sample["audio_path"],
                    "offset_sec": sample.get("offset_sec"),
                    "end_sec": sample.get("end_sec"),
                    "source": sample.get("source"),
                })


        chunk_dataset = Dataset.from_list(processed)
        all_chunks.append(chunk_dataset)

        print(f"    Chunk {chunk_num} complete: {len(processed):,} samples")

    print(f"\n  Concatenating {len(all_chunks)} chunks")
    final_dataset = concatenate_datasets(all_chunks)

    print(f"  {split_name}: {len(final_dataset):,} samples | Errors: {errors}")

    return final_dataset


In [None]:
def split_by_file(
    df: pd.DataFrame,
    test_size: float = 0.2,
    random_state: int = 42
) -> Tuple[pd.DataFrame, pd.DataFrame]:

    """
    Split dataset by FILE, not by chunk.
    Ensures all chunks from the same audio file stay together,
    preventing data leakage
    """
    unique_files = df['audio_path'].unique()
    train_files, test_files = train_test_split(
        unique_files,
        test_size=test_size,
        random_state=random_state
    )

    train_df = df[df['audio_path'].isin(train_files)].reset_index(drop=True)
    test_df = df[df['audio_path'].isin(test_files)].reset_index(drop=True)

    print(f"  Train: {len(train_files):,} files → {len(train_df):,} chunks")
    print(f"  Test:  {len(test_files):,} files → {len(test_df):,} chunks")

    # Show source distribution in each split
    if 'source' in train_df.columns:
        print(f"\n  Train source distribution:")
        for source, count in train_df['source'].value_counts().items():
            print(f"    {source}: {count:,}")
        print(f"\n  Test source distribution:")
        for source, count in test_df['source'].value_counts().items():
            print(f"    {source}: {count:,}")

    return train_df, test_df

In [None]:
def run_full_pipeline(
    cv_train_path: str,
    db_train_path: str,
    cv_audio_folder: str,
    db_audio_folder: str,
    wav_output_dir: str,
    cache_dir: str,
    model_name: str = "openai/whisper-large-v3",
    run_feature_extraction: bool = True,
    filler_tokens: List[str] = None,
    quality_config: Dict = None,
    balance_config: Dict = None,
    format_config: Dict = None,
    max_audio_sec: float = 30.0,
    max_label_tok: int = 448,
    sample_rate: int = 16000
):
    """
    Run the complete data curation pipeline
    This is the main entry point that ties everything together

    Args:
        cv_train_path: Path to CommonVoice training CSV
        db_train_path: Path to DementiaBank training CSV
        cv_audio_folder: Folder containing CommonVoice audio files
        db_audio_folder: Folder containing DementiaBank audio files
        wav_output_dir: Directory to save converted WAV files
        cache_dir: Directory to save final preprocessed dataset
        model_name: Whisper model name (default: whisper-large-v3)
        run_feature_extraction: Whether to extract mel features (default: True)
        filler_tokens: List of filler tokens to add (default: [UH], [UM], etc.)
        quality_config: Quality filtering config (default: sensible defaults)
        balance_config: Dataset balancing config (default: 5x oversample dementia)
        format_config: Training format distribution (default: 63%/25%/12% split)
        max_audio_sec: Maximum audio segment length (default: 30.0)
        max_label_tok: Maximum label tokens (default: 448)
        sample_rate: Audio sample rate (default: 16000)
    """

    # Set defaults for mutable arguments
    if filler_tokens is None:
        filler_tokens = ["[UH]", "[UM]", "[ER]", "[AH]", "[HM]", "[UNINTELLIGIBLE]"]

    if quality_config is None:
        quality_config = {
            "min_duration": 0.5,
            "max_duration": 30.0,
            "min_snr_clinical": 5.0,
            "min_snr_standard": 10.0,
            "max_tokens": 448,
        }

    if balance_config is None:
        balance_config = {
            "dementia_oversample": 5,
            "cv_undersample_ratio": 1.0,
            "seed": 42,
        }

    if format_config is None:
        format_config = {
            "text_only_ratio": 0.63,
            "timestamps_ratio": 0.25,
            "timestamps_with_prev_ratio": 0.12,
        }

    start_time = time.time()

    # ========================================
    # Setup tokenizer with filler tokens
    tokenizer, processor, new_vocab_size = setup_tokenizer_with_fillers(
        model_name,
        filler_tokens
    )
    feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)

    # ========================================
    # Load csv files

    cv_train = pd.read_csv(cv_train_path)
    db_train = pd.read_csv(db_train_path)

    print(f"  Loaded {len(cv_train):,} CommonVoice samples")
    print(f"  Loaded {len(db_train):,} DementiaBank samples")

    # ========================================
    # Check files

    cv_check = check_missing_files(cv_train, cv_audio_folder)
    db_check = check_missing_files(db_train, db_audio_folder)

    cv_valid = cv_check[cv_check['file_exists']].copy()
    db_valid = db_check[db_check['file_exists']].copy()

    # ========================================
    # Clean transcripts (keeping fillers)

    # CommonVoice
    cv_valid['transcript_clean'] = cv_valid['ground_truth'].apply(
        clean_commonvoice_transcript
    ).apply(
        prepare_transcript_for_whisper
    )
    cv_valid['audio_path'] = cv_valid['full_audio_path']
    cv_valid['source'] = 'commonvoice'

    # DementiaBank - check for raw transcripts first
    if 'transcript_raw' in db_valid.columns:
        print("  Using raw DementiaBank transcripts (preserving disfluencies)")
        db_valid['transcript_clean'] = db_valid['transcript_raw'].apply(
            clean_dementiabank_preserve_fillers
        ).apply(
            prepare_transcript_for_whisper
        )
    else:
        print(" Using pre-cleaned transcripts. Fillers may be lost")
        db_valid['transcript_clean'] = db_valid['transcript_text_clean_full'].apply(
            clean_dementiabank_preserve_fillers
        ).apply(
            prepare_transcript_for_whisper
        )

    db_valid['audio_path'] = db_valid['full_audio_path']
    db_valid['source'] = 'dementiabank'

    # ========================================
    # Convert audio with normalization

    # Combine for conversion
    all_audio_paths = list(cv_valid['audio_path'].unique()) + list(db_valid['audio_path'].unique())
    print(f"  Found {len(all_audio_paths):,} unique audio files")

    path_map = {}
    for path in tqdm(all_audio_paths, desc="Converting"):
        new_path = convert_audio_to_wav_normalized(
            path,
            wav_output_dir,
            target_sr=sample_rate,
            use_lufs=PYLOUDNORM_AVAILABLE
        )
        if new_path:
            path_map[path] = new_path

    print(f"  Converted {len(path_map):,} files")

    # Update paths
    cv_valid['audio_path'] = cv_valid['audio_path'].map(path_map)
    db_valid['audio_path'] = db_valid['audio_path'].map(path_map)

    # Remove rows where conversion failed
    cv_valid = cv_valid[cv_valid['audio_path'].notna()].reset_index(drop=True)
    db_valid = db_valid[db_valid['audio_path'].notna()].reset_index(drop=True)

    # ========================================
    # Quality filtering
    cv_subset = cv_valid[['audio_path', 'transcript_clean', 'source']].copy()
    db_subset = db_valid[['audio_path', 'transcript_clean', 'source']].copy()

    cv_filtered = quality_filter_dataset(cv_subset, tokenizer, quality_config)
    db_filtered = quality_filter_dataset(db_subset, tokenizer, quality_config)

    # ========================================
    # Balance and combine datasets
    df_combined = balance_and_combine_datasets(db_filtered, cv_filtered, balance_config)

    # ========================================
    # Chunk with forced alignment

    df_chunked = chunk_dataset_with_alignment(
        df_combined,
        tokenizer_name=model_name,
        filler_tokens=filler_tokens,
        max_audio_sec=max_audio_sec,
        max_label_tok=max_label_tok
    )

    # ========================================
    # Assign training formats
    df_chunked = assign_training_format(df_chunked, format_config)

    # ========================================
    # Encode labels

    df_chunked = encode_labels_with_format(df_chunked, tokenizer, max_label_tok)

    # ========================================
    # Train/test split by file

    train_df, test_df = split_by_file(df_chunked, test_size=0.2)

    # ========================================
    # Create HuggingFace Dataset

    cols_for_dataset = ["audio_path", "labels", "offset_sec", "end_sec", "source", "training_format"]

    dem_bank = DatasetDict({
        "train": Dataset.from_list(train_df[cols_for_dataset].to_dict(orient="records")),
        "test": Dataset.from_list(test_df[cols_for_dataset].to_dict(orient="records")),
    })

    print(f"\n  Dataset created:")
    print(f"    Train: {len(dem_bank['train']):,} samples")
    print(f"    Test: {len(dem_bank['test']):,} samples")

    # ========================================
    # Feature extraction
    if run_feature_extraction:
        print("\n>>> Extracting features")

        train_processed = process_dataset_for_features(
            dem_bank["train"],
            feature_extractor,
            split_name="Train",
            chunk_size=5000,
            sr=sample_rate,
            max_audio_sec=max_audio_sec
        )

        test_processed = process_dataset_for_features(
            dem_bank["test"],
            feature_extractor,
            split_name="Test",
            chunk_size=2000,
            sr=sample_rate,
            max_audio_sec=max_audio_sec
        )

        # Rebuild dataset with features
        dem_bank = DatasetDict({
            "train": train_processed,
            "test": test_processed,
        })

        # ========================================
        # Save to disk

        dem_bank.save_to_disk(cache_dir)
        print(f"  Saved to: {cache_dir}")


    elapsed = time.time() - start_time
    print(f"  PIPELINE COMPLETE in {elapsed/60:.1f} minutes")
    print(f"  Final dataset: {len(dem_bank['train']):,} train, {len(dem_bank['test']):,} test")
    print(f"  New vocab size (for model resize): {new_vocab_size}")

    return dem_bank, tokenizer, processor, new_vocab_size

In [None]:
CV_TRAIN_PATH = "train_cv_master.csv"
DB_TRAIN_PATH = "train_dementia_master.csv"
CV_AUDIO_FOLDER = "/content/drive/MyDrive/vin_capstone/new_data"
DB_AUDIO_FOLDER = "/content/drive/MyDrive/Capstone/Data/dementiabank_v2/audio_data"
WAV_OUTPUT_DIR = "/content/drive/MyDrive/whisper_wavs_normalized"
CACHE_DIR = "/content/drive/MyDrive/vin-capstone/data/preprocessed_whisper_large_v3_features_v3_final"

# Run the pipeline
dem_bank, tokenizer, processor, new_vocab_size = run_full_pipeline(
    cv_train_path=CV_TRAIN_PATH,
    db_train_path=DB_TRAIN_PATH,
    cv_audio_folder=CV_AUDIO_FOLDER,
    db_audio_folder=DB_AUDIO_FOLDER,
    wav_output_dir=WAV_OUTPUT_DIR,
    cache_dir=CACHE_DIR,
)

print(f"\n New model base_model.resize_token_embeddings({new_vocab_size})")

  Original vocab size: 51,866
  Added 6 filler tokens: ['[UH]', '[UM]', '[ER]', '[AH]', '[HM]', '[UNINTELLIGIBLE]']
  New vocab size: 51,872

  IMPORTANT: Resize model embeddings to 51872 when loading!
  Loaded 48,984 CommonVoice samples
  Loaded 2,895 DementiaBank samples
File Existence Check: /content/drive/MyDrive/vin_capstone/new_data
Total files in DataFrame: 48,984
Files that exist:         31,384 (64.1%)
Files that are missing:   17,600 (35.9%)

First 10 missing files:
  common_voice_en_22525318.mp3
  common_voice_en_28441998.mp3
  common_voice_en_24066239.mp3
  common_voice_en_24066241.mp3
  common_voice_en_27257344.mp3
  common_voice_en_27257366.mp3
  common_voice_en_39607994.mp3
  common_voice_en_30627006.mp3
  common_voice_en_36736392.mp3
  common_voice_en_30544061.mp3
File Existence Check: /content/drive/MyDrive/Capstone/Data/dementiabank_v2/audio_data
Total files in DataFrame: 2,895
Files that exist:         2,826 (97.6%)
Files that are missing:   69 (2.4%)

First 10 missi

Converting:   0%|          | 0/33468 [00:00<?, ?it/s]

  Converted 33,468 files


Checking quality:   0%|          | 0/31384 [00:00<?, ?it/s]


  Results: 31,366 / 31,384 valid (99.9%)

  Rejection breakdown:
    low_snr: 15
    transcript_too_many_tokens: 3

  Valid sample statistics:
    Duration: 5.3s avg, 1.3s min, 10.5s max
    SNR: 52.1dB avg
    Tokens: 20 avg


Checking quality:   0%|          | 0/2826 [00:00<?, ?it/s]


  Results: 991 / 2,826 valid (35.1%)

  Rejection breakdown:
    too_long: 1,824
    transcript_too_many_tokens: 284
    low_snr: 7
    too_short: 4

  Valid sample statistics:
    Duration: 19.3s avg, 1.0s min, 30.0s max
    SNR: 22.3dB avg
    Tokens: 70 avg
  Original sizes:
    DementiaBank: 991
    CommonVoice:  31,366
    Ratio: 1:31.7

  After balancing:
    DementiaBank: 4,955 (5x oversample)
    CommonVoice:  31,366
    New ratio: 1:6.3

  Final dataset: 36,321 samples
  Source distribution:
    commonvoice: 31,366 (86.4%)
    dementiabank: 4,955 (13.6%)
  Loading forced alignment model on cuda...
Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt


100%|██████████| 1.18G/1.18G [00:14<00:00, 87.5MB/s]


  Forced alignment model loaded


Chunking:   0%|          | 0/36321 [00:00<?, ?it/s]


  Total chunks: 36,321
  Alignment methods:
    no_chunking_needed: 36,321 (100.0%)
  Format distribution:
    text_only: 22,883 (63.0%)
    timestamps_only: 9,080 (25.0%)
    timestamps_and_prev: 4,358 (12.0%)


Encoding:   0%|          | 0/36321 [00:00<?, ?it/s]


  Average token length: 28.1
  Max token length: 387
  Train: 25,763 files → 28,982 chunks
  Test:  6,441 files → 7,339 chunks

  Train source distribution:
    commonvoice: 25,092
    dementiabank: 3,890

  Test source distribution:
    commonvoice: 6,274
    dementiabank: 1,065

  Dataset created:
    Train: 28,982 samples
    Test: 7,339 samples

>>> Extracting features

  Train - Chunk 1/6 (samples 0-5,000)


Processing:   0%|          | 0/5000 [00:00<?, ?it/s]

    Chunk 1 complete: 5,000 samples

  Train - Chunk 2/6 (samples 5,000-10,000)


Processing:   0%|          | 0/5000 [00:00<?, ?it/s]

    Chunk 2 complete: 5,000 samples

  Train - Chunk 3/6 (samples 10,000-15,000)


Processing:   0%|          | 0/5000 [00:00<?, ?it/s]

    Chunk 3 complete: 5,000 samples

  Train - Chunk 4/6 (samples 15,000-20,000)


Processing:   0%|          | 0/5000 [00:00<?, ?it/s]

    Chunk 4 complete: 5,000 samples

  Train - Chunk 5/6 (samples 20,000-25,000)


Processing:   0%|          | 0/5000 [00:00<?, ?it/s]

    Chunk 5 complete: 5,000 samples

  Train - Chunk 6/6 (samples 25,000-28,982)


Processing:   0%|          | 0/3982 [00:00<?, ?it/s]

    Chunk 6 complete: 3,982 samples

  Concatenating 6 chunks
  Train: 28,982 samples | Errors: 0

  Test - Chunk 1/4 (samples 0-2,000)


Processing:   0%|          | 0/2000 [00:00<?, ?it/s]

    Chunk 1 complete: 2,000 samples

  Test - Chunk 2/4 (samples 2,000-4,000)


Processing:   0%|          | 0/2000 [00:00<?, ?it/s]

    Chunk 2 complete: 2,000 samples

  Test - Chunk 3/4 (samples 4,000-6,000)


Processing:   0%|          | 0/2000 [00:00<?, ?it/s]

    Chunk 3 complete: 2,000 samples

  Test - Chunk 4/4 (samples 6,000-7,339)


Processing:   0%|          | 0/1339 [00:00<?, ?it/s]

    Chunk 4 complete: 1,339 samples

  Concatenating 4 chunks
  Test: 7,339 samples | Errors: 0


Saving the dataset (0/45 shards):   0%|          | 0/28982 [00:00<?, ? examples/s]

Saving the dataset (0/12 shards):   0%|          | 0/7339 [00:00<?, ? examples/s]

  Saved to: /content/drive/MyDrive/vin-capstone/data/preprocessed_whisper_large_v3_features_v3_final
  PIPELINE COMPLETE in 324.3 minutes
  Final dataset: 28,982 train, 7,339 test
  New vocab size (for model resize): 51872

 New model base_model.resize_token_embeddings(51872)
