This is a test line

# Doctor-Patient ASR Baseline Model

This notebook implements a baseline ASR model for doctor-patient conversations using the Hugging Face dataset `Shamus/United-Syn-Med`.

## Pipeline Overview:
1. **Data Loading**: Stream subset from HF dataset
2. **Preprocessing**: Audio resampling + feature extraction
3. **Training**: Fine-tune Whisper Small/Tiny
4. **Evaluation**: Compute WER metrics
5. **Inference**: Test on validation samples

Target: ~2000 samples (~500MB) for baseline prototype

## Setup and Dependencies

In [14]:
# Install required packages
!pip install datasets transformers torch torchaudio accelerate evaluate jiwer tensorboard
!pip install --upgrade huggingface_hub
# Install additional audio dependencies if needed
!pip install soundfile librosa



In [15]:
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
import evaluate
from typing import Dict, List, Optional, Tuple
import warnings
warnings.filterwarnings('ignore')

# Configuration
CONFIG = {
    'dataset_name': 'Shamus/United-Syn-Med',
    'model_name': 'openai/whisper-tiny',  # Change to 'openai/whisper-tiny' for faster testing
    'num_samples': 700,
    'target_sample_rate': 16000,
    'train_split_ratio': 0.8,
    'output_dir': './whisper-medical-asr',
    'max_audio_length': 30.0,  # seconds
    'batch_size': 8,
    'num_epochs': 3,
    'learning_rate': 1e-5,
    'warmup_steps': 500,
    'eval_steps': 500,
    'save_steps': 1000,
    'gradient_accumulation_steps': 2
}

