# Doctor-Patient ASR Baseline Model

This notebook implements a baseline ASR model for doctor-patient conversations using the Hugging Face dataset `Shamus/United-Syn-Med`.

## Pipeline Overview:
1. **Data Loading**: Stream subset from HF dataset
2. **Preprocessing**: Audio resampling + feature extraction
3. **Training**: Fine-tune Whisper Small/Tiny
4. **Evaluation**: Compute WER metrics
5. **Inference**: Test on validation samples

Target: ~2000 samples (~500MB) for baseline prototype

## Setup and Dependencies

In [8]:
# Install required packages
!pip install datasets transformers torch torchaudio accelerate evaluate jiwer tensorboard
!pip install --upgrade huggingface_hub
# Install additional audio dependencies if needed
!pip install soundfile librosa



In [9]:
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
import evaluate
from typing import Dict, List, Optional, Tuple
import warnings
warnings.filterwarnings('ignore')

# Configuration
CONFIG = {
    'dataset_name': 'Shamus/United-Syn-Med',
    'model_name': 'openai/whisper-tiny',  # Change to 'openai/whisper-tiny' for faster testing
    'num_samples': 2000,
    'target_sample_rate': 16000,
    'train_split_ratio': 0.8,
    'output_dir': './whisper-medical-asr',
    'max_audio_length': 30.0,  # seconds
    'batch_size': 8,
    'num_epochs': 3,
    'learning_rate': 1e-5,
    'warmup_steps': 500,
    'eval_steps': 500,
    'save_steps': 1000,
    'gradient_accumulation_steps': 2
}

