# Speech-to-Text with Whisper Transfer Learning

**Objective:** Fine-tune a Whisper base model on the United-Syn-Med dataset to improve medical speech transcription accuracy in a live teleconsultation context.

In [1]:
# Installing required packages

!pip install git+https://github.com/openai/whisper.git
!pip install jiwer datasets torchaudio transformers accelerate soundfile

Collecting git+https://github.com/openai/whisper.git
  Cloning https://github.com/openai/whisper.git to c:\users\ifedi\appdata\local\temp\pip-req-build-y7g0svdv
  Resolved https://github.com/openai/whisper.git to commit dd985ac4b90cafeef8712f2998d62c59c3e62d22
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting more-itertools
  Downloading more_itertools-10.7.0-py3-none-any.whl (65 kB)
     ---------------------------------------- 65.3/65.3 kB 1.7 MB/s eta 0:00:00
Collecting numba
  Downloading numba-0.61.2-cp311-cp311-win_amd64.whl (2.8 MB)
     ---------------------------------------- 2.8/2.8 MB 2.4 MB/s eta 0:00:00
Collecting tiktoken
  Downloading tiktoken-0.9.0-cp311-cp311-win_amd64.whl (893 kB)
 

  Running command git clone --filter=blob:none --quiet https://github.com/openai/whisper.git 'C:\Users\ifedi\AppData\Local\Temp\pip-req-build-y7g0svdv'

[notice] A new release of pip available: 22.3.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
     -------------------------------------- 491.5/491.5 kB 1.4 MB/s eta 0:00:00
Collecting torchaudio
  Downloading torchaudio-2.7.1-cp311-cp311-win_amd64.whl (2.5 MB)
     ---------------------------------------- 2.5/2.5 MB 1.8 MB/s eta 0:00:00
Collecting transformers
  Downloading transformers-4.52.4-py3-none-any.whl (10.5 MB)
     ---------------------------------------- 10.5/10.5 MB 1.8 MB/s eta 0:00:00
Collecting accelerate
  Downloading accelerate-1.7.0-py3-none-any.whl (362 kB)
     -------------------------------------- 362.1/362.1 kB 1.4 MB/s eta 0:00:00
Collecting soundfile
  Downloading soundfile-0.13.1-py2.py3-none-win_amd64.whl (1.0 MB)
     ---------------------------------------- 1.0/1.0 MB 1.3 MB/s eta 0:00:00
Collecting click>=8.1.8
  Using cached click-8.2.1-py3-none-any.whl (102 kB)
Collecting rapidfuzz>=3.9.7
  Downloading ra

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.15.1+cpu requires torch==2.0.0, but you have torch 2.7.1 which is incompatible.

[notice] A new release of pip available: 22.3.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# import dependent libraries

import os
import torch
import whisper
import pandas as pd
import soundfile as sf
from datasets import Dataset, DatasetDict
from jiwer import wer, cer
from transformers import WhisperProcessor, WhisperForConditionalGeneration, TrainingArguments, Trainer
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torchaudio
import glob

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [4]:
# Loading the data
import os
n = 0
for dirname, _, filenames in os.walk('../data-collection/'):
    for filename in filenames:
        if n < 3:
            print(os.path.join(dirname, filename))
            n += 1
        else: break
    if n >= 3: break

../data-collection/UnitedSynMed.rar
../data-collection/UnitedSynMed\audio\test\drug-brand-en-us-female-0031c803-9529-4e8c-85cc-e69baedc152c.mp3
../data-collection/UnitedSynMed\audio\test\drug-brand-en-us-female-00c6b893-719c-4550-8907-55177ce89e99.mp3


In [5]:
# Paths to the dataset
audio_root = "../data-collection/UnitedSynMed/audio"
transcript_root = "../data-collection/UnitedSynMed/transcript/"

# Load CSVs and match them with audio paths
def load_split(split):
    csv_path = os.path.join(transcript_root, f"{split}.csv")
    df = pd.read_csv(csv_path)
    df["path"] = df["file_name"].apply(lambda x: os.path.join(audio_root, split, x))
    return df

# Create datasets
train_df = load_split("train")
test_df = load_split("test")
val_df = load_split("validation")

# Convert to Hugging Face Dataset
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df),
    "test": Dataset.from_pandas(test_df),
    "validation": Dataset.from_pandas(val_df)
})


In [6]:
dataset["train"][:5]