print(f"Using model: {CONFIG['model_name']}")
print(f"Target samples: {CONFIG['num_samples']}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

Using model: openai/whisper-tiny
Target samples: 700
Device: cuda


## Module 1: Data Loading

In [16]:
class DataLoader:
    """Handles loading and streaming of the medical conversation dataset."""

    def __init__(self, dataset_name: str, num_samples: int, train_split_ratio: float = 0.8):
        self.dataset_name = dataset_name
        self.num_samples = num_samples
        self.train_split_ratio = train_split_ratio
        self.raw_data = None

    def load_dataset_subset(self) -> List[Dict]:
        """Load a subset of the dataset using direct parquet access to avoid audio decoding issues."""
        print(f"Loading {self.num_samples} samples from {self.dataset_name}...")

        try:
            # Method 1: Direct parquet loading to avoid audio decoding
            print("Attempting direct parquet loading...")
            samples = self._load_from_parquet()
            if samples:
                self.raw_data = samples
                return samples
        except Exception as e:
            print(f"Parquet loading failed: {e}")

        try:
            # Method 2: Use datasets library but process samples carefully
            print("Attempting careful streaming with manual audio handling...")
            samples = self._load_with_manual_processing()
            if samples:
                self.raw_data = samples
                return samples
        except Exception as e:
            print(f"Manual processing failed: {e}")

        raise Exception("All loading methods failed")

    def _load_from_parquet(self) -> List[Dict]:
        """Load dataset directly from parquet files."""
        import pandas as pd
        from huggingface_hub import list_repo_files

        # Get parquet files
        files = list_repo_files(self.dataset_name, repo_type="dataset")
        parquet_files = [f for f in files if f.endswith('.parquet') and 'train' in f]

        if not parquet_files:
            raise Exception("No parquet files found")

        print(f"Found {len(parquet_files)} parquet files")

        samples = []
        samples_per_file = self.num_samples // len(parquet_files) + 1

        for file_idx, file_path in enumerate(parquet_files):
            if len(samples) >= self.num_samples:
                break

            try:
                print(f"Loading from {file_path}...")
                file_url = f"https://huggingface.co/datasets/{self.dataset_name}/resolve/main/{file_path}"
                df = pd.read_parquet(file_url, engine='pyarrow')

                # Take only the samples we need from this file
                remaining_samples = self.num_samples - len(samples)
                df_subset = df.head(min(samples_per_file, remaining_samples))

                for _, row in df_subset.iterrows():
                    if len(samples) >= self.num_samples:
                        break

                    # Convert row to dict and handle audio separately
                    sample = {}
                    for col, value in row.items():
                        if col == 'audio':
                            # Keep audio as raw data structure
                            if isinstance(value, dict):
                                sample[col] = value
                            else:
                                # If it's a different format, create a dict structure
                                sample[col] = {'bytes': value, 'path': None, 'sampling_rate': 16000}
                        elif col == 'transcription':
                            # Map transcription to text field
                            sample['text'] = str(value) if value is not None else ""
                        else:
                            sample[col] = value

                    # Ensure we have a text field
                    if 'text' not in sample:
                        sample['text'] = ""

                    samples.append(sample)

                print(f"Loaded {len(samples)} samples so far...")

            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                continue

        return samples

    def _load_with_manual_processing(self) -> List[Dict]:
        """Fallback method using datasets library with careful processing."""
        # Try loading without any schema constraints
        dataset_stream = load_dataset(
            self.dataset_name,
            split='train',
            streaming=True
        )

        samples = []
        error_count = 0
        max_errors = 10  # Allow some errors before giving up

        iterator = iter(dataset_stream)

        while len(samples) < self.num_samples and error_count < max_errors:
            try:
                sample = next(iterator)

                # Process sample carefully
                processed_sample = {}

                for key, value in sample.items():
                    if key == 'audio':
                        # Try to keep audio without triggering decoding
                        if hasattr(value, 'keys') and callable(getattr(value, 'keys')):
                            # It's dict-like, extract raw data
                            processed_sample[key] = {
                                'bytes': getattr(value, 'bytes', None),
                                'path': getattr(value, 'path', None),
                                'sampling_rate': getattr(value, 'sampling_rate', 16000)
                            }
                        else:
                            # Store as-is and hope for the best
                            processed_sample[key] = value
                    elif key == 'transcription':
                        processed_sample['text'] = str(value) if value is not None else ""
                    else:
                        processed_sample[key] = value

                # Ensure text field exists
                if 'text' not in processed_sample:
                    processed_sample['text'] = ""

                samples.append(processed_sample)

                if (len(samples) + 1) % 100 == 0:
                    print(f"Loaded {len(samples)} samples...")

            except StopIteration:
                print("Reached end of dataset")
                break
            except Exception as e:
                error_count += 1
                print(f"Error processing sample {len(samples) + error_count}: {e}")
                if error_count >= max_errors:
                    print(f"Too many errors ({error_count}), stopping...")
                    break

        return samples

    def create_train_val_split(self, samples: List[Dict]) -> Tuple[Dataset, Dataset]:
        """Split the data into train and validation sets."""
        split_idx = int(len(samples) * self.train_split_ratio)

        train_samples = samples[:split_idx]
        val_samples = samples[split_idx:]

        # Convert to HF datasets
        train_dataset = Dataset.from_list(train_samples)
        val_dataset = Dataset.from_list(val_samples)

        print(f"Train samples: {len(train_dataset)}")
        print(f"Validation samples: {len(val_dataset)}")

        return train_dataset, val_dataset

    def inspect_sample(self, idx: int = 0) -> None:
        """Inspect a sample to understand the data structure."""
        if self.raw_data is None:
            print("No data loaded. Run load_dataset_subset() first.")
            return

        sample = self.raw_data[idx]
        print(f"Sample {idx} structure:")
        for key, value in sample.items():
            if key == 'audio':
                if isinstance(value, dict):
                    print(f"  {key}: dict with keys {list(value.keys())}")
                    for sub_key, sub_value in value.items():
                        if sub_key == 'bytes' and sub_value is not None:
                            print(f"    {sub_key}: {len(sub_value)} bytes")
                        else:
                            print(f"    {sub_key}: {sub_value}")
                else:
                    print(f"  {key}: {type(value)}")
            else:
                print(f"  {key}: {str(value)[:100]}..." if len(str(value)) > 100 else f"  {key}: {value}")

# Initialize data loader
data_loader = DataLoader(
    dataset_name=CONFIG['dataset_name'],
    num_samples=CONFIG['num_samples'],
    train_split_ratio=CONFIG['train_split_ratio']
)

In [17]:
# Test cell to examine the original dataset structure without triggering audio decoding
def check_dataset_structure_safe():
    """Check the actual structure of the dataset without triggering audio decoding."""
    print("Examining dataset structure safely...")

    try:
        # Load dataset info without streaming to avoid audio decoding
        from datasets import get_dataset_config_names, get_dataset_split_names
        from datasets.utils.info_utils import get_dataset_infos

        print("Dataset configs:", get_dataset_config_names(CONFIG['dataset_name']))
        print("Dataset splits:", get_dataset_split_names(CONFIG['dataset_name']))

        # Try to get dataset info
        dataset_infos = get_dataset_infos(CONFIG['dataset_name'])
        print("Dataset info keys:", list(dataset_infos.keys()))

        # If we have default config, print its features
        if 'default' in dataset_infos:
            features = dataset_infos['default'].features
            print("Dataset features:")
            for name, feature in features.items():
                print(f"  {name}: {feature}")

        return True

    except Exception as e:
        print(f"Error getting dataset info: {e}")

        # Alternative: try loading parquet files directly
        try:
            print("Trying direct parquet approach...")
            import pandas as pd
            from huggingface_hub import list_repo_files

            # List files in the repository
            files = list_repo_files(CONFIG['dataset_name'], repo_type="dataset")
            parquet_files = [f for f in files if f.endswith('.parquet') and 'train' in f]
            print(f"Found {len(parquet_files)} training parquet files")

            if parquet_files:
                # Load just the first few rows of the first parquet file to check structure
                first_file = parquet_files[0]
                print(f"Examining first file: {first_file}")

                # Load parquet file directly
                file_url = f"https://huggingface.co/datasets/{CONFIG['dataset_name']}/resolve/main/{first_file}"
                df_sample = pd.read_parquet(file_url, engine='pyarrow').head(3)

                print("Parquet file structure:")
                print(f"Columns: {list(df_sample.columns)}")
                print(f"Shape: {df_sample.shape}")

                for col in df_sample.columns:
                    print(f"  {col}: {df_sample[col].dtype}")
                    if col != 'audio':  # Skip audio column to avoid issues
                        print(f"    Sample value: {str(df_sample[col].iloc[0])[:100]}...")

                return df_sample

        except Exception as e2:
            print(f"Parquet approach also failed: {e2}")
            return None

# Run the safe structure check
print("Checking dataset structure safely...")
sample_structure = check_dataset_structure_safe()

Checking dataset structure safely...
Examining dataset structure safely...
Error getting dataset info: cannot import name 'get_dataset_infos' from 'datasets.utils.info_utils' (/usr/local/lib/python3.11/dist-packages/datasets/utils/info_utils.py)
Trying direct parquet approach...
Found 33 training parquet files
Examining first file: data/train-00000-of-00033.parquet
Parquet file structure:
Columns: ['audio', 'transcription']
Shape: (3, 2)
  audio: object
  transcription: object
    Sample value: Durysta is a medication used to reduce eye pressure in patients with open-angle glaucoma or ocular h...


In [18]:
# Load the dataset subset
samples = data_loader.load_dataset_subset()

# Inspect a sample to understand the structure
data_loader.inspect_sample(0)

# Test audio decoding on the first sample
print("\n" + "="*50)
print("Testing audio decoding...")
try:
    first_sample = samples[0]
    audio_info = first_sample['audio']

    # Try to decode the audio manually
    import io
    if 'bytes' in audio_info and audio_info['bytes'] is not None:
        print("Audio stored as bytes - decoding...")
        audio_bytes = audio_info['bytes']
        audio_file = io.BytesIO(audio_bytes)
        waveform, sample_rate = torchaudio.load(audio_file)
        print(f"Successfully decoded audio: shape={waveform.shape}, sample_rate={sample_rate}")
    elif 'path' in audio_info:
        print(f"Audio path: {audio_info['path']}")
        waveform, sample_rate = torchaudio.load(audio_info['path'])
        print(f"Successfully loaded audio: shape={waveform.shape}, sample_rate={sample_rate}")
    else:
        print(f"Unexpected audio format: {audio_info}")

except Exception as e:
    print(f"Error decoding audio: {e}")
    print("This might indicate an issue with the audio format.")

print("="*50)

Loading 700 samples from Shamus/United-Syn-Med...
Attempting direct parquet loading...
Found 33 parquet files
Loading from data/train-00000-of-00033.parquet...
Loaded 22 samples so far...
Loading from data/train-00001-of-00033.parquet...
Loaded 44 samples so far...
Loading from data/train-00002-of-00033.parquet...
Loaded 66 samples so far...
Loading from data/train-00003-of-00033.parquet...
Loaded 88 samples so far...
Loading from data/train-00004-of-00033.parquet...
Loaded 110 samples so far...
Loading from data/train-00005-of-00033.parquet...
Loaded 132 samples so far...
Loading from data/train-00006-of-00033.parquet...
Loaded 154 samples so far...
Loading from data/train-00007-of-00033.parquet...
Loaded 176 samples so far...
Loading from data/train-00008-of-00033.parquet...
Loaded 198 samples so far...
Loading from data/train-00009-of-00033.parquet...
Loaded 220 samples so far...
Loading from data/train-00010-of-00033.parquet...
Loaded 242 samples so far...
Loading from data/train-0

## Module 2: Audio Preprocessing

In [19]:
class AudioPreprocessor:
    """Handles audio preprocessing for Whisper model."""

    def __init__(self, model_name: str, target_sample_rate: int = 16000, max_audio_length: float = 30.0):
        self.model_name = model_name
        self.target_sample_rate = target_sample_rate
        self.max_audio_length = max_audio_length

        # Initialize Whisper components
        print(f"Loading Whisper processor for {model_name}...")
        self.processor = WhisperProcessor.from_pretrained(model_name)
        self.feature_extractor = self.processor.feature_extractor
        self.tokenizer = self.processor.tokenizer

        print(f"Processor loaded. Target sample rate: {self.target_sample_rate}Hz")

    def decode_audio_bytes(self, audio_info: Dict) -> Tuple[np.ndarray, int]:
        """Decode audio from bytes using torchaudio."""
        import io

        if 'bytes' in audio_info and audio_info['bytes'] is not None:
            # Load audio from bytes
            audio_bytes = audio_info['bytes']
            audio_file = io.BytesIO(audio_bytes)

            # Use torchaudio to load the audio
            waveform, sample_rate = torchaudio.load(audio_file)

            # Convert to numpy and handle multi-channel audio
            audio_array = waveform.numpy()
            if len(audio_array.shape) > 1:
                # Convert to mono by averaging channels
                audio_array = np.mean(audio_array, axis=0)

            return audio_array, sample_rate

        elif 'path' in audio_info and audio_info['path'] is not None:
            # Load audio from file path
            waveform, sample_rate = torchaudio.load(audio_info['path'])

            # Convert to numpy and handle multi-channel audio
            audio_array = waveform.numpy()
            if len(audio_array.shape) > 1:
                # Convert to mono by averaging channels
                audio_array = np.mean(audio_array, axis=0)

            return audio_array, sample_rate

        elif 'array' in audio_info:
            # Already decoded audio
            return np.array(audio_info['array']), audio_info.get('sampling_rate', 16000)

        else:
            raise ValueError(f"Unsupported audio format: {audio_info.keys()}")

    def resample_audio(self, audio_array: np.ndarray, original_sr: int) -> np.ndarray:
        """Resample audio to target sample rate."""
        if original_sr == self.target_sample_rate:
            return audio_array

        # Convert to tensor for resampling
        audio_tensor = torch.from_numpy(audio_array).float()
        if len(audio_tensor.shape) == 1:
            audio_tensor = audio_tensor.unsqueeze(0)  # Add channel dimension

        # Resample
        resampler = torchaudio.transforms.Resample(
            orig_freq=original_sr,
            new_freq=self.target_sample_rate
        )
        resampled = resampler(audio_tensor)

        return resampled.squeeze().numpy()

    def trim_or_pad_audio(self, audio_array: np.ndarray) -> np.ndarray:
        """Trim or pad audio to max length."""
        max_samples = int(self.max_audio_length * self.target_sample_rate)

        if len(audio_array) > max_samples:
            # Trim to max length
            return audio_array[:max_samples]
        elif len(audio_array) < max_samples:
            # Pad with zeros
            padding = max_samples - len(audio_array)
            return np.pad(audio_array, (0, padding), mode='constant')
        else:
            return audio_array

    def preprocess_batch(self, batch: Dict) -> Dict:
        """Preprocess a batch of audio samples."""
        # Extract audio data
        audio_data = []
        texts = []

        for i in range(len(batch['audio'])):
            try:
                # Handle audio - decode from bytes/path
                audio_info = batch['audio'][i]
                audio_array, sampling_rate = self.decode_audio_bytes(audio_info)

                # Preprocess audio
                audio_array = self.resample_audio(audio_array, sampling_rate)
                audio_array = self.trim_or_pad_audio(audio_array)
                audio_data.append(audio_array)

                # Handle text (check multiple possible field names)
                text_field = ""
                for field in ['text', 'transcription', 'sentence', 'transcript']:
                    if field in batch and i < len(batch[field]):
                        text_field = batch[field][i]
                        break

                if text_field is None:
                    text_field = ""  # Empty fallback
                texts.append(str(text_field))

            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                # Skip this sample or use empty data
                audio_data.append(np.zeros(int(self.max_audio_length * self.target_sample_rate)))
                texts.append("")

        # Process with Whisper feature extractor
        features = self.feature_extractor(
            audio_data,
            sampling_rate=self.target_sample_rate,
            return_tensors="pt",
            padding=True
        )

        # Tokenize texts with proper padding and truncation for Whisper
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                texts,
                max_length=448,  # Whisper's max sequence length
                padding="max_length",  # Force consistent padding
                truncation=True,
                return_tensors="pt"
            ).input_ids

        # Replace padding token id with -100 so it's ignored in loss computation
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_features": features["input_features"],
            "labels": labels
        }