print(f"Using model: {CONFIG['model_name']}")
print(f"Target samples: {CONFIG['num_samples']}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

Using model: openai/whisper-tiny
Target samples: 2000
Device: cpu


## Module 1: Data Loading

In [17]:
class DataLoader:
    """Handles loading and streaming of the medical conversation dataset."""

    def __init__(self, dataset_name: str, num_samples: int, train_split_ratio: float = 0.8):
        self.dataset_name = dataset_name
        self.num_samples = num_samples
        self.train_split_ratio = train_split_ratio
        self.raw_data = None

    def load_dataset_subset(self) -> List[Dict]:
        """Load a subset of the dataset using direct parquet access to avoid audio decoding issues."""
        print(f"Loading {self.num_samples} samples from {self.dataset_name}...")

        try:
            # Method 1: Direct parquet loading to avoid audio decoding
            print("Attempting direct parquet loading...")
            samples = self._load_from_parquet()
            if samples:
                self.raw_data = samples
                return samples
        except Exception as e:
            print(f"Parquet loading failed: {e}")

        try:
            # Method 2: Use datasets library but process samples carefully
            print("Attempting careful streaming with manual audio handling...")
            samples = self._load_with_manual_processing()
            if samples:
                self.raw_data = samples
                return samples
        except Exception as e:
            print(f"Manual processing failed: {e}")

        raise Exception("All loading methods failed")

    def _load_from_parquet(self) -> List[Dict]:
        """Load dataset directly from parquet files."""
        import pandas as pd
        from huggingface_hub import list_repo_files

        # Get parquet files
        files = list_repo_files(self.dataset_name, repo_type="dataset")
        parquet_files = [f for f in files if f.endswith('.parquet') and 'train' in f]

        if not parquet_files:
            raise Exception("No parquet files found")

        print(f"Found {len(parquet_files)} parquet files")

        samples = []
        samples_per_file = self.num_samples // len(parquet_files) + 1

        for file_idx, file_path in enumerate(parquet_files):
            if len(samples) >= self.num_samples:
                break

            try:
                print(f"Loading from {file_path}...")
                file_url = f"https://huggingface.co/datasets/{self.dataset_name}/resolve/main/{file_path}"
                df = pd.read_parquet(file_url, engine='pyarrow')

                # Take only the samples we need from this file
                remaining_samples = self.num_samples - len(samples)
                df_subset = df.head(min(samples_per_file, remaining_samples))

                for _, row in df_subset.iterrows():
                    if len(samples) >= self.num_samples:
                        break

                    # Convert row to dict and handle audio separately
                    sample = {}
                    for col, value in row.items():
                        if col == 'audio':
                            # Keep audio as raw data structure
                            if isinstance(value, dict):
                                sample[col] = value
                            else:
                                # If it's a different format, create a dict structure
                                sample[col] = {'bytes': value, 'path': None, 'sampling_rate': 16000}
                        elif col == 'transcription':
                            # Map transcription to text field
                            sample['text'] = str(value) if value is not None else ""
                        else:
                            sample[col] = value

                    # Ensure we have a text field
                    if 'text' not in sample:
                        sample['text'] = ""

                    samples.append(sample)

                print(f"Loaded {len(samples)} samples so far...")

            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                continue

        return samples

    def _load_with_manual_processing(self) -> List[Dict]:
        """Fallback method using datasets library with careful processing."""
        # Try loading without any schema constraints
        dataset_stream = load_dataset(
            self.dataset_name,
            split='train',
            streaming=True
        )

        samples = []
        error_count = 0
        max_errors = 10  # Allow some errors before giving up

        iterator = iter(dataset_stream)

        while len(samples) < self.num_samples and error_count < max_errors:
            try:
                sample = next(iterator)

                # Process sample carefully
                processed_sample = {}

                for key, value in sample.items():
                    if key == 'audio':
                        # Try to keep audio without triggering decoding
                        if hasattr(value, 'keys') and callable(getattr(value, 'keys')):
                            # It's dict-like, extract raw data
                            processed_sample[key] = {
                                'bytes': getattr(value, 'bytes', None),
                                'path': getattr(value, 'path', None),
                                'sampling_rate': getattr(value, 'sampling_rate', 16000)
                            }
                        else:
                            # Store as-is and hope for the best
                            processed_sample[key] = value
                    elif key == 'transcription':
                        processed_sample['text'] = str(value) if value is not None else ""
                    else:
                        processed_sample[key] = value

                # Ensure text field exists
                if 'text' not in processed_sample:
                    processed_sample['text'] = ""

                samples.append(processed_sample)

                if (len(samples) + 1) % 100 == 0:
                    print(f"Loaded {len(samples)} samples...")

            except StopIteration:
                print("Reached end of dataset")
                break
            except Exception as e:
                error_count += 1
                print(f"Error processing sample {len(samples) + error_count}: {e}")
                if error_count >= max_errors:
                    print(f"Too many errors ({error_count}), stopping...")
                    break

        return samples

    def create_train_val_split(self, samples: List[Dict]) -> Tuple[Dataset, Dataset]:
        """Split the data into train and validation sets."""
        split_idx = int(len(samples) * self.train_split_ratio)

        train_samples = samples[:split_idx]
        val_samples = samples[split_idx:]

        # Convert to HF datasets
        train_dataset = Dataset.from_list(train_samples)
        val_dataset = Dataset.from_list(val_samples)

        print(f"Train samples: {len(train_dataset)}")
        print(f"Validation samples: {len(val_dataset)}")

        return train_dataset, val_dataset

    def inspect_sample(self, idx: int = 0) -> None:
        """Inspect a sample to understand the data structure."""
        if self.raw_data is None:
            print("No data loaded. Run load_dataset_subset() first.")
            return

        sample = self.raw_data[idx]
        print(f"Sample {idx} structure:")
        for key, value in sample.items():
            if key == 'audio':
                if isinstance(value, dict):
                    print(f"  {key}: dict with keys {list(value.keys())}")
                    for sub_key, sub_value in value.items():
                        if sub_key == 'bytes' and sub_value is not None:
                            print(f"    {sub_key}: {len(sub_value)} bytes")
                        else:
                            print(f"    {sub_key}: {sub_value}")
                else:
                    print(f"  {key}: {type(value)}")
            else:
                print(f"  {key}: {str(value)[:100]}..." if len(str(value)) > 100 else f"  {key}: {value}")

# Initialize data loader
data_loader = DataLoader(
    dataset_name=CONFIG['dataset_name'],
    num_samples=CONFIG['num_samples'],
    train_split_ratio=CONFIG['train_split_ratio']
)

In [16]:
# Test cell to examine the original dataset structure without triggering audio decoding
def check_dataset_structure_safe():
    """Check the actual structure of the dataset without triggering audio decoding."""
    print("Examining dataset structure safely...")

    try:
        # Load dataset info without streaming to avoid audio decoding
        from datasets import get_dataset_config_names, get_dataset_split_names
        from datasets.utils.info_utils import get_dataset_infos

        print("Dataset configs:", get_dataset_config_names(CONFIG['dataset_name']))
        print("Dataset splits:", get_dataset_split_names(CONFIG['dataset_name']))

        # Try to get dataset info
        dataset_infos = get_dataset_infos(CONFIG['dataset_name'])
        print("Dataset info keys:", list(dataset_infos.keys()))

        # If we have default config, print its features
        if 'default' in dataset_infos:
            features = dataset_infos['default'].features
            print("Dataset features:")
            for name, feature in features.items():
                print(f"  {name}: {feature}")

        return True

    except Exception as e:
        print(f"Error getting dataset info: {e}")

        # Alternative: try loading parquet files directly
        try:
            print("Trying direct parquet approach...")
            import pandas as pd
            from huggingface_hub import list_repo_files

            # List files in the repository
            files = list_repo_files(CONFIG['dataset_name'], repo_type="dataset")
            parquet_files = [f for f in files if f.endswith('.parquet') and 'train' in f]
            print(f"Found {len(parquet_files)} training parquet files")

            if parquet_files:
                # Load just the first few rows of the first parquet file to check structure
                first_file = parquet_files[0]
                print(f"Examining first file: {first_file}")

                # Load parquet file directly
                file_url = f"https://huggingface.co/datasets/{CONFIG['dataset_name']}/resolve/main/{first_file}"
                df_sample = pd.read_parquet(file_url, engine='pyarrow').head(3)

                print("Parquet file structure:")
                print(f"Columns: {list(df_sample.columns)}")
                print(f"Shape: {df_sample.shape}")

                for col in df_sample.columns:
                    print(f"  {col}: {df_sample[col].dtype}")
                    if col != 'audio':  # Skip audio column to avoid issues
                        print(f"    Sample value: {str(df_sample[col].iloc[0])[:100]}...")

                return df_sample

        except Exception as e2:
            print(f"Parquet approach also failed: {e2}")
            return None

# Run the safe structure check
print("Checking dataset structure safely...")
sample_structure = check_dataset_structure_safe()

Checking dataset structure safely...
Examining dataset structure safely...
Error getting dataset info: cannot import name 'get_dataset_infos' from 'datasets.utils.info_utils' (/usr/local/lib/python3.11/dist-packages/datasets/utils/info_utils.py)
Trying direct parquet approach...
Found 33 training parquet files
Examining first file: data/train-00000-of-00033.parquet
Parquet file structure:
Columns: ['audio', 'transcription']
Shape: (3, 2)
  audio: object
  transcription: object
    Sample value: Durysta is a medication used to reduce eye pressure in patients with open-angle glaucoma or ocular h...


In [18]:
# Load the dataset subset
samples = data_loader.load_dataset_subset()

# Inspect a sample to understand the structure
data_loader.inspect_sample(0)

# Test audio decoding on the first sample
print("\n" + "="*50)
print("Testing audio decoding...")
try:
    first_sample = samples[0]
    audio_info = first_sample['audio']

    # Try to decode the audio manually
    import io
    if 'bytes' in audio_info and audio_info['bytes'] is not None:
        print("Audio stored as bytes - decoding...")
        audio_bytes = audio_info['bytes']
        audio_file = io.BytesIO(audio_bytes)
        waveform, sample_rate = torchaudio.load(audio_file)
        print(f"Successfully decoded audio: shape={waveform.shape}, sample_rate={sample_rate}")
    elif 'path' in audio_info:
        print(f"Audio path: {audio_info['path']}")
        waveform, sample_rate = torchaudio.load(audio_info['path'])
        print(f"Successfully loaded audio: shape={waveform.shape}, sample_rate={sample_rate}")
    else:
        print(f"Unexpected audio format: {audio_info}")

except Exception as e:
    print(f"Error decoding audio: {e}")
    print("This might indicate an issue with the audio format.")

print("="*50)

Loading 2000 samples from Shamus/United-Syn-Med...
Attempting direct parquet loading...
Found 33 parquet files
Loading from data/train-00000-of-00033.parquet...
Loaded 61 samples so far...
Loading from data/train-00001-of-00033.parquet...
Loaded 122 samples so far...
Loading from data/train-00002-of-00033.parquet...
Loaded 183 samples so far...
Loading from data/train-00003-of-00033.parquet...
Loaded 244 samples so far...
Loading from data/train-00004-of-00033.parquet...
Loaded 305 samples so far...
Loading from data/train-00005-of-00033.parquet...
Loaded 366 samples so far...
Loading from data/train-00006-of-00033.parquet...
Loaded 427 samples so far...
Loading from data/train-00007-of-00033.parquet...
Loaded 488 samples so far...
Loading from data/train-00008-of-00033.parquet...
Loaded 549 samples so far...
Loading from data/train-00009-of-00033.parquet...
Loaded 610 samples so far...
Loading from data/train-00010-of-00033.parquet...
Loaded 671 samples so far...
Loading from data/tra

## Module 2: Audio Preprocessing

In [20]:
class AudioPreprocessor:
    """Handles audio preprocessing for Whisper model."""

    def __init__(self, model_name: str, target_sample_rate: int = 16000, max_audio_length: float = 30.0):
        self.model_name = model_name
        self.target_sample_rate = target_sample_rate
        self.max_audio_length = max_audio_length

        # Initialize Whisper components
        print(f"Loading Whisper processor for {model_name}...")
        self.processor = WhisperProcessor.from_pretrained(model_name)
        self.feature_extractor = self.processor.feature_extractor
        self.tokenizer = self.processor.tokenizer

        print(f"Processor loaded. Target sample rate: {self.target_sample_rate}Hz")

    def decode_audio_bytes(self, audio_info: Dict) -> Tuple[np.ndarray, int]:
        """Decode audio from bytes using torchaudio."""
        import io

        if 'bytes' in audio_info and audio_info['bytes'] is not None:
            # Load audio from bytes
            audio_bytes = audio_info['bytes']
            audio_file = io.BytesIO(audio_bytes)

            # Use torchaudio to load the audio
            waveform, sample_rate = torchaudio.load(audio_file)

            # Convert to numpy and handle multi-channel audio
            audio_array = waveform.numpy()
            if len(audio_array.shape) > 1:
                # Convert to mono by averaging channels
                audio_array = np.mean(audio_array, axis=0)

            return audio_array, sample_rate

        elif 'path' in audio_info and audio_info['path'] is not None:
            # Load audio from file path
            waveform, sample_rate = torchaudio.load(audio_info['path'])

            # Convert to numpy and handle multi-channel audio
            audio_array = waveform.numpy()
            if len(audio_array.shape) > 1:
                # Convert to mono by averaging channels
                audio_array = np.mean(audio_array, axis=0)

            return audio_array, sample_rate

        elif 'array' in audio_info:
            # Already decoded audio
            return np.array(audio_info['array']), audio_info.get('sampling_rate', 16000)

        else:
            raise ValueError(f"Unsupported audio format: {audio_info.keys()}")

    def resample_audio(self, audio_array: np.ndarray, original_sr: int) -> np.ndarray:
        """Resample audio to target sample rate."""
        if original_sr == self.target_sample_rate:
            return audio_array

        # Convert to tensor for resampling
        audio_tensor = torch.from_numpy(audio_array).float()
        if len(audio_tensor.shape) == 1:
            audio_tensor = audio_tensor.unsqueeze(0)  # Add channel dimension

        # Resample
        resampler = torchaudio.transforms.Resample(
            orig_freq=original_sr,
            new_freq=self.target_sample_rate
        )
        resampled = resampler(audio_tensor)

        return resampled.squeeze().numpy()

    def trim_or_pad_audio(self, audio_array: np.ndarray) -> np.ndarray:
        """Trim or pad audio to max length."""
        max_samples = int(self.max_audio_length * self.target_sample_rate)

        if len(audio_array) > max_samples:
            # Trim to max length
            return audio_array[:max_samples]
        elif len(audio_array) < max_samples:
            # Pad with zeros
            padding = max_samples - len(audio_array)
            return np.pad(audio_array, (0, padding), mode='constant')
        else:
            return audio_array

    def preprocess_batch(self, batch: Dict) -> Dict:
        """Preprocess a batch of audio samples."""
        # Extract audio data
        audio_data = []
        texts = []

        for i in range(len(batch['audio'])):
            try:
                # Handle audio - decode from bytes/path
                audio_info = batch['audio'][i]
                audio_array, sampling_rate = self.decode_audio_bytes(audio_info)

                # Preprocess audio
                audio_array = self.resample_audio(audio_array, sampling_rate)
                audio_array = self.trim_or_pad_audio(audio_array)
                audio_data.append(audio_array)

                # Handle text (check multiple possible field names)
                text_field = ""
                for field in ['text', 'transcription', 'sentence', 'transcript']:
                    if field in batch and i < len(batch[field]):
                        text_field = batch[field][i]
                        break

                if text_field is None:
                    text_field = ""  # Empty fallback
                texts.append(str(text_field))

            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                # Skip this sample or use empty data
                audio_data.append(np.zeros(int(self.max_audio_length * self.target_sample_rate)))
                texts.append("")

        # Process with Whisper feature extractor
        features = self.feature_extractor(
            audio_data,
            sampling_rate=self.target_sample_rate,
            return_tensors="pt",
            padding=True
        )

        # Tokenize texts - IMPORTANT: Use the correct format for Whisper
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                texts,
                padding=True,
                truncation=True,
                return_tensors="pt"
            ).input_ids

        # Replace padding token id with -100 so it's ignored in loss computation
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_features": features["input_features"],
            "labels": labels
        }

