# Speech-to-Text with Whisper Transfer Learning

**Objective:** Fine-tune a Whisper base model on the United-Syn-Med dataset to improve medical speech transcription accuracy in a live teleconsultation context.

In [1]:
# Installing required packages

!pip install git+https://github.com/openai/whisper.git
!pip install jiwer datasets torchaudio transformers accelerate soundfile

Collecting git+https://github.com/openai/whisper.git
  Cloning https://github.com/openai/whisper.git to /tmp/pip-req-build-rju39vv6
  Running command git clone --filter=blob:none --quiet https://github.com/openai/whisper.git /tmp/pip-req-build-rju39vv6
  Resolved https://github.com/openai/whisper.git to commit dd985ac4b90cafeef8712f2998d62c59c3e62d22
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->openai-whisper==20240930)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->openai-whisper==20240930)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->openai-whisper==20240930)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-many

In [2]:
# import dependent libraries

import os
import torch
import whisper
import pandas as pd
import soundfile as sf
from datasets import Dataset, DatasetDict
from jiwer import wer, cer
from transformers import WhisperProcessor, WhisperForConditionalGeneration, TrainingArguments, Trainer
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torchaudio
from glob import glob
from tqdm import tqdm  # for progress bar

2025-06-17 09:44:08.732191: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750153448.892769      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750153448.936656      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Loading the data
n = 0
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        if n < 3:
            print(os.path.join(dirname, filename))
            n += 1
        else: break
    if n >= 3: break

/kaggle/input/unitedsnymedsmall/unitedsynmed_small/transcript/validation.csv
/kaggle/input/unitedsnymedsmall/unitedsynmed_small/transcript/train.csv
/kaggle/input/unitedsnymedsmall/unitedsynmed_small/transcript/test.csv


In [4]:
# Paths to the dataset
audio_root = "/kaggle/input/unitedsnymedsmall/unitedsynmed_small/audio"
transcript_root = "/kaggle/input/unitedsnymedsmall/unitedsynmed_small/transcript/"

# Load CSVs and match them with audio paths
def load_split(split):
    csv_path = os.path.join(transcript_root, f"{split}.csv")
    df = pd.read_csv(csv_path)
    df["path"] = df["file_name"].apply(lambda x: os.path.join(audio_root, split, x))
    return df

# Create datasets
train_df = load_split("train")
test_df = load_split("test")
val_df = load_split("validation")

# Convert to Hugging Face Dataset
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df),
    "test": Dataset.from_pandas(test_df),
    "validation": Dataset.from_pandas(val_df)
})


In [5]:
dataset["train"][:5]