# Initialize preprocessor
preprocessor = AudioPreprocessor(
    model_name=CONFIG['model_name'],
    target_sample_rate=CONFIG['target_sample_rate'],
    max_audio_length=CONFIG['max_audio_length']
)

Loading Whisper processor for openai/whisper-tiny...
Processor loaded. Target sample rate: 16000Hz


In [20]:
# Create train/val splits
train_dataset, val_dataset = data_loader.create_train_val_split(samples)

# Apply preprocessing to datasets
print("Preprocessing training data...")
train_dataset = train_dataset.map(
    preprocessor.preprocess_batch,
    batched=True,
    batch_size=8,
    remove_columns=train_dataset.column_names
)

print("Preprocessing validation data...")
val_dataset = val_dataset.map(
    preprocessor.preprocess_batch,
    batched=True,
    batch_size=8,
    remove_columns=val_dataset.column_names
)

print("Preprocessing complete!")
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")

Train samples: 560
Validation samples: 140
Preprocessing training data...


Map:   0%|          | 0/560 [00:00<?, ? examples/s]

Preprocessing validation data...


Map:   0%|          | 0/140 [00:00<?, ? examples/s]

Preprocessing complete!
Train dataset: 560 samples
Val dataset: 140 samples


## Module 3: Model Training

In [21]:
class WhisperTrainer:
    """Handles Whisper model training and fine-tuning."""

    def __init__(self, model_name: str, output_dir: str):
        self.model_name = model_name
        self.output_dir = output_dir

        # Load model and processor
        print(f"Loading Whisper model: {model_name}...")
        self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
        self.processor = WhisperProcessor.from_pretrained(model_name)

        # Force decoder to use correct language tokens
        self.model.config.forced_decoder_ids = None
        self.model.config.suppress_tokens = []

        print(f"Model loaded. Parameters: {self.model.num_parameters():,}")

    def create_data_collator(self):
        """Create data collator for training."""
        # Use DefaultDataCollator since we're handling padding in preprocessing
        from transformers import DefaultDataCollator
        return DefaultDataCollator()

    def setup_training_args(self, config: Dict) -> TrainingArguments:
        """Setup training arguments."""
        return TrainingArguments(
            output_dir=self.output_dir,
            per_device_train_batch_size=config['batch_size'],
            per_device_eval_batch_size=config['batch_size'],
            gradient_accumulation_steps=config['gradient_accumulation_steps'],
            learning_rate=config['learning_rate'],
            warmup_steps=config['warmup_steps'],
            num_train_epochs=config['num_epochs'],
            eval_strategy="steps",
            eval_steps=config['eval_steps'],
            save_steps=config['save_steps'],
            logging_steps=50,
            save_total_limit=2,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            fp16=True if torch.cuda.is_available() else False,
            dataloader_pin_memory=False,
            dataloader_num_workers=0,  # Avoid multiprocessing issues
            report_to=["tensorboard"],
            push_to_hub=False,
            remove_unused_columns=False
        )

    def create_trainer(self, train_dataset: Dataset, val_dataset: Dataset, config: Dict) -> Trainer:
        """Create and configure the trainer."""
        training_args = self.setup_training_args(config)
        data_collator = self.create_data_collator()

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
            tokenizer=self.processor.tokenizer
        )

        return trainer

    def train(self, train_dataset: Dataset, val_dataset: Dataset, config: Dict):
        """Train the model."""
        print("Setting up trainer...")
        trainer = self.create_trainer(train_dataset, val_dataset, config)

        print("Starting training...")
        trainer.train()

        print("Saving final model...")
        trainer.save_model()
        self.processor.save_pretrained(self.output_dir)

        return trainer

# Initialize trainer
whisper_trainer = WhisperTrainer(
    model_name=CONFIG['model_name'],
    output_dir=CONFIG['output_dir']
)

Loading Whisper model: openai/whisper-tiny...
Model loaded. Parameters: 37,760,640


In [22]:
# Test preprocessed data structure
def test_preprocessed_data():
    """Test that the preprocessed data has the correct structure for training."""
    print("Testing preprocessed data structure...")

    # Get a few samples from the training dataset
    sample_indices = [0, 1, 2] if len(train_dataset) >= 3 else [0]

    for idx in sample_indices:
        sample = train_dataset[idx]
        print(f"\nSample {idx} structure:")
        for key, value in sample.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: tensor with shape {value.shape}, dtype {value.dtype}")
            else:
                print(f"  {key}: {type(value)}")

    # Test data collator
    print("\nTesting data collator...")
    data_collator = whisper_trainer.create_data_collator()

    # Try to collate a small batch
    batch_samples = [train_dataset[i] for i in range(min(3, len(train_dataset)))]
    try:
        collated_batch = data_collator(batch_samples)
        print("\nCollated batch structure:")
        for key, value in collated_batch.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: tensor with shape {value.shape}, dtype {value.dtype}")
            else:
                print(f"  {key}: {type(value)}")
        print("✅ Data collation successful!")
        return True
    except Exception as e:
        print(f"❌ Data collation failed: {e}")
        return False

# Run the test
test_success = test_preprocessed_data()