# Initialize preprocessor
preprocessor = AudioPreprocessor(
    model_name=CONFIG['model_name'],
    target_sample_rate=CONFIG['target_sample_rate'],
    max_audio_length=CONFIG['max_audio_length']
)

Loading Whisper processor for openai/whisper-tiny...
Processor loaded. Target sample rate: 16000Hz


In [None]:
# Test preprocessed data structure
def test_preprocessed_data():
    """Test that the preprocessed data has the correct structure for training."""
    print("Testing preprocessed data structure...")

    # Get a sample from the training dataset
    sample = train_dataset[0]
    print(f"Sample keys: {list(sample.keys())}")

    for key, value in sample.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: tensor with shape {value.shape}")
        else:
            print(f"{key}: {type(value)}")

    # Test data collator
    data_collator = whisper_trainer.create_data_collator()

    # Try to collate a small batch
    batch_samples = [train_dataset[i] for i in range(min(2, len(train_dataset)))]
    try:
        collated_batch = data_collator(batch_samples)
        print("\nCollated batch structure:")
        for key, value in collated_batch.items():
            if isinstance(value, torch.Tensor):
                print(f"{key}: tensor with shape {value.shape}")
            else:
                print(f"{key}: {type(value)}")
        print("✅ Data collation successful!")
        return True
    except Exception as e:
        print(f"❌ Data collation failed: {e}")
        return False

