# Fine-Tuning Whisper for Greek Medical Dictation

This notebook demonstrates the fine-tuning process for OpenAI's Whisper models (Small, Medium, and Large-v2) on Greek medical speech data, as part of the Bachelor Thesis "Automatic Speech Recognition for Greek Medical Dictation" by Vardis Georgilas (Athens University of Economics and Business, August 2025).

## Objective
- Adapt Whisper models using Low-Rank Adaptation (LoRA) for improved Automatic Speech Recognition (ASR) in Greek speech recognition.

## Environment
- **Platform**: Google Colab with NVIDIA A100 GPU.
- **Training Time**: Approximately 3-5 hours per model (depending on size).

## Notes
- **Thesis Context**: This work contributes to a system combining ASR with text post processing for Greek healthcare, supervised by Prof. Themos Stafylakis.
- **Potential Issues**: Monitor GPU memory usage for larger models, reduce batch size if errors occur.
- **Extensions**: Post-fine-tuning, models can be evaluated with reranking (see related notebooks).


In [None]:
!pip install datasets==3.6.0
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main
!pip install evaluate jiwer

Collecting datasets==3.6.0
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: datasets
  Attempting uninstall: datasets
    Found existing installation: datasets 4.0.0
    Uninstalling datasets-4.0.0:
      Successfully uninstalled datasets-4.0.0
Successfully installed datasets-3.6.0
Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-xx1l1g0k
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-xx1l1g0k
  Resolved https://github.com/huggingface/transformers to commit 894b2d84b697a9c1b502eb6b18d703b39ec1464a
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  P

In [None]:
from huggingface_hub import login
# put huggingface access token
login(token="#########################")

## Loading Greek Mosel Dataset

In [None]:
from datasets import load_dataset, IterableDatasetDict
import os


os.environ["CUDA_VISIBLE_DEVICES"] = "0"


a = IterableDatasetDict()
a_full = load_dataset("Vardis/Greek_Mosel", split="train")
a_temp = a_full.train_test_split(test_size=0.2, seed=42)  # 80% train
a_val_test = a_temp["test"].train_test_split(test_size=0.5, seed=42)  # 10% val + 10% test

a["train"] = a_temp["train"]
a["validation"] = a_val_test["train"]
a["test"] = a_val_test["test"]

## Loading Common Voice Dataset

In [None]:
language_abbr = "el"


b = IterableDatasetDict()
b_full = load_dataset("mozilla-foundation/common_voice_11_0", language_abbr, split="train+validation+test")
b_temp = b_full.train_test_split(test_size=0.2, seed=42)
b_val_test = b_temp["test"].train_test_split(test_size=0.5, seed=42)

b["train"] = b_temp["train"]
b["validation"] = b_val_test["train"]
b["test"] = b_val_test["test"]

## Loading Fleurs Dataset

In [None]:
language_abbr2 = "el_gr"

c = IterableDatasetDict()
c_full = load_dataset("google/fleurs", language_abbr2, split="train+validation+test")
c_temp = c_full.train_test_split(test_size=0.2, seed=42)
c_val_test = c_temp["test"].train_test_split(test_size=0.5, seed=42)

c["train"] = c_temp["train"]
c["validation"] = c_val_test["train"]
c["test"] = c_val_test["test"]


## Clean and Standardize Columns

- Remove unnecessary columns from datasets.  
- Rename text columns to a common name `"sentence"` for consistency.