if test_success:
    print("\n🎉 Data preprocessing and collation working correctly!")
    print("Ready to start training...")
else:
    print("\n⚠️ Issues found with data preprocessing. Please check the errors above.")

Testing preprocessed data structure...

Sample 0 structure:
  input_features: <class 'list'>
  labels: <class 'list'>

Sample 1 structure:
  input_features: <class 'list'>
  labels: <class 'list'>

Sample 2 structure:
  input_features: <class 'list'>
  labels: <class 'list'>

Testing data collator...

Collated batch structure:
  input_features: tensor with shape torch.Size([3, 80, 3000]), dtype torch.float32
  labels: tensor with shape torch.Size([3, 448]), dtype torch.int64
✅ Data collation successful!

🎉 Data preprocessing and collation working correctly!
Ready to start training...


In [23]:
# Start training
print("Starting model training...")
trainer = whisper_trainer.train(train_dataset, val_dataset, CONFIG)
print("Training completed!")

Starting model training...
Setting up trainer...
Starting training...


Step,Training Loss,Validation Loss


Saving final model...
Training completed!


## Module 4: Evaluation

In [26]:
class ModelEvaluator:
    """Handles model evaluation and metrics computation."""
    
    def __init__(self, model_path: str):
        self.model_path = model_path
        
        # Load trained model
        print(f"Loading trained model from {model_path}...")
        self.model = WhisperForConditionalGeneration.from_pretrained(model_path)
        self.processor = WhisperProcessor.from_pretrained(model_path)
        
        # Fix generation config issues
        self.model.generation_config.forced_decoder_ids = None
        self.model.generation_config.suppress_tokens = []
        
        # Load WER metric
        self.wer_metric = evaluate.load("wer")
        
        print("Model and metrics loaded for evaluation.")
    
    def transcribe_audio(self, audio_features: torch.Tensor) -> str:
        """Transcribe audio features to text."""
        with torch.no_grad():
            # Generate transcription with explicit parameters to avoid conflicts
            predicted_ids = self.model.generate(
                audio_features,
                max_length=225,
                num_beams=1,
                do_sample=False,
                language="en",  # Force English to avoid language detection
                task="transcribe",  # Explicit task
                forced_decoder_ids=None,  # Explicitly set to None
                suppress_tokens=[],  # Empty suppress tokens
            )
            
            # Decode to text
            transcription = self.processor.tokenizer.batch_decode(
                predicted_ids, 
                skip_special_tokens=True
            )[0]
            
            return transcription.strip()
    
    def evaluate_dataset(self, dataset: Dataset, max_samples: Optional[int] = None) -> Dict:
        """Evaluate model on a dataset and compute metrics."""
        print("Running evaluation...")
        
        predictions = []
        references = []
        
        eval_samples = min(len(dataset), max_samples) if max_samples else len(dataset)
        
        self.model.eval()
        
        for i in range(eval_samples):
            try:
                sample = dataset[i]
                
                # Handle input_features - convert to tensor if needed
                input_features = sample['input_features']
                if isinstance(input_features, list):
                    input_features = torch.tensor(input_features)
                elif not isinstance(input_features, torch.Tensor):
                    input_features = torch.tensor(input_features)
                
                # Ensure correct shape - add batch dimension if needed
                if len(input_features.shape) == 2:
                    input_features = input_features.unsqueeze(0)
                
                # Handle labels
                labels = sample['labels']
                if isinstance(labels, list):
                    labels = torch.tensor(labels)
                elif not isinstance(labels, torch.Tensor):
                    labels = torch.tensor(labels)
                
                # Get reference text (decode labels)
                reference = self.processor.tokenizer.decode(
                    labels, 
                    skip_special_tokens=True
                ).strip()
                
                # Generate prediction
                prediction = self.transcribe_audio(input_features)
                
                predictions.append(prediction)
                references.append(reference)
                
                if (i + 1) % 50 == 0:
                    print(f"Evaluated {i + 1}/{eval_samples} samples...")
                    
            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                # Add empty results to maintain alignment
                predictions.append("")
                references.append("")
                continue
        
        # Filter out empty predictions/references
        valid_pairs = [(p, r) for p, r in zip(predictions, references) if p.strip() and r.strip()]
        if valid_pairs:
            valid_predictions, valid_references = zip(*valid_pairs)
        else:
            valid_predictions, valid_references = [], []
        
        # Compute WER
        if len(valid_predictions) > 0:
            wer_score = self.wer_metric.compute(predictions=list(valid_predictions), references=list(valid_references))
        else:
            wer_score = 1.0  # 100% error if no valid predictions
        
        results = {
            'wer': wer_score,
            'num_samples': len(valid_predictions),
            'predictions': list(valid_predictions[:5]),  # First 5 for inspection
            'references': list(valid_references[:5])
        }
        
        return results
    
    def print_evaluation_results(self, results: Dict):
        """Print evaluation results in a readable format."""
        print("\n" + "="*50)
        print("EVALUATION RESULTS")
        print("="*50)
        print(f"Word Error Rate (WER): {results['wer']:.4f}")
        print(f"Samples evaluated: {results['num_samples']}")
        
        print("\nSample Predictions vs References:")
        print("-"*50)
        
        for i, (pred, ref) in enumerate(zip(results['predictions'], results['references'])):
            print(f"Sample {i+1}:")
            print(f"  Reference: {ref}")
            print(f"  Prediction: {pred}")
            print()

# Initialize evaluator (will load the trained model)
evaluator = ModelEvaluator(CONFIG['output_dir'])

Loading trained model from ./whisper-medical-asr...
Model and metrics loaded for evaluation.


In [27]:
# Evaluate on validation set
print("Evaluating model on validation set...")
eval_results = evaluator.evaluate_dataset(val_dataset, max_samples=100)  # Limit for speed

# Print results
evaluator.print_evaluation_results(eval_results)

`generation_config` default values have been modified to match model-specific defaults: {'suppress_tokens': [], 'begin_suppress_tokens': [220, 50257]}. If this is not desired, please set these values explicitly.


Evaluating model on validation set...
Running evaluation...


A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> to see related `.generate()` flags.
A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> to see related `.generate()` flags.


Evaluated 50/100 samples...
Evaluated 100/100 samples...

EVALUATION RESULTS
Word Error Rate (WER): 0.1766
Samples evaluated: 100

Sample Predictions vs References:
--------------------------------------------------
Sample 1:
  Reference: Remember to follow your healthcare provider's instructions carefully when taking umeclidinium.
  Prediction: Remember to follow your healthcare providers instructions carefully when taking U-Micladinium.

Sample 2:
  Reference: DUORANDIL is a commonly prescribed medication for individuals with heart conditions.
  Prediction: Durandil is a commonly prescribed medication for individuals with heart conditions.

Sample 3:
  Reference: It is important to follow the dosage instructions when taking JILAZO to ensure its effectiveness.
  Prediction: It is important to follow the dosage instructions when taking gelazzo to ensure its effectiveness.

Sample 4:
  Reference: Have you tried Bevon softules for an easy-to-take multivitamin solution?
  Prediction: Have

## Module 5: Inference and Testing