# Run the test
test_success = test_preprocessed_data()

In [21]:
# Create train/val splits
train_dataset, val_dataset = data_loader.create_train_val_split(samples)

# Apply preprocessing to datasets
print("Preprocessing training data...")
train_dataset = train_dataset.map(
    preprocessor.preprocess_batch,
    batched=True,
    batch_size=8,
    remove_columns=train_dataset.column_names
)

print("Preprocessing validation data...")
val_dataset = val_dataset.map(
    preprocessor.preprocess_batch,
    batched=True,
    batch_size=8,
    remove_columns=val_dataset.column_names
)

print("Preprocessing complete!")
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")

Train samples: 1600
Validation samples: 400
Preprocessing training data...


Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Preprocessing validation data...


Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Preprocessing complete!
Train dataset: 1600 samples
Val dataset: 400 samples


## Module 3: Model Training

In [22]:
class WhisperTrainer:
    """Handles Whisper model training and fine-tuning."""

    def __init__(self, model_name: str, output_dir: str):
        self.model_name = model_name
        self.output_dir = output_dir

        # Load model and processor
        print(f"Loading Whisper model: {model_name}...")
        self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
        self.processor = WhisperProcessor.from_pretrained(model_name)

        # Force decoder to use correct language tokens
        self.model.config.forced_decoder_ids = None
        self.model.config.suppress_tokens = []

        print(f"Model loaded. Parameters: {self.model.num_parameters():,}")

    def create_data_collator(self):
        """Create data collator for training."""
        # Use a simple data collator that doesn't interfere with our preprocessing
        from transformers import DefaultDataCollator
        return DefaultDataCollator()

    def setup_training_args(self, config: Dict) -> TrainingArguments:
        """Setup training arguments."""
        return TrainingArguments(
            output_dir=self.output_dir,
            per_device_train_batch_size=config['batch_size'],
            per_device_eval_batch_size=config['batch_size'],
            gradient_accumulation_steps=config['gradient_accumulation_steps'],
            learning_rate=config['learning_rate'],
            warmup_steps=config['warmup_steps'],
            num_train_epochs=config['num_epochs'],
            eval_strategy="steps",
            eval_steps=config['eval_steps'],
            save_steps=config['save_steps'],
            logging_steps=50,
            save_total_limit=2,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            fp16=True if torch.cuda.is_available() else False,
            dataloader_pin_memory=False,
            report_to=["tensorboard"],
            push_to_hub=False,
            remove_unused_columns=False,
            dataloader_num_workers=0,  # Avoid multiprocessing issues
        )

    def create_trainer(self, train_dataset: Dataset, val_dataset: Dataset, config: Dict) -> Trainer:
        """Create and configure the trainer."""
        training_args = self.setup_training_args(config)
        data_collator = self.create_data_collator()

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
            tokenizer=self.processor.tokenizer
        )

        return trainer

    def train(self, train_dataset: Dataset, val_dataset: Dataset, config: Dict):
        """Train the model."""
        print("Setting up trainer...")
        trainer = self.create_trainer(train_dataset, val_dataset, config)

        print("Starting training...")
        trainer.train()

        print("Saving final model...")
        trainer.save_model()
        self.processor.save_pretrained(self.output_dir)

        return trainer

