# Speech-to-Text with Whisper Transfer Learning

**Objective:** Fine-tune a Whisper base model on the United-Syn-Med dataset to improve medical speech transcription accuracy in a live teleconsultation context.

In [1]:
# Installing required packages

!pip install git+https://github.com/openai/whisper.git
!pip install jiwer datasets torchaudio transformers accelerate soundfile

Collecting git+https://github.com/openai/whisper.git
  Cloning https://github.com/openai/whisper.git to /tmp/pip-req-build-djjjdjq0
  Running command git clone --filter=blob:none --quiet https://github.com/openai/whisper.git /tmp/pip-req-build-djjjdjq0
  Resolved https://github.com/openai/whisper.git to commit dd985ac4b90cafeef8712f2998d62c59c3e62d22
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->openai-whisper==20240930)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->openai-whisper==20240930)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->openai-whisper==20240930)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-many

In [19]:
# import dependent libraries

import os
import torch
import whisper
import pandas as pd
import soundfile as sf
from datasets import Dataset, DatasetDict
from jiwer import wer, cer
from transformers import WhisperProcessor, WhisperForConditionalGeneration, TrainingArguments, Trainer
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torchaudio

In [None]:
# Loading the data
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/unitedsynmed/UnitedSynMed/transcript/validation.csv
/kaggle/input/unitedsynmed/UnitedSynMed/transcript/train.csv
/kaggle/input/unitedsynmed/UnitedSynMed/transcript/test.csv


In [20]:
# Paths to the dataset
audio_root = "/kaggle/input/unitedsynmed/UnitedSynMed/audio"
transcript_root = "/kaggle/input/unitedsynmed/UnitedSynMed/transcript/"

# Load CSVs and match them with audio paths
def load_split(split):
    csv_path = os.path.join(transcript_root, f"{split}.csv")
    df = pd.read_csv(csv_path)
    df["path"] = df["file_name"].apply(lambda x: os.path.join(audio_root, split, x))
    return df

# Create datasets
train_df = load_split("train")
test_df = load_split("test")
val_df = load_split("validation")

# Convert to Hugging Face Dataset
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df),
    "test": Dataset.from_pandas(test_df),
    "validation": Dataset.from_pandas(val_df)
})


In [21]:
dataset["train"][:5]

{'file_name': ['drug-female-defa7fcb-89d7-4b25-8834-90888b201d25.mp3',
  'drug-female-160727b4-dd0c-43c7-ba17-963ae54347a0.mp3',
  'drug-female-637d7dcc-fe73-499c-af76-b2ee28d36374.mp3',
  'drug-male-02a2daf6-0f99-4939-848d-adc95f03d4bd.mp3',
  'drug-brand-en-us-male-421229aa-4f71-48fa-bd43-a9ac606783f8.mp3'],
 'transcription': ['Durysta is a medication used to reduce eye pressure in patients with open-angle glaucoma or ocular hypertension.',
  'Annona muricata extract is known for its potential health benefits as a natural dietary supplement.',
  'Many patients have found relief with REDBURY GOLD for their ongoing health issues.',
  'ALMAL-Z is a popular medication used for treating allergies and cold symptoms.',
  ' Norfazole may cause side effects such as nausea or a metallic taste in the mouth.'],
 'path': ['/kaggle/input/unitedsynmed/UnitedSynMed/audio/train/drug-female-defa7fcb-89d7-4b25-8834-90888b201d25.mp3',
  '/kaggle/input/unitedsynmed/UnitedSynMed/audio/train/drug-female-16

In [None]:
from glob import glob

# Define source and target folders
source_root = "/kaggle/input/unitedsynmed/UnitedSynMed/audio"
target_root = "/kaggle/input/unitedsynmed/UnitedSynMed/audio_resampled"
target_sample_rate = 16000

os.makedirs(target_root, exist_ok=True)

splits = ['train', 'test', 'validation']

for split in splits:
    src_dir = os.path.join(source_root, split)
    tgt_dir = os.path.join(target_root, split)
    os.makedirs(tgt_dir, exist_ok=True)

    audio_files = glob(os.path.join(src_dir, "*.mp3"))

    for file in audio_files:
        waveform, sr = torchaudio.load(file)
        if sr != target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
            waveform = resampler(waveform)

        filename = os.path.splitext(os.path.basename(file))[0] + ".wav"
        torchaudio.save(os.path.join(tgt_dir, filename), waveform, target_sample_rate)

print("✅ All audio resampled and saved to:", target_root)

In [None]:

# Load Whisper processor
processor = WhisperProcessor.from_pretrained("openai/whisper-base")

# Set target sample rate
target_sample_rate = 16000

def preprocess(batch):
    audio_input, sr = sf.read(batch["path"])
    
    # If the sample rate is not 16kHz, resample it
    if sr != target_sample_rate:
        waveform = torch.tensor(audio_input).float()
        if len(waveform.shape) > 1 and waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0)  # Convert to mono
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
        audio_input = resampler(waveform).numpy()
    
    inputs = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt")
    batch["input_features"] = inputs.input_features[0]
    batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess)

Map:   0%|          | 0/632548 [00:00<?, ? examples/s]

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": f["input_features"]} for f in features]
        label_features = [{"input_ids": f["labels"]} for f in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch["input_ids"] == self.processor.tokenizer.pad_token_id, -100)
        batch["labels"] = labels

        return batch

In [None]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
# Freeze encoder layers
for param in model.model.encoder.parameters():
    param.requires_grad = False

In [None]:

training_args = TrainingArguments(
    output_dir="./whisper-medical",
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,
    logging_dir="./logs",
    learning_rate=1e-4,
    warmup_steps=500,
    fp16=True,
    push_to_hub=False,
)

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
)

In [None]:
trainer.train()

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer_score = wer(label_str, pred_str)
    cer_score = cer(label_str, pred_str)

    return {"wer": wer_score, "cer": cer_score}

results = trainer.evaluate()
print(results)

In [None]:
model.save_pretrained("whisper-medical-finetuned")
processor.save_pretrained("whisper-medical-finetuned")