In [28]:
class InferenceEngine:
    """Handles inference on new audio samples."""
    
    def __init__(self, model_path: str):
        self.model_path = model_path
        
        # Load model and processor
        print(f"Loading model for inference from {model_path}...")
        self.model = WhisperForConditionalGeneration.from_pretrained(model_path)
        self.processor = WhisperProcessor.from_pretrained(model_path)
        
        # Fix generation config issues
        self.model.generation_config.forced_decoder_ids = None
        self.model.generation_config.suppress_tokens = []
        
        # Set to eval mode
        self.model.eval()
        
        print("Inference engine ready.")
    
    def transcribe_from_features(self, input_features: torch.Tensor) -> str:
        """Transcribe audio from preprocessed features."""
        with torch.no_grad():
            # Generate transcription with explicit parameters
            predicted_ids = self.model.generate(
                input_features,
                max_length=225,
                num_beams=2,  # Slightly better quality
                do_sample=False,
                temperature=1.0,
                language="en",  # Force English
                task="transcribe",  # Explicit task
                forced_decoder_ids=None,  # Explicitly set to None
                suppress_tokens=[],  # Empty suppress tokens
            )
            
            # Decode to text
            transcription = self.processor.tokenizer.batch_decode(
                predicted_ids, 
                skip_special_tokens=True
            )[0]
            
            return transcription.strip()
    
    def transcribe_raw_audio(self, audio_array: np.ndarray, sampling_rate: int) -> str:
        """Transcribe from raw audio array."""
        # Preprocess audio
        if sampling_rate != 16000:
            # Resample to 16kHz
            audio_tensor = torch.from_numpy(audio_array).float()
            if len(audio_tensor.shape) == 1:
                audio_tensor = audio_tensor.unsqueeze(0)
            
            resampler = torchaudio.transforms.Resample(
                orig_freq=sampling_rate, 
                new_freq=16000
            )
            audio_array = resampler(audio_tensor).squeeze().numpy()
        
        # Extract features
        features = self.processor.feature_extractor(
            audio_array, 
            sampling_rate=16000, 
            return_tensors="pt"
        )
        
        # Transcribe
        return self.transcribe_from_features(features["input_features"])
    
    def demo_inference(self, dataset: Dataset, num_samples: int = 3):
        """Run demo inference on dataset samples."""
        print(f"\nRunning demo inference on {num_samples} samples...")
        print("="*60)
        
        for i in range(min(num_samples, len(dataset))):
            try:
                sample = dataset[i]
                
                # Handle input_features - convert to tensor if needed
                input_features = sample['input_features']
                if isinstance(input_features, list):
                    input_features = torch.tensor(input_features)
                elif not isinstance(input_features, torch.Tensor):
                    input_features = torch.tensor(input_features)
                
                # Ensure correct shape - add batch dimension if needed
                if len(input_features.shape) == 2:
                    input_features = input_features.unsqueeze(0)
                
                # Handle labels
                labels = sample['labels']
                if isinstance(labels, list):
                    labels = torch.tensor(labels)
                elif not isinstance(labels, torch.Tensor):
                    labels = torch.tensor(labels)
                
                # Get reference
                reference = self.processor.tokenizer.decode(
                    labels, 
                    skip_special_tokens=True
                ).strip()
                
                # Get prediction
                prediction = self.transcribe_from_features(input_features)
                
                # Display results
                print(f"\nSample {i+1}:")
                print(f"Reference:  {reference}")
                print(f"Prediction: {prediction}")
                print("-" * 60)
                
            except Exception as e:
                print(f"\nError processing sample {i+1}: {e}")
                print("-" * 60)

# Initialize inference engine
inference_engine = InferenceEngine(CONFIG['output_dir'])

Loading model for inference from ./whisper-medical-asr...
Inference engine ready.


In [30]:
# Run demo inference
inference_engine.demo_inference(val_dataset, num_samples=5)


Running demo inference on 5 samples...

Sample 1:
Reference:  Remember to follow your healthcare provider's instructions carefully when taking umeclidinium.
Prediction: Remember to follow your healthcare provider's instruction carefully when taking you Mickladinium.
------------------------------------------------------------

Sample 2:
Reference:  DUORANDIL is a commonly prescribed medication for individuals with heart conditions.
Prediction: Durandel is a commonly prescribed medication for individuals with heart conditions.
------------------------------------------------------------

Sample 3:
Reference:  It is important to follow the dosage instructions when taking JILAZO to ensure its effectiveness.
Prediction: It is important to follow the dosage instructions when taking jillazo to ensure its effectiveness.
------------------------------------------------------------

Sample 4:
Reference:  Have you tried Bevon softules for an easy-to-take multivitamin solution?
Prediction: "Have

## Summary and Next Steps

In [31]:
# Summary of the training run
print("\n" + "="*60)
print("BASELINE ASR MODEL TRAINING COMPLETE")
print("="*60)
print(f"Model: {CONFIG['model_name']}")
print(f"Dataset: {CONFIG['dataset_name']}")
print(f"Samples used: {CONFIG['num_samples']}")
print(f"Training epochs: {CONFIG['num_epochs']}")
print(f"Model saved to: {CONFIG['output_dir']}")
print(f"Final WER: {eval_results['wer']:.4f}")

print("\nNext Steps:")
print("1. Fine-tune hyperparameters for better WER")
print("2. Increase dataset size for more robust training")
print("3. Implement EHR structuring pipeline")
print("4. Add domain-specific medical vocabulary")
print("5. Evaluate on held-out test set")

print("\nModel Files:")
import os
if os.path.exists(CONFIG['output_dir']):
    files = os.listdir(CONFIG['output_dir'])
    for file in files:
        print(f"  - {file}")
else:
    print("  Model directory not found.")


BASELINE ASR MODEL TRAINING COMPLETE
Model: openai/whisper-tiny
Dataset: Shamus/United-Syn-Med
Samples used: 700
Training epochs: 3
Model saved to: ./whisper-medical-asr
Final WER: 0.1766

Next Steps:
1. Fine-tune hyperparameters for better WER
2. Increase dataset size for more robust training
3. Implement EHR structuring pipeline
4. Add domain-specific medical vocabulary
5. Evaluate on held-out test set

Model Files:
  - runs
  - model.safetensors
  - normalizer.json
  - vocab.json
  - generation_config.json
  - config.json
  - checkpoint-54
  - preprocessor_config.json
  - tokenizer_config.json
  - training_args.bin
  - merges.txt
  - special_tokens_map.json
  - added_tokens.json


## Model downloading 


In [32]:
# Download trained model as ZIP file
def download_trained_model():
    """Download the complete trained model as a ZIP file."""
    import zipfile
    import os
    
    # Your model directory
    model_dir = CONFIG['output_dir']  # './whisper-medical-asr'
    zip_filename = 'whisper_medical_asr_model.zip'
    
    # Create ZIP file with all model files
    with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(model_dir):
            for file in files:
                file_path = os.path.join(root, file)
                # Add file to zip with relative path
                zipf.write(file_path, os.path.relpath(file_path, model_dir))
    
    print(f"✅ Model packaged as {zip_filename}")
    print(f"📁 Size: {os.path.getsize(zip_filename) / (1024*1024):.1f} MB")
    
    # Download the file (works in Colab/Kaggle)
    try:
        from google.colab import files
        files.download(zip_filename)
        print("🚀 Download started!")
    except:
        print("💡 On Kaggle: Find the file in the output section")
        print("💡 On local: File saved in current directory")

# Run this after training
download_trained_model()

✅ Model packaged as whisper_medical_asr_model.zip
📁 Size: 529.4 MB


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

🚀 Download started!


##  Upload to Hugging Face Hub


In [39]:
# Upload to Hugging Face Hub
def upload_to_huggingface():
    """Upload your trained model to Hugging Face Hub."""
    from huggingface_hub import HfApi, create_repo
    
    # Your Hugging Face username and desired repo name
    username = "Abhijeet17o"  # Replace with your HF username
    repo_name = "whisper-tiny-1000-medical-asr"
    
    # Create repository
    api = HfApi()
    repo_url = create_repo(f"{username}/{repo_name}", exist_ok=True)
    
    # Upload model files
    api.upload_folder(
        folder_path=CONFIG['output_dir'],
        repo_id=f"{username}/{repo_name}",
        repo_type="model"
    )
    
    print(f"✅ Model uploaded to: https://huggingface.co/{username}/{repo_name}")
    print("🌐 Now you can load it from anywhere!")