{'file_name': ['drug-male-0b01f9d4-980d-451f-a8f1-18e899158859.wav',
  'drug-male-d58aac86-05d3-40ea-a61d-e1cbb7f3e790.wav',
  'drug-female-06c23421-e597-4cf4-a912-1d44c187a4f3.wav',
  'drug-male-9300288f-77c3-4c42-a0f6-166877f7f965.wav',
  'drug-female-86945722-12e1-4983-bf51-6aa27b196dc9.wav'],
 'transcription': ['Iron calx is a commonly used medicine to treat iron deficiency anemia.',
  'If you experience nausea or vomiting, DOMPAR may help alleviate your symptoms.',
  'AGROBEN-I is a reliable medicine for treating infections in plants.',
  "Make sure to follow your healthcare provider's instructions carefully while taking FEVIBID for optimal results.",
  'Clinical trials have shown favorable results with maralixibat chloride in pediatric patients.'],
 'path': ['/kaggle/input/unitedsnymedsmall/unitedsynmed_small/audio/train/drug-male-0b01f9d4-980d-451f-a8f1-18e899158859.wav',
  '/kaggle/input/unitedsnymedsmall/unitedsynmed_small/audio/train/drug-male-d58aac86-05d3-40ea-a61d-e1cbb7f3

In [None]:
# # Load Whisper processor
# processor = WhisperProcessor.from_pretrained("openai/whisper-base")

# # Set target sample rate
# target_sample_rate = 16000

# def preprocess(batch):
#     audio_input, sr = sf.read(batch["path"])
    
#     # If the sample rate is not 16kHz, resample it
#     # if sr != target_sample_rate:
#     waveform = torch.tensor(audio_input, dtype=torch.float32).float().to(device) 
#     if len(waveform.shape) > 1 and waveform.shape[0] > 1:
#         waveform = waveform.mean(dim=0)  # Convert to mono
#     resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
#     audio_input = resampler(waveform).cpu().numpy()
    
#     inputs = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt").to(device)
#     batch["input_features"] = inputs.input_features[0].to
#     batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids
#     return batch

# # Apply preprocessing
# dataset = dataset.map(preprocess)

In [6]:
# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
print(f"Using device: {device}")

Using device: cuda


In [8]:
import torchaudio
from transformers import WhisperProcessor
import torch

# Initialize processor
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
device = "cuda" if torch.cuda.is_available() else "cpu"

def preprocess(batch):
    
    # 1. Load audio file
    waveform, sr = torchaudio.load(batch["path"])
    waveform = waveform.to(device)  # Move to GPU here
    
    # 2. Verify sample rate (optional if you're certain)
    if sr != 16000:
        raise ValueError(f"Invalid sample rate {sr}Hz (expected 16000Hz)")
    
    # 3. Convert to mono if needed
    if waveform.dim() > 1 and waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    # 4. Process audio - key fix is making sure we use the tensor, not the method
    audio_array = waveform.squeeze().cpu().numpy()  # Explicitly move to CPU first
    
    # 5. Generate features
    inputs = processor(
        audio_array,
        sampling_rate=16000,
        return_tensors="pt"
    )
    
    # 6. Prepare output - ensure we're using the actual tensors
    batch["input_features"] = inputs.input_features[0].numpy()  # Convert to numpy array
    batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids
    batch["input_features"] = inputs.input_features[0].cpu().numpy()
    
    return batch
        

# Apply preprocessing
dataset = dataset.map(preprocess)

preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

Map:   0%|          | 0/9500 [00:00<?, ? examples/s]

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

In [10]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    apply_spec_augment: bool = True
    apply_noise: bool = True

    def __call__(self, features):
        # 1. Pad input features and labels
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        # 2. Apply SpecAugment (correct implementation)
        if self.apply_spec_augment:
            # Time masking
            time_masked = batch["input_features"].clone()
            time_mask_length = torch.randint(low=0, high=30, size=(1,)).item()
            time_mask_start = torch.randint(low=0, high=max(1, batch["input_features"].shape[1] - time_mask_length), size=(1,)).item()
            time_masked[:, time_mask_start:time_mask_start+time_mask_length, :] = 0
            
            # Frequency masking
            freq_masked = time_masked.clone()
            freq_mask_length = torch.randint(low=0, high=10, size=(1,)).item()
            freq_mask_start = torch.randint(low=0, high=max(1, batch["input_features"].shape[2] - freq_mask_length), size=(1,)).item()
            freq_masked[:, :, freq_mask_start:freq_mask_start+freq_mask_length] = 0
            
            batch["input_features"] = freq_masked
        
        # 3. Add Gaussian noise
        if self.apply_noise and torch.rand(1) < 0.05:
            batch["input_features"] += torch.randn_like(batch["input_features"]) * 0.01
        
        # 4. Normalize
        batch["input_features"] = (batch["input_features"] - batch["input_features"].mean()) / (batch["input_features"].std() + 1e-7)
        
        # 5. Process labels
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch["input_ids"] == self.processor.tokenizer.pad_token_id, -100)
        batch["labels"] = labels

        return batch

### Train Model

In [11]:

from transformers import WhisperForConditionalGeneration, Seq2SeqTrainingArguments, TrainerCallback, Trainer
from transformers.integrations import TensorBoardCallback
import torch
import os

# 1. Load base Whisper model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")


config.json:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

In [12]:

# 2. Freeze encoder except last 2 layers (optional fine-tuning strategy)
for param in model.model.encoder.parameters():
    param.requires_grad = False
for layer in model.model.encoder.layers[-2:]:
    for param in layer.parameters():
        param.requires_grad = True


In [13]:

# 3. Define compute_metrics
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Limit to first 100 samples to reduce memory usage
    max_samples = 100
    pred_ids = pred_ids[:max_samples]
    label_ids = label_ids[:max_samples]
    
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    wer_score = wer(label_str, pred_str)
    cer_score = cer(label_str, pred_str)
    
    return {"eval_loss": wer_score, "eval_cer": cer_score}


In [14]:
# 4. Clear memory callback
class ClearMemoryCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()

In [15]:

# 5. Create logging directory
os.makedirs("./logs", exist_ok=True)

# 6. Set label names for Trainer
if hasattr(model, 'config'):
    model.config.label_names = ["labels"]
    
# 7. Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-medical-v2",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=10,
    learning_rate=3e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    fp16=True,
    gradient_checkpointing=True,
    report_to="tensorboard",
    logging_dir="./logs",
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    remove_unused_columns=False,
    eval_accumulation_steps=2,
    prediction_loss_only=True
)

# 8. Data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# 9. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"].with_format("torch"),
    eval_dataset=dataset["validation"].with_format("torch"),
    processing_class=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[ClearMemoryCallback()],
)

In [16]:
# Confirm GPU is being used