In [None]:
b = b.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
c = c.remove_columns(["id", "num_samples", "path", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"])

a = a.rename_column("text", "sentence")
c = c.rename_column("transcription", "sentence")

## Combine and Resample Audio Datasets

- Cast the audio column in datasets to a **16 kHz sampling rate**.  
- Concatenate the train, validation, and test splits from all three datasets.  


In [None]:
from datasets import Audio
from datasets import concatenate_datasets


a = a.cast_column("audio", Audio(sampling_rate=16000))
b = b.cast_column("audio", Audio(sampling_rate=16000))
c = c.cast_column("audio", Audio(sampling_rate=16000))


combined_train = concatenate_datasets([a['train'], b['train'], c['train']])
combined_test = concatenate_datasets([a['test'], b['test'], c['test']])
combined_valid = concatenate_datasets([a['validation'], b['validation'], c['validation']])

combined_dataset = IterableDatasetDict({
    'train': combined_train,
    "validation": combined_valid,
    'test': combined_test
})

dataset = combined_dataset
print(dataset)


README.md:   0%|          | 0.00/362 [00:00<?, ?B/s]

data/train-00000-of-00007.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

data/train-00001-of-00007.parquet:   0%|          | 0.00/494M [00:00<?, ?B/s]

data/train-00002-of-00007.parquet:   0%|          | 0.00/498M [00:00<?, ?B/s]

data/train-00003-of-00007.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

data/train-00004-of-00007.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

data/train-00005-of-00007.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

data/train-00006-of-00007.parquet:   0%|          | 0.00/505M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3876 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

common_voice_11_0.py: 0.00B [00:00, ?B/s]

languages.py: 0.00B [00:00, ?B/s]

release_stats.py: 0.00B [00:00, ?B/s]

The repository for mozilla-foundation/common_voice_11_0 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mozilla-foundation/common_voice_11_0.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


n_shards.json: 0.00B [00:00, ?B/s]

audio/el/train/el_train_0.tar:   0%|          | 0.00/57.4M [00:00<?, ?B/s]

audio/el/dev/el_dev_0.tar:   0%|          | 0.00/51.0M [00:00<?, ?B/s]

audio/el/test/el_test_0.tar:   0%|          | 0.00/50.9M [00:00<?, ?B/s]

audio/el/other/el_other_0.tar:   0%|          | 0.00/238M [00:00<?, ?B/s]

audio/el/invalidated/el_invalidated_0.ta(…):   0%|          | 0.00/23.3M [00:00<?, ?B/s]

transcript/el/train.tsv:   0%|          | 0.00/482k [00:00<?, ?B/s]

transcript/el/dev.tsv:   0%|          | 0.00/423k [00:00<?, ?B/s]

transcript/el/test.tsv:   0%|          | 0.00/410k [00:00<?, ?B/s]

transcript/el/other.tsv:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

transcript/el/invalidated.tsv:   0%|          | 0.00/201k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1914it [00:00, 138166.67it/s]


Generating validation split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1701it [00:00, 158759.90it/s]


Generating test split: 0 examples [00:00, ? examples/s]


Reading metadata...: 1696it [00:00, 141526.36it/s]


Generating other split: 0 examples [00:00, ? examples/s]


Reading metadata...: 9072it [00:00, 161141.08it/s]


Generating invalidated split: 0 examples [00:00, ? examples/s]


Reading metadata...: 797it [00:00, 119575.77it/s]


README.md: 0.00B [00:00, ?B/s]

fleurs.py: 0.00B [00:00, ?B/s]

The repository for google/fleurs contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/google/fleurs.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


data/el_gr/audio/train.tar.gz:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

data/el_gr/audio/dev.tar.gz:   0%|          | 0.00/141M [00:00<?, ?B/s]

data/el_gr/audio/test.tar.gz:   0%|          | 0.00/349M [00:00<?, ?B/s]

train.tsv: 0.00B [00:00, ?B/s]

dev.tsv: 0.00B [00:00, ?B/s]

test.tsv: 0.00B [00:00, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

IterableDatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 3100
    })
    validation: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 388
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 388
    })
})
IterableDatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 4248
    })
    validation: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 531
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 532
    })
})
IterableDatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 3308
    })
    validation: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 414
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 414
    })
})
IterableDatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
       

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

## Load Whisper Processor for Greek Transcription

In [None]:

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2", language="Greek", task="transcribe")

## Prepare and Vectorize the Dataset for Whisper

This step extracts input features from the audio, generates prompt token IDs for Greek transcription, tokenizes the target sentences, and constructs label sequences including prompt and EOS tokens. The dataset is then converted to PyTorch tensors and shuffled for training, validation, and testing.


In [None]:

def prepare_dataset(batch):
    audio = batch["audio"]
    sentence = batch["sentence"]

    # Extract input features
    input_features = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["input_features"] = input_features
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    # Get prompt IDs for language and task
    prompt_ids = processor.get_decoder_prompt_ids(language="Greek", task="transcribe")

    # Extract only the token IDs from prompt_ids tuples
    prompt_token_ids = [token_id for _, token_id in prompt_ids]

    # Tokenize sentence without adding special tokens
    sentence_ids = processor.tokenizer(sentence, add_special_tokens=False).input_ids

    # Construct full label sequence: prompt tokens + sentence tokens + EOS token
    labels = prompt_token_ids + sentence_ids + [processor.tokenizer.eos_token_id]

    batch["labels"] = labels
    return batch

vectorized_datasets = dataset.map(prepare_dataset, remove_columns=list(next(iter(dataset.values())).features)).with_format("torch")

vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=42)
vectorized_datasets["test"] = vectorized_datasets["test"].shuffle(seed=42)
vectorized_datasets["validation"] = vectorized_datasets["validation"].shuffle(seed=42)




Map:   0%|          | 0/10656 [00:00<?, ? examples/s]

Map:   0%|          | 0/1333 [00:00<?, ? examples/s]