upload_to_huggingface()

model.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

Upload 10 LFS files:   0%|          | 0/10 [00:00<?, ?it/s]

optimizer.pt:   0%|          | 0.00/298M [00:00<?, ?B/s]

scaler.pt:   0%|          | 0.00/988 [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.30k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

events.out.tfevents.1755499978.27923e26b12c.36.0:   0%|          | 0.00/6.12k [00:00<?, ?B/s]

events.out.tfevents.1755501222.27923e26b12c.36.1:   0%|          | 0.00/6.12k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.30k [00:00<?, ?B/s]

✅ Model uploaded to: https://huggingface.co/Abhijeet17o/whisper-tiny-1000-medical-asr
🌐 Now you can load it from anywhere!


In [38]:
# Authenticate with Hugging Face
from huggingface_hub import login
import os

# Method 1: Direct token input (less secure but quick)
token = "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"  # Replace with your actual token
login(token=token)

print("✅ Logged in to Hugging Face!")

✅ Logged in to Hugging Face!


# Base Whisper model to check the WER score

In [40]:
# Compare with Normal Whisper Model
class BaselineWhisperEvaluator:
    """Evaluate normal Whisper model (not fine-tuned) for comparison."""
    
    def __init__(self, model_name: str = 'openai/whisper-tiny'):
        self.model_name = model_name
        
        # Load normal Whisper model (not fine-tuned)
        print(f"Loading baseline Whisper model: {model_name}...")
        self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
        self.processor = WhisperProcessor.from_pretrained(model_name)
        
        # Fix generation config
        self.model.generation_config.forced_decoder_ids = None
        self.model.generation_config.suppress_tokens = []
        
        # Load WER metric
        self.wer_metric = evaluate.load("wer")
        
        print("Baseline Whisper model loaded for comparison.")
    
    def transcribe_audio(self, audio_features: torch.Tensor) -> str:
        """Transcribe audio features to text using baseline Whisper."""
        with torch.no_grad():
            # Generate transcription
            predicted_ids = self.model.generate(
                audio_features,
                max_length=225,
                num_beams=1,
                do_sample=False,
                language="en",
                task="transcribe",
                forced_decoder_ids=None,
                suppress_tokens=[],
            )
            
            # Decode to text
            transcription = self.processor.tokenizer.batch_decode(
                predicted_ids, 
                skip_special_tokens=True
            )[0]
            
            return transcription.strip()
    
    def evaluate_dataset(self, dataset: Dataset, max_samples: Optional[int] = None) -> Dict:
        """Evaluate baseline model on the dataset."""
        print("Running baseline Whisper evaluation...")
        
        predictions = []
        references = []
        
        eval_samples = min(len(dataset), max_samples) if max_samples else len(dataset)
        
        self.model.eval()
        
        for i in range(eval_samples):
            try:
                sample = dataset[i]
                
                # Handle input_features
                input_features = sample['input_features']
                if isinstance(input_features, list):
                    input_features = torch.tensor(input_features)
                elif not isinstance(input_features, torch.Tensor):
                    input_features = torch.tensor(input_features)
                
                if len(input_features.shape) == 2:
                    input_features = input_features.unsqueeze(0)
                
                # Handle labels
                labels = sample['labels']
                if isinstance(labels, list):
                    labels = torch.tensor(labels)
                elif not isinstance(labels, torch.Tensor):
                    labels = torch.tensor(labels)
                
                # Get reference text
                reference = self.processor.tokenizer.decode(
                    labels, 
                    skip_special_tokens=True
                ).strip()
                
                # Generate prediction with baseline model
                prediction = self.transcribe_audio(input_features)
                
                predictions.append(prediction)
                references.append(reference)
                
                if (i + 1) % 25 == 0:
                    print(f"Evaluated {i + 1}/{eval_samples} samples...")
                    
            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                predictions.append("")
                references.append("")
                continue
        
        # Filter out empty predictions/references
        valid_pairs = [(p, r) for p, r in zip(predictions, references) if p.strip() and r.strip()]
        if valid_pairs:
            valid_predictions, valid_references = zip(*valid_pairs)
        else:
            valid_predictions, valid_references = [], []
        
        # Compute WER
        if len(valid_predictions) > 0:
            wer_score = self.wer_metric.compute(predictions=list(valid_predictions), references=list(valid_references))
        else:
            wer_score = 1.0
        
        results = {
            'wer': wer_score,
            'num_samples': len(valid_predictions),
            'predictions': list(valid_predictions[:5]),
            'references': list(valid_references[:5])
        }
        
        return results

# Initialize baseline evaluator
baseline_evaluator = BaselineWhisperEvaluator('openai/whisper-tiny')

Loading baseline Whisper model: openai/whisper-tiny...


Downloading builder script: 0.00B [00:00, ?B/s]

Baseline Whisper model loaded for comparison.


In [41]:
# Run comparison evaluation
print("🔄 Evaluating BASELINE Whisper-tiny model...")
baseline_results = baseline_evaluator.evaluate_dataset(val_dataset, max_samples=100)

print("\n" + "="*60)
print("📊 MODEL COMPARISON RESULTS")
print("="*60)

print(f"\n🤖 BASELINE Whisper-tiny (no fine-tuning):")
print(f"   Word Error Rate: {baseline_results['wer']:.4f} ({baseline_results['wer']*100:.2f}%)")
print(f"   Samples: {baseline_results['num_samples']}")

print(f"\n🎯 YOUR FINE-TUNED Model:")
print(f"   Word Error Rate: {eval_results['wer']:.4f} ({eval_results['wer']*100:.2f}%)")
print(f"   Samples: {eval_results['num_samples']}")

# Calculate improvement
improvement = baseline_results['wer'] - eval_results['wer']
improvement_percent = (improvement / baseline_results['wer']) * 100

print(f"\n📈 IMPROVEMENT:")
print(f"   Absolute WER reduction: {improvement:.4f}")
print(f"   Relative improvement: {improvement_percent:.1f}%")

if improvement > 0:
    print(f"   🎉 Your model is {improvement_percent:.1f}% better!")
else:
    print(f"   ⚠️ Baseline performed {abs(improvement_percent):.1f}% better")

print("\n" + "="*60)

🔄 Evaluating BASELINE Whisper-tiny model...
Running baseline Whisper evaluation...


You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.


Evaluated 25/100 samples...
Evaluated 50/100 samples...
Evaluated 75/100 samples...
Evaluated 100/100 samples...

📊 MODEL COMPARISON RESULTS

🤖 BASELINE Whisper-tiny (no fine-tuning):
   Word Error Rate: 0.1892 (18.92%)
   Samples: 100

🎯 YOUR FINE-TUNED Model:
   Word Error Rate: 0.1766 (17.66%)
   Samples: 100

📈 IMPROVEMENT:
   Absolute WER reduction: 0.0126
   Relative improvement: 6.7%
   🎉 Your model is 6.7% better!



# Comparing the baseline and the finetuned model side by side

In [42]:
# Side-by-side prediction comparison
def compare_predictions(baseline_results, finetuned_results, num_samples=5):
    """Compare predictions side by side."""
    print("\n" + "="*80)
    print("🔍 PREDICTION COMPARISON (Baseline vs Fine-tuned)")
    print("="*80)
    
    for i in range(min(num_samples, len(baseline_results['predictions']))):
        print(f"\nSample {i+1}:")
        print(f"📝 Reference:   {baseline_results['references'][i]}")
        print(f"🤖 Baseline:    {baseline_results['predictions'][i]}")
        print(f"🎯 Fine-tuned:  {finetuned_results['predictions'][i]}")
        print("-" * 80)

# Run the comparison
compare_predictions(baseline_results, eval_results)


🔍 PREDICTION COMPARISON (Baseline vs Fine-tuned)

Sample 1:
📝 Reference:   Remember to follow your healthcare provider's instructions carefully when taking umeclidinium.
🤖 Baseline:    Remember to follow your health care providers instructions carefully when taking you McLidinium.
🎯 Fine-tuned:  Remember to follow your healthcare providers instructions carefully when taking U-Micladinium.
--------------------------------------------------------------------------------

Sample 2:
📝 Reference:   DUORANDIL is a commonly prescribed medication for individuals with heart conditions.
🤖 Baseline:    Durandil is a commonly prescribed medication for individuals with heart conditions.
🎯 Fine-tuned:  Durandil is a commonly prescribed medication for individuals with heart conditions.
--------------------------------------------------------------------------------

Sample 3:
📝 Reference:   It is important to follow the dosage instructions when taking JILAZO to ensure its effectiveness.
🤖 Baseline: 

# Techniques for Major WER Improvements

In [45]:
# Advanced Audio Data Augmentation
import torchaudio.transforms as T
import random

class MedicalAudioAugmentor:
    """Advanced audio augmentation for medical conversations."""
    
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
        
        # Define augmentation transforms
        self.transforms = {
            'noise': self.add_background_noise,
            'speed': self.change_speed,
            'pitch': self.change_pitch,
            'volume': self.change_volume,
            'reverb': self.add_reverb,
            'bandpass': self.bandpass_filter
        }
    
    def add_background_noise(self, audio, noise_factor=0.005):
        """Add subtle background noise (hospital environment)."""
        noise = torch.randn_like(audio) * noise_factor
        return audio + noise
    
    def change_speed(self, audio, speed_factor=None):
        """Change speaking speed (0.9-1.1x)."""
        if speed_factor is None:
            speed_factor = random.uniform(0.9, 1.1)
        
        # Use time stretching
        return torch.nn.functional.interpolate(
            audio.unsqueeze(0).unsqueeze(0),
            scale_factor=speed_factor,
            mode='linear',
            align_corners=False
        ).squeeze()
    
    def change_pitch(self, audio, n_steps=None):
        """Shift pitch slightly (±2 semitones)."""
        if n_steps is None:
            n_steps = random.uniform(-2, 2)
        
        # Pitch shift using resampling approximation
        shift_factor = 2 ** (n_steps / 12)
        return self.change_speed(audio, shift_factor)
    
    def change_volume(self, audio, volume_factor=None):
        """Adjust volume (0.7-1.3x)."""
        if volume_factor is None:
            volume_factor = random.uniform(0.7, 1.3)
        return audio * volume_factor
    
    def add_reverb(self, audio, room_size=0.1):
        """Add subtle room reverb."""
        # Simple reverb using delay and decay
        delay_samples = int(0.05 * self.sample_rate)  # 50ms delay
        decay = 0.3
        
        if len(audio) > delay_samples:
            reverb = torch.zeros_like(audio)
            reverb[delay_samples:] = audio[:-delay_samples] * decay
            return audio + reverb
        return audio
    
    def bandpass_filter(self, audio, low_freq=300, high_freq=8000):
        """Apply bandpass filter (telephone quality)."""
        # Simulate different recording devices
        nyquist = self.sample_rate // 2
        low = low_freq / nyquist
        high = high_freq / nyquist
        
        # Simple frequency domain filtering
        fft = torch.fft.rfft(audio)
        freqs = torch.fft.rfftfreq(len(audio), 1/self.sample_rate)
        
        # Create bandpass mask
        mask = (freqs >= low_freq) & (freqs <= high_freq)
        fft = fft * mask.float()
        
        return torch.fft.irfft(fft, n=len(audio))
    
    def augment_batch(self, audio_batch, augment_prob=0.8, num_augs=2):
        """Apply random augmentations to a batch."""
        augmented_batch = []
        
        for audio in audio_batch:
            if random.random() < augment_prob:
                # Apply random augmentations
                aug_names = random.sample(list(self.transforms.keys()), num_augs)
                
                augmented_audio = audio.clone()
                for aug_name in aug_names:
                    augmented_audio = self.transforms[aug_name](augmented_audio)
                
                augmented_batch.append(augmented_audio)
            else:
                augmented_batch.append(audio)
        
        return augmented_batch

# Initialize augmentor
augmentor = MedicalAudioAugmentor()

In [46]:
# Medical-Specific Preprocessing
class MedicalTextProcessor:
    """Enhanced text processing for medical terminology."""
    
    def __init__(self):
        # Common medical abbreviations and their expansions
        self.medical_abbreviations = {
            'mg': 'milligrams',
            'ml': 'milliliters',
            'bp': 'blood pressure',
            'hr': 'heart rate',
            'temp': 'temperature',
            'wbc': 'white blood cell',
            'rbc': 'red blood cell',
            'ecg': 'electrocardiogram',
            'mri': 'magnetic resonance imaging',
            'ct': 'computed tomography',
            'iv': 'intravenous',
            'po': 'by mouth',
            'bid': 'twice daily',
            'tid': 'three times daily',
            'qid': 'four times daily'
        }
        
        # Common medication name patterns
        self.medication_patterns = [
            r'(\w+)mycin',  # antibiotics
            r'(\w+)cillin', # penicillins
            r'(\w+)pril',   # ACE inhibitors
            r'(\w+)sartan', # ARBs
            r'(\w+)olol',   # beta blockers
            r'(\w+)statin', # statins
        ]
    
    def normalize_medical_text(self, text):
        """Normalize medical text for better training."""
        text = text.lower().strip()
        
        # Expand abbreviations
        for abbrev, expansion in self.medical_abbreviations.items():
            text = text.replace(f' {abbrev} ', f' {expansion} ')
            text = text.replace(f' {abbrev}.', f' {expansion}')
        
        # Normalize medication names
        import re
        for pattern in self.medication_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            for match in matches:
                # Ensure consistent casing for medication names
                original = f"{match}{pattern.split('(')[1].split(')')[1]}"
                normalized = original.lower()
                text = text.replace(original, normalized)
        
        return text
    
    def add_medical_context(self, text):
        """Add context markers for medical conversations."""
        # Add role markers
        if any(word in text.lower() for word in ['doctor', 'physician', 'dr']):
            text = f"[DOCTOR] {text}"
        elif any(word in text.lower() for word in ['patient', 'feel', 'pain', 'hurt']):
            text = f"[PATIENT] {text}"
        
        return text

medical_processor = MedicalTextProcessor()

In [47]:
# Enhanced Audio Preprocessor with Augmentation
class EnhancedAudioPreprocessor(AudioPreprocessor):
    """Enhanced preprocessing with augmentation and medical-specific handling."""
    
    def __init__(self, model_name: str, target_sample_rate: int = 16000, max_audio_length: float = 30.0):
        super().__init__(model_name, target_sample_rate, max_audio_length)
        self.augmentor = MedicalAudioAugmentor(target_sample_rate)
        self.text_processor = MedicalTextProcessor()
        
    def enhanced_audio_processing(self, audio_array: np.ndarray, augment: bool = True) -> np.ndarray:
        """Enhanced audio processing with noise reduction and augmentation."""
        # Convert to tensor
        audio_tensor = torch.from_numpy(audio_array).float()
        
        # Apply spectral noise reduction (simple highpass filter)
        audio_tensor = self.apply_noise_reduction(audio_tensor)
        
        # Apply augmentation during training
        if augment:
            # 50% chance to augment
            if random.random() < 0.5:
                audio_tensor = self.augmentor.add_background_noise(audio_tensor)
            if random.random() < 0.3:
                audio_tensor = self.augmentor.change_volume(audio_tensor)
            if random.random() < 0.2:
                audio_tensor = self.augmentor.change_speed(audio_tensor)
        
        return audio_tensor.numpy()
    
    def apply_noise_reduction(self, audio_tensor: torch.Tensor) -> torch.Tensor:
        """Apply simple noise reduction."""
        # High-pass filter to remove low-frequency noise
        highpass = T.HighpassBiquad(self.target_sample_rate, cutoff_freq=80)
        audio_tensor = highpass(audio_tensor)
        
        # Normalize amplitude
        audio_tensor = audio_tensor / (torch.max(torch.abs(audio_tensor)) + 1e-8)
        
        return audio_tensor
    
    def preprocess_batch(self, batch: Dict, augment: bool = True) -> Dict:
        """Enhanced batch preprocessing with medical text processing."""
        audio_data = []
        texts = []
        
        for i in range(len(batch['audio'])):
            try:
                # Handle audio
                audio_info = batch['audio'][i]
                audio_array, sampling_rate = self.decode_audio_bytes(audio_info)
                
                # Enhanced audio processing
                audio_array = self.resample_audio(audio_array, sampling_rate)
                audio_array = self.enhanced_audio_processing(audio_array, augment=augment)
                audio_array = self.trim_or_pad_audio(audio_array)
                audio_data.append(audio_array)
                
                # Enhanced text processing
                text_field = ""
                for field in ['text', 'transcription', 'sentence', 'transcript']:
                    if field in batch and i < len(batch[field]):
                        text_field = batch[field][i]
                        break
                
                if text_field:
                    # Apply medical text processing
                    text_field = self.text_processor.normalize_medical_text(str(text_field))
                    text_field = self.text_processor.add_medical_context(text_field)
                
                texts.append(text_field if text_field else "")
                
            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                audio_data.append(np.zeros(int(self.max_audio_length * self.target_sample_rate)))
                texts.append("")
        
        # Process with Whisper feature extractor
        features = self.feature_extractor(
            audio_data, 
            sampling_rate=self.target_sample_rate, 
            return_tensors="pt",
            padding=True
        )
        
        # Enhanced tokenization with medical vocabulary
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                texts,
                max_length=448,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            ).input_ids
        
        # Replace padding tokens
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            "input_features": features["input_features"],
            "labels": labels
        }