# Initialize trainer
whisper_trainer = WhisperTrainer(
    model_name=CONFIG['model_name'],
    output_dir=CONFIG['output_dir']
)

Loading Whisper model: openai/whisper-tiny...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

Model loaded. Parameters: 37,760,640


In [None]:
# Start training
print("Starting model training...")
trainer = whisper_trainer.train(train_dataset, val_dataset, CONFIG)
print("Training completed!")

Starting model training...
Setting up trainer...
Starting training...


Exception in thread Thread-6 (_loader_worker):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch_xla/distributed/parallel_loader.py", line 165, in _loader_worker
    _, data = next(data_iter)
              ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/accelerate/data_loader.py", line 567, in __iter__
    current_batch = next(dataloader_iter)
                    ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 708, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 764, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^

## Module 4: Evaluation

In [None]:
class ModelEvaluator:
    """Handles model evaluation and metrics computation."""

    def __init__(self, model_path: str):
        self.model_path = model_path

        # Load trained model
        print(f"Loading trained model from {model_path}...")
        self.model = WhisperForConditionalGeneration.from_pretrained(model_path)
        self.processor = WhisperProcessor.from_pretrained(model_path)

        # Load WER metric
        self.wer_metric = evaluate.load("wer")

        print("Model and metrics loaded for evaluation.")

    def transcribe_audio(self, audio_features: torch.Tensor) -> str:
        """Transcribe audio features to text."""
        with torch.no_grad():
            # Generate transcription
            predicted_ids = self.model.generate(
                audio_features,
                max_length=225,
                num_beams=1,
                do_sample=False
            )

            # Decode to text
            transcription = self.processor.tokenizer.batch_decode(
                predicted_ids,
                skip_special_tokens=True
            )[0]

            return transcription.strip()

    def evaluate_dataset(self, dataset: Dataset, max_samples: Optional[int] = None) -> Dict:
        """Evaluate model on a dataset and compute metrics."""
        print("Running evaluation...")

        predictions = []
        references = []

        eval_samples = min(len(dataset), max_samples) if max_samples else len(dataset)

        self.model.eval()

        for i in range(eval_samples):
            sample = dataset[i]

            # Get input features
            input_features = sample['input_features'].unsqueeze(0)

            # Get reference text (decode labels)
            reference = self.processor.tokenizer.decode(
                sample['labels'],
                skip_special_tokens=True
            ).strip()

            # Generate prediction
            prediction = self.transcribe_audio(input_features)

            predictions.append(prediction)
            references.append(reference)

            if (i + 1) % 50 == 0:
                print(f"Evaluated {i + 1}/{eval_samples} samples...")

        # Compute WER
        wer_score = self.wer_metric.compute(predictions=predictions, references=references)

        results = {
            'wer': wer_score,
            'num_samples': len(predictions),
            'predictions': predictions[:5],  # First 5 for inspection
            'references': references[:5]
        }

        return results

    def print_evaluation_results(self, results: Dict):
        """Print evaluation results in a readable format."""
        print("\n" + "="*50)
        print("EVALUATION RESULTS")
        print("="*50)
        print(f"Word Error Rate (WER): {results['wer']:.4f}")
        print(f"Samples evaluated: {results['num_samples']}")

        print("\nSample Predictions vs References:")
        print("-"*50)

        for i, (pred, ref) in enumerate(zip(results['predictions'], results['references'])):
            print(f"Sample {i+1}:")
            print(f"  Reference: {ref}")
            print(f"  Prediction: {pred}")
            print()