Map:   0%|          | 0/1334 [00:00<?, ? examples/s]


This step filters out audio samples longer than 30 seconds from the training set to ensure all inputs are within the maximum allowed length for the model.


In [None]:
max_input_length = 30.0

def is_audio_in_length_range(length):
    return length < max_input_length

vectorized_datasets["train"] = vectorized_datasets["train"].filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)


## Data Collator for Speech-to-Text

This defines a custom data collator that pads input features and labels to create uniform batches for training. It ensures input features and label sequences are properly padded, masks padding tokens in the labels, and prepares the batch as PyTorch tensors.


In [None]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:

        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        #batch["input_features"].requires_grad_(True)

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch


data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)


## Evaluation Metrics

This function computes the Word Error Rate (WER) and Character Error Rate (CER) for model predictions.


In [None]:
import evaluate

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

do_normalize_eval = True

def compute_metrics(pred):

    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # convert ids into strings
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, normalize=do_normalize_eval)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True, normalize=do_normalize_eval)

    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
    cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

## Load Whisper Model


In [None]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="Greek", task="transcribe")

## Apply LoRA for Efficient Fine-Tuning

The Whisper model is prepared for **k-bit training** and adapted with **LoRA (Low-Rank Adaptation)** to fine-tune only specific projection layers (`q_proj`, `v_proj`, `k_proj`, `out_proj`).

In [None]:
from peft import prepare_model_for_kbit_training,  LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj","k_proj","out_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

model.enable_input_require_grads()

## Trainer Callback for Dataset Shuffling

Defines a custom ShuffleCallback that reshuffles or reinitializes iterable datasets at the beginning of each training epoch, ensuring proper data order and variability during training.


In [None]:
from transformers import TrainerCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset

# trainer callback to reinitialise and reshuffle the datasets at the beginning of each epoch
class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        if isinstance(train_dataloader.dataset, IterableDatasetShard):
            pass  # set_epoch() is handled by the Trainer
        elif isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)

## Training Setup

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper_large_checkpoints",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    remove_unused_columns=False,
    learning_rate=5e-5,
    warmup_steps=500,
    max_steps=1800,
    weight_decay=0.1,
    gradient_checkpointing=True,
    bf16=True,
    eval_strategy="steps",
    predict_with_generate=True,
    generation_max_length=250,
    save_steps=300,
    eval_steps=300,
    logging_steps=25,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
)

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=vectorized_datasets["train"],
    eval_dataset=vectorized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    callbacks=[ShuffleCallback()],
)

Filter:   0%|          | 0/9075 [00:00<?, ? examples/s]

trainable params: 31,457,280 || all params: 1,574,762,240 || trainable%: 1.9976


  trainer = Seq2SeqTrainer(


## Training

In [None]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvardisgeorgilas03[0m ([33mvardisgeorgilas03-athens-university-of-economics-and-bus[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Wer,Cer
250,0.1776,0.190408,13.518935,6.742842
500,0.1478,0.169763,12.552428,6.377864
750,0.1229,0.160847,12.327518,6.236739
1000,0.1057,0.160479,12.145158,6.257178


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Step,Training Loss,Validation Loss,Wer,Cer
250,0.1776,0.190408,13.518935,6.742842
500,0.1478,0.169763,12.552428,6.377864
750,0.1229,0.160847,12.327518,6.236739
1000,0.1057,0.160479,12.145158,6.257178
1250,0.0864,0.163021,12.649687,6.650381
1500,0.0677,0.164349,13.227159,7.347245
1750,0.0618,0.16811,12.856361,6.856715
2000,0.0533,0.16858,12.977934,7.004652


TrainOutput(global_step=2000, training_loss=0.122501784324646, metrics={'train_runtime': 16748.8367, 'train_samples_per_second': 3.821, 'train_steps_per_second': 0.119, 'total_flos': 1.38508821098496e+20, 'train_loss': 0.122501784324646, 'epoch': 7.143112701252236})

## Pushing model into Hugging Face

In [None]:
model.push_to_hub("Vardis/Whisper-Large-v2-Greek", use_auth_token="###################")
processor.push_to_hub("Vardis/Whisper-Large-v2-Greek", use_auth_token="###################")



adapter_model.safetensors:   0%|          | 0.00/126M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Vardis/Whisper-LoRA-Greek/commit/433fef6b27bfcf8af49b17434cc16d2eaf1c9a23', commit_message='Upload processor', commit_description='', oid='433fef6b27bfcf8af49b17434cc16d2eaf1c9a23', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Vardis/Whisper-LoRA-Greek', endpoint='https://huggingface.co', repo_type='model', repo_id='Vardis/Whisper-LoRA-Greek'), pr_revision=None, pr_num=None)