In [48]:
# Enhanced Training Configuration
ENHANCED_CONFIG = {
    'dataset_name': 'Shamus/United-Syn-Med',
    'model_name': 'openai/whisper-tiny',  # Upgrade to small for better performance
    'num_samples': 1000,  # Increase dataset size
    'target_sample_rate': 16000,
    'train_split_ratio': 0.85,  # More training data
    'output_dir': './whisper-enhanced-medical-asr',
    'max_audio_length': 25.0,  # Slightly shorter for efficiency
    
    # Enhanced training parameters
    'batch_size': 12,  # Larger batch size
    'num_epochs': 5,   # More epochs
    'learning_rate': 3e-5,  # Lower learning rate for stability
    'warmup_steps': 200,
    'eval_steps': 300,
    'save_steps': 600,
    'gradient_accumulation_steps': 2,
    
    # Advanced training techniques
    'weight_decay': 0.01,
    'lr_scheduler_type': 'cosine',
    'gradient_checkpointing': True,
    'fp16': True,
    'dataloader_num_workers': 2,
    
    # Regularization
    'dropout': 0.1,
    'attention_dropout': 0.1,
}

print("🚀 Enhanced training configuration loaded!")
print("Expected improvements:")
print("- Data Augmentation: 15-25% WER reduction")
print("- Medical Text Processing: 10-20% WER reduction") 
print("- Larger Model (Small): 8-12% WER reduction")
print("- Advanced Training: 5-10% WER reduction")
print("- Total Expected WER: 8-12% (vs current 17.66%)")