print(f"GPU available: {torch.cuda.is_available()}")
print(f"Device being used: {trainer.args.device}")  # After creating trainer

GPU available: True
Device being used: cuda:0


In [17]:
# Train Model

trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Epoch,Training Loss,Validation Loss
1,2.9684,2.700359
2,2.176,2.23067
3,1.7284,2.032025
4,1.3823,1.907453
5,1.1409,1.845043
6,0.9342,1.783167
7,0.7733,1.757256
8,0.6625,1.750147
9,0.5615,1.750251
10,0.524,1.759988


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=5940, training_loss=1.439986103391808, metrics={'train_runtime': 3279.2, 'train_samples_per_second': 28.97, 'train_steps_per_second': 1.811, 'total_flos': 6.1616996352e+18, 'train_loss': 1.439986103391808, 'epoch': 10.0})

In [18]:
# Save Model

trainer.save_model("./whisper-medical-v2-final")
processor.save_pretrained("./whisper-medical-v2-final")


[]

In [None]:
# # Reload model

# from transformers import WhisperForConditionalGeneration, WhisperProcessor

# model = WhisperForConditionalGeneration.from_pretrained("./whisper-medical-v2-final")
# processor = WhisperProcessor.from_pretrained("./whisper-medical-v2-final")


In [19]:
results = trainer.evaluate()
print(results)

{'eval_loss': 1.7540605068206787, 'eval_runtime': 33.2195, 'eval_samples_per_second': 36.123, 'eval_steps_per_second': 18.062, 'epoch': 10.0}


In [20]:
# Download the saved model
import shutil

# Zip the model directory
shutil.make_archive("whisper-medical-v2-final", 'zip', "./whisper-medical-v2-final")


'/kaggle/working/whisper-medical-v2-final.zip'

#### Test the trained model on a single audio file

In [25]:
import torchaudio
from transformers import pipeline

# 1. Load audio file correctly
test_file = dataset["test"][215]["path"]  # or your custom path
print(f"Testing on: {test_file}")
waveform, sample_rate = torchaudio.load(test_file)

# 2. Convert to mono if needed and resample
if waveform.shape[0] > 1:  # If multi-channel
    waveform = waveform.mean(dim=0, keepdim=True)  # Convert to mono
    
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(sample_rate, 16000)
    waveform = resampler(waveform)


Testing on: /kaggle/input/unitedsnymedsmall/unitedsynmed_small/audio/test/drug-male-2879b019-a935-4a3e-98fc-0617ce82a124.wav


In [26]:
# 3. Prepare input format (single channel, numpy array)
audio_input = waveform.squeeze().numpy()  # Must be 1D numpy array


In [27]:
# 4. Create ASR pipeline with trained model
asr_pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device=0 if torch.cuda.is_available() else -1
)

Device set to use cuda:0


In [28]:
# 5. Run inference
results = asr_pipe(
    audio_input,  # 1D numpy array
    return_timestamps="word",
    generate_kwargs={
        "language": "en",
        "task": "transcribe"
    }
)

You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, None], [2, 50359]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


In [29]:
# 6. Print results
print("Transcription:")
print(results["text"])

if "chunks" in results:
    print("\nTimestamps:")
    for chunk in results["chunks"]:
        print(f"{chunk['timestamp']}: {chunk['text']}")

Transcription:
 Make sure to take Telmaze-ED exactly as prescribed by your healthcare provider.                                                                                                                                                                                                                                                                                                                                                                                                                                           

Timestamps:
(0.26, 0.38):  Make
(0.38, 0.64):  sure
(0.64, 0.82):  to
(0.82, 1.04):  take
(1.04, 1.64):  Telmaze
(1.64, 2.18): -ED
(2.18, 2.66):  exactly
(2.66, 3.12):  as
(3.12, 3.58):  prescribed
(3.58, 3.96):  by
(3.96, 4.16):  your
(4.16, 4.6):  healthcare
(4.6, 5.6):  provider.
(5.6, None):                                                                                                                                                                                    

In [31]:
!pip install evaluate jiwer

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


In [32]:
# Get reference transcription
reference = dataset["test"][215]["transcription"]

# Calculate WER
from evaluate import load
wer = load("wer")
print(f"\nWER: {100 * wer.compute(predictions=[results['text']], references=[reference]):.2f}%")
print(f"Reference: {reference}")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]


WER: 8.33%
Reference: Make sure to take TELMA-ACT exactly as prescribed by your healthcare provider.


In [33]:
# Download the trained model folder
import shutil

# Zip the model directory
shutil.make_archive("whisper-medical-v2", 'zip', "./whisper-medical-v2")


'/kaggle/working/whisper-medical-v2.zip'