{'file_name': ['drug-female-ca56662a-065c-4318-b1c6-41095684a20a.mp3',
  'drug-male-4e02e7a8-ee16-4943-9b7f-4e09d32b558a.mp3',
  'drug-male-133fccfc-22ea-4b3b-a89c-9b5df813bff8.mp3',
  'drug-female-d0565c6e-c35b-43a0-b55b-624802e078c1.mp3',
  'drug-brand-en-us-female-9c816919-b532-4ffb-8be7-08c4a0a671fc.mp3'],
 'transcription': ['Meglumine diatrizoate is commonly used as a contrast medium in medical imaging procedures.',
  'ZOTACAL is a prescription medication used to treat osteoporosis by increasing bone density.',
  'The extract from pongamia pinnata seeds is gaining popularity for its potential therapeutic properties.',
  'The use of aminoacridine may offer a new avenue for the treatment of certain types of leukemia.',
  'Bolofen is commonly used as a muscle relaxant medication.'],
 'path': ['../data-collection/UnitedSynMed/audio\\train\\drug-female-ca56662a-065c-4318-b1c6-41095684a20a.mp3',
  '../data-collection/UnitedSynMed/audio\\train\\drug-male-4e02e7a8-ee16-4943-9b7f-4e09d32b5

In [10]:
# Define source and target folders
source_root = "C:/Users/ifedi/Documents/Applied AI and ML/Session 2/INFO8665 - Projects in Machine Learning/SageCare-2.0/data-collection/UnitedSynMed/audio"
target_root = "C:/Users/ifedi/Documents/Applied AI and ML/Session 2/INFO8665 - Projects in Machine Learning/SageCare-2.0/data-collection/UnitedSynMed/audio_resampled"
target_sample_rate = 16000

os.makedirs(target_root, exist_ok=True)

splits = ['train', 'test', 'validation']

for split in splits:
    src_dir = os.path.join(source_root, split)
    tgt_dir = os.path.join(target_root, split)
    os.makedirs(tgt_dir, exist_ok=True)

    audio_files = glob.glob(os.path.join(src_dir, "*.mp3"))

    for file in audio_files:
        waveform, sr = torchaudio.load(file)
        if sr != target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
            waveform = resampler(waveform)

        filename = os.path.splitext(os.path.basename(file))[0] + ".wav"
        torchaudio.save(os.path.join(tgt_dir, filename), waveform, target_sample_rate)

print("✅ All audio resampled and saved to:", target_root)

✅ All audio resampled and saved to: C:/Users/ifedi/Documents/Applied AI and ML/Session 2/INFO8665 - Projects in Machine Learning/SageCare-2.0/data-collection/UnitedSynMed/audio_resampled


In [None]:
# # Paths to the dataset
# audio_root = "../data-collection/UnitedSynMed/audio_resampled"
# transcript_root = "../data-collection/UnitedSynMed/transcript/"

# # Load CSVs and match them with audio paths
# def load_split(split):
#     csv_path = os.path.join(transcript_root, f"{split}.csv")
#     df = pd.read_csv(csv_path)
#     df["path"] = df["file_name"].apply(lambda x: os.path.join(audio_root, split, x))
#     return df

# # Create datasets
# train_df = load_split("train")
# test_df = load_split("test")
# val_df = load_split("validation")

# # Convert to Hugging Face Dataset
# dataset = DatasetDict({
#     "train": Dataset.from_pandas(train_df),
#     "test": Dataset.from_pandas(test_df),
#     "validation": Dataset.from_pandas(val_df)
# })


In [13]:
# Paths to the dataset
audio_root = "../data-collection/UnitedSynMed/audio_resampled"
transcript_root = "../data-collection/UnitedSynMed/transcript/"

# Load CSVs and match them with audio paths
def load_split(split):
    csv_path = os.path.join(transcript_root, f"{split}.csv")
    df = pd.read_csv(csv_path)
    df["path"] = df["file_name"].apply(lambda x: os.path.join(audio_root, split, x[:-4] + ".wav"))
    return df

# Create datasets
train_df = load_split("train")
test_df = load_split("test")
val_df = load_split("validation")

# Convert to Hugging Face Dataset
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df),
    "test": Dataset.from_pandas(test_df),
    "validation": Dataset.from_pandas(val_df)
})


In [14]:

# Load Whisper processor
processor = WhisperProcessor.from_pretrained("openai/whisper-base")

# Set target sample rate
target_sample_rate = 16000

def preprocess(batch):
    audio_input, sr = sf.read(batch["path"])
    
    # If the sample rate is not 16kHz, resample it
    if sr != target_sample_rate:
        waveform = torch.tensor(audio_input).float()
        if len(waveform.shape) > 1 and waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0)  # Convert to mono
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
        audio_input = resampler(waveform).numpy()
    
    inputs = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt")
    batch["input_features"] = inputs.input_features[0]
    batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess)

Map:  57%|█████▋    | 35999/63250 [45:59<34:49, 13.04 examples/s]     


ArrowMemoryError: realloc of size 68719476736 failed

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": f["input_features"]} for f in features]
        label_features = [{"input_ids": f["labels"]} for f in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch["input_ids"] == self.processor.tokenizer.pad_token_id, -100)
        batch["labels"] = labels

        return batch

In [None]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
# Freeze encoder layers
for param in model.model.encoder.parameters():
    param.requires_grad = False

In [None]:

training_args = TrainingArguments(
    output_dir="./whisper-medical",
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,
    logging_dir="./logs",
    learning_rate=1e-4,
    warmup_steps=500,
    fp16=True,
    push_to_hub=False,
)

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
)

In [None]:
trainer.train()

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer_score = wer(label_str, pred_str)
    cer_score = cer(label_str, pred_str)

    return {"wer": wer_score, "cer": cer_score}

results = trainer.evaluate()
print(results)

In [None]:
model.save_pretrained("whisper-medical-finetuned")
processor.save_pretrained("whisper-medical-finetuned")