🚀 Enhanced training configuration loaded!
Expected improvements:
- Data Augmentation: 15-25% WER reduction
- Medical Text Processing: 10-20% WER reduction
- Larger Model (Small): 8-12% WER reduction
- Advanced Training: 5-10% WER reduction
- Total Expected WER: 8-12% (vs current 17.66%)


In [50]:
# Complete Enhanced Training Pipeline
def train_enhanced_model():
    """Train the enhanced model with all improvements."""
    
    # Initialize enhanced preprocessor
    enhanced_preprocessor = EnhancedAudioPreprocessor(
        model_name=ENHANCED_CONFIG['model_name'],
        target_sample_rate=ENHANCED_CONFIG['target_sample_rate'],
        max_audio_length=ENHANCED_CONFIG['max_audio_length']
    )
    
    # Load larger dataset
    enhanced_data_loader = DataLoader(
        dataset_name=ENHANCED_CONFIG['dataset_name'],
        num_samples=ENHANCED_CONFIG['num_samples'],
        train_split_ratio=ENHANCED_CONFIG['train_split_ratio']
    )
    
    # Load and preprocess data
    print("Loading enhanced dataset...")
    samples = enhanced_data_loader.load_dataset_subset()
    train_dataset, val_dataset = enhanced_data_loader.create_train_val_split(samples)
    
    # Apply enhanced preprocessing with augmentation
    print("Applying enhanced preprocessing with augmentation...")
    train_dataset = train_dataset.map(
        lambda batch: enhanced_preprocessor.preprocess_batch(batch, augment=True),
        batched=True,
        batch_size=8,
        remove_columns=train_dataset.column_names
    )
    
    val_dataset = val_dataset.map(
        lambda batch: enhanced_preprocessor.preprocess_batch(batch, augment=False),
        batched=True,
        batch_size=8,
        remove_columns=val_dataset.column_names
    )
    
    # Initialize enhanced trainer
    enhanced_trainer = WhisperTrainer(
        model_name=ENHANCED_CONFIG['model_name'],
        output_dir=ENHANCED_CONFIG['output_dir']
    )
    
    # Train enhanced model
    print("Starting enhanced training...")
    trainer = enhanced_trainer.train(train_dataset, val_dataset, ENHANCED_CONFIG)
    
    return trainer

# Run enhanced training
enhanced_trainer = train_enhanced_model()

Loading Whisper processor for openai/whisper-tiny...
Processor loaded. Target sample rate: 16000Hz
Loading enhanced dataset...
Loading 1000 samples from Shamus/United-Syn-Med...
Attempting direct parquet loading...
Found 33 parquet files
Loading from data/train-00000-of-00033.parquet...
Loaded 31 samples so far...
Loading from data/train-00001-of-00033.parquet...
Loaded 62 samples so far...
Loading from data/train-00002-of-00033.parquet...
Loaded 93 samples so far...
Loading from data/train-00003-of-00033.parquet...
Loaded 124 samples so far...
Loading from data/train-00004-of-00033.parquet...
Loaded 155 samples so far...
Loading from data/train-00005-of-00033.parquet...
Loaded 186 samples so far...
Loading from data/train-00006-of-00033.parquet...
Loaded 217 samples so far...
Loading from data/train-00007-of-00033.parquet...
Loaded 248 samples so far...
Loading from data/train-00008-of-00033.parquet...
Loaded 279 samples so far...
Loading from data/train-00009-of-00033.parquet...
Load

Map:   0%|          | 0/850 [00:00<?, ? examples/s]

Error processing sample 0: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 1: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 2: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 3: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 4: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 5: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 6: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 7: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 0: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 1: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 2: module 'torchaudio.transforms' has no attribute 'High

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Error processing sample 0: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 1: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 2: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 3: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 4: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 5: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 6: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 7: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 0: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 1: module 'torchaudio.transforms' has no attribute 'HighpassBiquad'
Error processing sample 2: module 'torchaudio.transforms' has no attribute 'High

ValueError: Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1694, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1513, in forward
    encoder_outputs = self.encoder(
                      ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/whisper/modeling_whisper.py", line 882, in forward
    raise ValueError(
ValueError: Whisper expects the mel input features to be of length 3000, but found 2500. Make sure to pad the input mel features to 3000.