# Initialize evaluator (will load the trained model)
evaluator = ModelEvaluator(CONFIG['output_dir'])

In [None]:
# Evaluate on validation set
print("Evaluating model on validation set...")
eval_results = evaluator.evaluate_dataset(val_dataset, max_samples=100)  # Limit for speed

# Print results
evaluator.print_evaluation_results(eval_results)

## Module 5: Inference and Testing

In [None]:
class InferenceEngine:
    """Handles inference on new audio samples."""

    def __init__(self, model_path: str):
        self.model_path = model_path

        # Load model and processor
        print(f"Loading model for inference from {model_path}...")
        self.model = WhisperForConditionalGeneration.from_pretrained(model_path)
        self.processor = WhisperProcessor.from_pretrained(model_path)

        # Set to eval mode
        self.model.eval()

        print("Inference engine ready.")

    def transcribe_from_features(self, input_features: torch.Tensor) -> str:
        """Transcribe audio from preprocessed features."""
        with torch.no_grad():
            # Generate transcription
            predicted_ids = self.model.generate(
                input_features,
                max_length=225,
                num_beams=2,  # Slightly better quality
                do_sample=False,
                temperature=1.0
            )

            # Decode to text
            transcription = self.processor.tokenizer.batch_decode(
                predicted_ids,
                skip_special_tokens=True
            )[0]

            return transcription.strip()

    def transcribe_raw_audio(self, audio_array: np.ndarray, sampling_rate: int) -> str:
        """Transcribe from raw audio array."""
        # Preprocess audio
        if sampling_rate != 16000:
            # Resample to 16kHz
            audio_tensor = torch.from_numpy(audio_array).float()
            if len(audio_tensor.shape) == 1:
                audio_tensor = audio_tensor.unsqueeze(0)

            resampler = torchaudio.transforms.Resample(
                orig_freq=sampling_rate,
                new_freq=16000
            )
            audio_array = resampler(audio_tensor).squeeze().numpy()

        # Extract features
        features = self.processor.feature_extractor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        )

        # Transcribe
        return self.transcribe_from_features(features["input_features"])

    def demo_inference(self, dataset: Dataset, num_samples: int = 3):
        """Run demo inference on dataset samples."""
        print(f"\nRunning demo inference on {num_samples} samples...")
        print("="*60)

        for i in range(min(num_samples, len(dataset))):
            sample = dataset[i]

            # Get reference
            reference = self.processor.tokenizer.decode(
                sample['labels'],
                skip_special_tokens=True
            ).strip()

            # Get prediction
            input_features = sample['input_features'].unsqueeze(0)
            prediction = self.transcribe_from_features(input_features)

            # Display results
            print(f"\nSample {i+1}:")
            print(f"Reference:  {reference}")
            print(f"Prediction: {prediction}")
            print("-" * 60)

# Initialize inference engine
inference_engine = InferenceEngine(CONFIG['output_dir'])

In [None]:
# Run demo inference
inference_engine.demo_inference(val_dataset, num_samples=5)

## Summary and Next Steps

In [None]:
# Summary of the training run
print("\n" + "="*60)
print("BASELINE ASR MODEL TRAINING COMPLETE")
print("="*60)
print(f"Model: {CONFIG['model_name']}")
print(f"Dataset: {CONFIG['dataset_name']}")
print(f"Samples used: {CONFIG['num_samples']}")
print(f"Training epochs: {CONFIG['num_epochs']}")
print(f"Model saved to: {CONFIG['output_dir']}")
print(f"Final WER: {eval_results['wer']:.4f}")

print("\nNext Steps:")
print("1. Fine-tune hyperparameters for better WER")
print("2. Increase dataset size for more robust training")
print("3. Implement EHR structuring pipeline")
print("4. Add domain-specific medical vocabulary")
print("5. Evaluate on held-out test set")

print("\nModel Files:")
import os
if os.path.exists(CONFIG['output_dir']):
    files = os.listdir(CONFIG['output_dir'])
    for file in files:
        print(f"  - {file}")
else:
    print("  Model directory not found.")

## Optional: Test with Different Model Size

In [None]:
# Uncomment and run this cell to quickly test with Whisper Tiny for faster iteration

# CONFIG_TINY = CONFIG.copy()
# CONFIG_TINY['model_name'] = 'openai/whisper-tiny'
# CONFIG_TINY['output_dir'] = './whisper-tiny-medical-asr'
# CONFIG_TINY['num_epochs'] = 1  # Faster training
# CONFIG_TINY['batch_size'] = 16  # Larger batch for smaller model

# print("Testing with Whisper Tiny for faster iteration...")
# print(f"Model: {CONFIG_TINY['model_name']}")

# # Quick training with tiny model
# tiny_trainer = WhisperTrainer(
#     model_name=CONFIG_TINY['model_name'],
#     output_dir=CONFIG_TINY['output_dir']
# )

# # Use a smaller subset for quick testing
# tiny_train = train_dataset.select(range(100))
# tiny_val = val_dataset.select(range(20))

# trainer_tiny = tiny_trainer.train(tiny_train, tiny_val, CONFIG_TINY)
# print("Tiny model training complete!")