### Preliminary benchmark

In [81]:
import pandas as pd
import os
from transformers import (WhisperProcessor, WhisperForConditionalGeneration,Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    default_data_collator, TrainingArguments, Trainer, pipeline
)
import torchaudio
import torch
from peft import PeftModel, PeftConfig, PeftType
from datasets import Features, Value, Audio, load_dataset, Sequence
import kagglehub
from jiwer import wer,cer
from tqdm import tqdm
import numpy as np
import torch
from transformers.generation.logits_process import LogitsProcessorList, SuppressTokensLogitsProcessor
from transformers import LogitsProcessor
from rapidfuzz import process, fuzz
import requests
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
import evaluate
from typing import Dict, List, Any
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence
from huggingface_hub import login, Repository, HfApi
import soundfile as sf
import math
import librosa

In [2]:
dev_manifest = pd.read_csv("piper_out/dev_manifest.csv", sep = "|")
train_manifest = pd.read_csv("piper_out/train_manifest.csv", sep = "|")
test_manifest = pd.read_csv("piper_out/test_manifest.csv", sep = "|")

In [15]:
audio_files = dev_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references = dev_manifest["text"].tolist()

In [3]:
with open("leki_nom.txt", "r", encoding="utf-8") as f:
    leki = f.read().splitlines()

In [5]:
download_dir = "C:/Users/Admin/Downloads/Posts.csv"

path = kagglehub.model_download(
    "msxksm/whisper-medium-medical-pl/transformers/default"
)

In [None]:
api_key = "<together.ai api key>"

In [5]:
checkpoint = "natural_anonym_synth"
SAMPLING_RATE = 16000
language = "pl"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(checkpoint) 
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, checkpoint)
merged_model = model.merge_and_unload()
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)

In [16]:
def make_dataset(audio_folder, audio_files, references):
    examples = []
    for audio, ref in zip(audio_files, references):
        filename = os.path.join(audio_folder, audio)
        waveform, sr = torchaudio.load(filename)
        if sr != SAMPLING_RATE:
            waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        examples.append({"waveform": waveform, "phrase": ref})
    return examples

In [23]:
class BiasLogitsProcessor(LogitsProcessor):
    def __init__(self, bias_words, processor, boost=5.0):
        self.bias_token_ids = set()
        for word in bias_words:
            ids = processor.tokenizer.encode(word, add_special_tokens=False)
            self.bias_token_ids.update(ids)
        self.boost = boost
    def __call__(self, input_ids, scores):
        for token_id in self.bias_token_ids:
            if token_id < scores.shape[-1]:
                scores[:, token_id] += self.boost
        return scores


In [17]:
def transcribe_example(
    example,
    initial_prompt=None,
    bad_words=None,    
    bias_words = None, boost = 1,
    llm = None, merged_model = model
):
    input_features = processor(
        example["waveform"].squeeze(0),
        sampling_rate=SAMPLING_RATE,
        return_tensors="pt"
    ).input_features

    decoder_input_ids = None
    if initial_prompt:
        prompt_ids = processor.tokenizer(
            initial_prompt,
            add_special_tokens=False,
            return_tensors="pt"
        ).input_ids  

        forced_ids = torch.tensor([[tok_id for _, tok_id in forced_decoder_ids]])  

        decoder_input_ids = torch.cat([prompt_ids, forced_ids], dim=1)      


    bad_words_ids = None
    if bad_words and len(bad_words) > 0:
        bad_words_ids = processor.tokenizer(
            bad_words,
            add_special_tokens=False
        ).input_ids  
    
    logits_processor_list = []
    if bias_words and len(bias_words) > 0:
        bias_processor = BiasLogitsProcessor(bias_words, processor, boost=boost)
        logits_processor_list.append(bias_processor)

    with torch.no_grad():
        if decoder_input_ids is not None:
            predicted_ids = merged_model.generate( input_features, decoder_input_ids=decoder_input_ids,
                                                   bad_words_ids = bad_words_ids, num_beams=boost,
            logits_processor=logits_processor_list )[0] 
        else: 
            predicted_ids = merged_model.generate( input_features, forced_decoder_ids=forced_decoder_ids,
                                                   bad_words_ids = bad_words_ids, num_beams=boost,
            logits_processor=logits_processor_list )[0]

    transcription = processor.decode(predicted_ids, skip_special_tokens=True)
    return transcription



In [38]:
def postprocess_transcription(transcription, known_terms, threshold=85):
    words = transcription.split()
    corrected = []
    for w in words:
        match, score, _ = process.extractOne(w, known_terms, scorer=fuzz.ratio)
        if score >= threshold:
            corrected.append(match)
        else:
            corrected.append(w)
    return " ".join(corrected)

In [54]:
def correct_with_llm(transcription, api_key, model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"):
    url = "https://api.together.xyz/v1/chat/completions"
    headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}

    prompt = f"""
    Jesteś ekspertem medycznym. Oto transkrypcja mowy:
    ---
    {transcription}
    ---
    Popraw błędy językowe i zamień nieprecyzyjne wyrażenia
    na właściwe terminy medyczne. Nie dodawaj nowych informacji.
    Zwróć tylko poprawioną wersję.
    """

    payload = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 512,
        "temperature": 0.2
    }

    response = requests.post(url, headers=headers, json=payload)
    response.raise_for_status()
    return response.json()["choices"][0]["message"]["content"].strip()

In [18]:
dataset = make_dataset("dev_noisy", audio_files, references)



In [10]:
audio_files = train_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references_train = train_manifest["text"].tolist()

In [11]:
folder = "train_noisy"
folder_files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
folder_set = set(folder_files)
audio_files = [f for f in audio_files if f in folder_set]

In [55]:
train_dataset = make_dataset("train_noisy", audio_files, references_train)

In [None]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

In [111]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.9536
Subset CER: 0.5920


In [112]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nie spróbować kuracji w upromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-PRO.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptycy metafem, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się nurowenworte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam norowę Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci doldit w osłupę, pobudzę węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węz

In [34]:
folder = "subset"
files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

In [35]:
subset_dataset = make_dataset("subset", files, leki)



In [149]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [10:27<00:00,  8.36s/it]


In [150]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 1.0990


In [151]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostarosnie.
Reference: ActiFolin	Hypothesis: Abstryj folii.
Reference: Aliness Kwas Foliowy	Hypothesis: Alinis quas foillowy.
Reference: Amlodipina	Hypothesis: Auto do dipina.
Reference: Amoksycylina	Hypothesis: Amoż licejny.
Reference: Amotaks	Hypothesis: Amotax.
Reference: ApoD3	Hypothesis: Apoedektyczna opieka w KG.
Reference: Arterios	Hypothesis: Aktelio.
Reference: Augmentin	Hypothesis: Augunantnie.
Reference: Avamina	Hypothesis: Awa minął.
Reference: Biaron D	Hypothesis: Jak odbiór rąde.
Reference: Bibloc	Hypothesis: PiWC.
Reference: Bisakodyl	Hypothesis: Dli zakończonych leków nie zalecanią.
Reference: Bisocard	Hypothesis: I sotoc.
Reference: Bisoprolol	Hypothesis: Nie zauważyłem.
Reference: Bisoratio	Hypothesis: Wii sora dnia.
Reference: Concor	Hypothesis: Ponco".
Reference: Coronal	Hypothesis: Trona.
Reference: Crestor	Hypothesis: Krystolat w kontrolnym ZWK nie obniża się.
Reference: Crosuvo	Hypothesis: Proszę wą krossofool.
Reference: D-Vitum 

### Post-processing model without full fine-tuning

##### setting initial-prompt parameter

In [31]:
initial_prompt = "To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprofen, Metformin."

In [None]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing:   0%|          | 0/146 [00:00<?, ?it/s]Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Transcribing: 100%|██████████| 146/146 [41:27<00:00, 17.04s/it]


In [22]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.4496
Subset CER: 0.1566


In [26]:
for ref, hyp in zip(references[30:], hypotheses[30:]):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Na recepcie zapiszę Hydrochlorotiazyd.	Hypothesis: Na recepcie zapiszę hydrochorotiazyt.
Reference: Powinnaś przyjmować Hydrochlorotiazyd na nadciśnienie.	Hypothesis: Pożynnaś przyjmować hydrochlorotiazytna nadciśnienie.
Reference: Najlepiej w twoim przypadku sprawdzi się Losartan.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się losartem.
Reference: Możesz kupić w aptece Losartan, powinien pomóc.	Hypothesis:  Możesz kupić w aptecjum sartan, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Avamina.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się wamin.
Reference: Powinnaś spróbować kuracji Avaminą.	Hypothesis: Pożynne nasz to jest kreatura, kreaturia waminowa.
Reference: Przepiszę ci Formetic, bo zwykle dobrze działa.	Hypothesis: Przepiszę onciwormic, bo zwykle dobrze działa.
Reference: Możesz kupić w aptece Formetic, powinien pomóc.	Hypothesis:  Możesz kupić w Abtec Informatik, powinien pomóc.
Reference: Na recepcie zapiszę Glucophage XR jako lek

In [35]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [27:56<00:00, 22.35s/it]


In [36]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 2.5681


In [37]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostrosz.
Reference: ActiFolin	Hypothesis: Abstij folii.
Reference: Aliness Kwas Foliowy	Hypothesis: Alinis quas fojolowy.
Reference: Amlodipina	Hypothesis: Analogipna.
Reference: Amoksycylina	Hypothesis: Amocylina.
Reference: Amotaks	Hypothesis: Amo tachs.
Reference: ApoD3	Hypothesis: Apoedet 3.
Reference: Arterios	Hypothesis: Ateria.
Reference: Augmentin	Hypothesis: Paukowa nąciń.
Reference: Avamina	Hypothesis:  Awamina.
Reference: Biaron D	Hypothesis: Jak wiarą de do zespoňu zespoňu zespoňu zespoňem.
Reference: Bibloc	Hypothesis: Pierć.
Reference: Bisakodyl	Hypothesis: Dziesia kodyl.
Reference: Bisocard	Hypothesis: I SOTS op.
Reference: Bisoprolol	Hypothesis: Nie zrozumiem.
Reference: Bisoratio	Hypothesis: Wi sora tie.
Reference: Concor	Hypothesis: Ponco.
Reference: Coronal	Hypothesis: Trona.
Reference: Crestor	Hypothesis: Tak to.
Reference: Crosuvo	Hypothesis: Proszuję do poradni lekarza lekarza lekarza lekarza.
Reference: D-Vitum Forte	Hypothesis: W

In [39]:
initial_prompt = "Rozmowa medyczna. Mogą wystąpić nazwy leków, w tym: Ibuprom, Metformina, Amlodipina, Ramipryl, Omeprazol, Bisoprolol."

In [28]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [46:20<00:00, 19.05s/it] 


In [29]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.4657
Subset CER: 0.2351


In [30]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nespróbować kuracji ibupromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzić się ibuprom.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis:  Możesz kupić wapnę Ticimetaphen, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzić snu rofenforte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam Noropher Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę nidolgit, bo zwykle dobrze działa.
Reference: Na recepcie zapiszę Dolgit jako lek na ból i stan zapalny.	Hypothesis: Na recepcie zapisze on dolegijego leka ból i stan zapalny.
Reference: Na nadciś

In [40]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [31:26<00:00, 25.15s/it]   


In [43]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 1.7799


In [44]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostrosi.
Reference: ActiFolin	Hypothesis: Abstij folii.
Reference: Aliness Kwas Foliowy	Hypothesis: Alina z Kwas Foiowem.
Reference: Amlodipina	Hypothesis: Amlodipina.
Reference: Amoksycylina	Hypothesis: Amoż siedzylina.
Reference: Amotaks	Hypothesis: Amotax.
Reference: ApoD3	Hypothesis: Apoedetry.
Reference: Arterios	Hypothesis: Ateria.
Reference: Augmentin	Hypothesis: Pałka nańczyna.
Reference: Avamina	Hypothesis:  Awaimina.
Reference: Biaron D	Hypothesis: Jedl.
Reference: Bibloc	Hypothesis: Pierdzol.
Reference: Bisakodyl	Hypothesis: Dli zakonyli.
Reference: Bisocard	Hypothesis: I.
Reference: Bisoprolol	Hypothesis: Nie zrozumiał.
Reference: Bisoratio	Hypothesis: Wi sora tie.
Reference: Concor	Hypothesis: Ponco.
Reference: Coronal	Hypothesis: Trona.
Reference: Crestor	Hypothesis: Przepływ do oparzeń dożylnych do zespołu podatku w poradni lewej komory, z powodu oparcia zespołu powierzchni, z powodu oparzenia z oparzeń z oparzeń z oparzeń z oparzeń z opa

In [45]:
print(references)
print(hypotheses)

['Actarosin', 'ActiFolin', 'Aliness Kwas Foliowy', 'Amlodipina', 'Amoksycylina', 'Amotaks', 'ApoD3', 'Arterios', 'Augmentin', 'Avamina', 'Biaron D', 'Bibloc', 'Bisakodyl', 'Bisocard', 'Bisoprolol', 'Bisoratio', 'Concor', 'Coronal', 'Crestor', 'Crosuvo', 'D-Vitum Forte', 'Dapagliflozyna', 'Devikap', 'Dolgit', 'Folik', 'Formetic', 'Glucophage XR', 'Hydrochlorotiazyd', 'Ibuprom', 'Ibuvit D3', 'KFD Omega 3', 'Laktuloza', 'Linagliptyna', 'Liraglutyd', 'Losartan', 'Macromax', 'Melatonina', 'Metafen', 'Metformax', 'Metformina', 'Mirtazapina', 'Nurofen Forte', 'Olicaps Witamina D3', 'Omeprazol', 'Oriovit-D 1000', 'Ospen 1000', 'OstroVit Omega 3', 'Pantoprazol', 'Perindopryl', 'Prestilol', 'Ramipryl', 'Ridlip', 'Romazic', 'Rosucard', 'Rosutrox', 'Rosuvastatin Medical Valley', 'Roswera', 'Sitagliptyna', 'Sobycor', 'Sukralfat', 'Sumamed', 'Suvardio', 'Synjardy', 'Trazodon', 'Vigalex Bio', 'Vigalex Forte', 'Vigalex Max', 'Vigantol', 'Vigantoletten', 'Witamina D3 Forte', 'Xigduo', 'Zahron', 'Zarant

### Using suppressed words 

In [None]:
bad_words = [
    "Awaimina", "Netformaz", "Metaphen", "Mieta zapina", "suwardia", "wigantoletem", "i bufet", "amotax"
    "Nurowendą", "Oliczaw z Vitamina", "Tarzodą", "hydrochorotiazyt", "medformax", "devicap"
    "Witamina D340", "Ostrosi", "Metaphen", "Netformaz", "Nurowendą", "kąco", "corozalne", "rosferę"
]

In [56]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, bad_words=bad_words)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [1:24:00<00:00, 34.52s/it]  


In [57]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.9536
Subset CER: 0.5920


In [58]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nie spróbować kuracji w upromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-PRO.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptycy metafem, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się nurowenworte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam norowę Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci doldit w osłupę, pobudzę węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węz

### Testing soft bias with tokens favorization

In [28]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, bias_words=leki)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [36:13<00:00, 14.89s/it] 


In [29]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.5101
Subset CER: 0.3221


In [30]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Poza tym w Klinikacji Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-Prom.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić wapnę Tafan, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi

In [34]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, bias_words=leki)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [2:30:35<00:00, 120.48s/it]   


In [35]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 4.7263


In [36]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostrożnie.
Reference: ActiFolin	Hypothesis: XG-Foli.
Reference: Aliness Kwas Foliowy	Hypothesis: Alina z KVAS fojjowym.
Reference: Amlodipina	Hypothesis: Auto-drypina.
Reference: Amoksycylina	Hypothesis: Amocylina.
Reference: Amotaks	Hypothesis: Amutax.
Reference: ApoD3	Hypothesis: Apoid 3.
Reference: Arterios	Hypothesis: Ateria.
Reference: Augmentin	Hypothesis: Pałka na 3.
Reference: Avamina	Hypothesis: Awamina.
Reference: Biaron D	Hypothesis: Kieruję do poradni lekarza lekarza leczo-pokrwia.
Reference: Bibloc	Hypothesis: PiWC.
Reference: Bisakodyl	Hypothesis: Diakodyl.
Reference: Bisocard	Hypothesis: I sotoc.
Reference: Bisoprolol	Hypothesis: Nisoprodu.
Reference: Bisoratio	Hypothesis: WIJSR 3.
Reference: Concor	Hypothesis: Pontyk onakowy-vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv vvvv vvvv vvv v v v v v v v v v v v

### Fuzzy Matching

In [45]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    postprocessed_transcription = postprocess_transcription(transcription=transcription, known_terms=leki)
    hypotheses.append(postprocessed_transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [36:17<00:00, 14.91s/it] 


In [47]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.5484
Subset CER: 0.2668


In [48]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nie spróbować kuracji w upromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-PRO.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pobinna spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptycy metafem, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się nurowenworte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam norowę Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci dole git, bo zwykle dobrze działa.
Reference: Na recepcie zapiszę Dolgit jako lek na ból i stan zapalny.	Hypothesis: Na recepcie zapiszę dolegięcie do leka ból i stan zapalny.
Reference: Na nadciś

### LLM correction

In [None]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    corrected = correct_with_llm(transcription=transcription, api_key = api_key)
    hypotheses.append(corrected)
    references.append(ex["phrase"])

In [55]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.8586
Subset CER: 0.7186


In [56]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po spożyciu alkoholu nie należy próbować kuracji odwykowej.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: W Pana przypadku najskuteczniejsze będzie ibuprofenum.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pobinna spróbować kuracji metafenem. -> Powinna spróbować kuracji metamizolem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptece metforminę, powinna pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się niesteroidowy lek przeciwzapalny.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam doustnie Fortem, ponieważ dobrze działa na nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci diklofenak w żelu, bo zwykle dobrze działa.
Reference: Na recepcie zapiszę Dolgit jako lek na ból

## Wnioski:
- najlepiej WER i CER zdaje się obniżać parametr initial_prompt

## Fine tuning

In [124]:
checkpoint = "natural_anonym_synth"
SAMPLING_RATE = 16000
language = "pl"
task = "transcribe"
OUTPUT_DIR = "./whisper_medical_finetuned_lora"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_CHECKPOINT = checkpoint

LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "v_proj"] 

NUM_EPOCHS = 3
PER_DEVICE_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 5e-4
FP16 = True

MED_WEIGHT = 5.0

In [75]:
peft_config = PeftConfig.from_pretrained(checkpoint) 
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, checkpoint)

In [76]:
class WhisperTuner(PeftModel):
    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )

    def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None,
                decoder_inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None,
                return_dict=None, task_ids=None, **kwargs):
        allowed_args = {
            "input_features", "attention_mask", "decoder_input_ids", "decoder_attention_mask",
            "decoder_inputs_embeds", "labels", "output_attentions", "output_hidden_states", "return_dict"
        }

        model_args = {}
        for arg_name in allowed_args:
            arg_val = locals().get(arg_name)
            if arg_val is not None:
                model_args[arg_name] = arg_val

        for k, v in kwargs.items():
            if k in allowed_args and k not in model_args:
                model_args[k] = v

        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if peft_config.peft_type == PeftType.POLY:
                model_args["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(**model_args)

    def prepare_inputs_for_generation(self, *args, **kwargs):
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if "input_ids" in model_kwargs and "input_features" not in model_kwargs:
            model_kwargs["input_features"] = model_kwargs.pop("input_ids")
        return model_kwargs

In [77]:
lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias="none"
    )
model = WhisperTuner(model, lora_config)

In [None]:
model.to(DEVICE)
print("Model loaded. Trainable parameters:")
try:
    model.print_trainable_parameters()
except Exception:
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {n_trainable} / {n_total}")
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)

forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task, initial_prompt = initial_prompt)
model.config.forced_decoder_ids = forced_decoder_ids

Model loaded. Trainable parameters:
trainable params: 2,359,296 || all params: 766,217,216 || trainable%: 0.3079


In [79]:
audio_files = dev_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references = dev_manifest["text"].tolist()

In [None]:
data = []
audio_folder = "dev_noisy"

for audio_name, ref in zip(audio_files, references):
    filename = os.path.join(audio_folder, audio_name)
    waveform, sr = torchaudio.load(filename)

    if sr != SAMPLING_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    waveform = waveform.squeeze(0)

    inputs = processor(waveform, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    input_features = inputs.input_features[0].numpy()

    text_tokens = processor.tokenizer(ref, add_special_tokens=True).input_ids
    labels = prompt_tokens + text_tokens

    labels = processor.tokenizer(ref, add_special_tokens=True).input_ids
 
    data.append({
        "input_features": input_features, 
        "labels": labels 
    })
df = pd.DataFrame(data)




LibsndfileError: Error opening 'dev_noisy\\Ibuprom_0001_female.wav': System error.

In [None]:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    if input_ids.dim() != 2:
        raise ValueError("shift_tokens_right expects a 2D tensor")
    shifted = input_ids.new_full(input_ids.shape, pad_token_id)
    shifted[:, 1:] = input_ids[:, :-1].clone()
    shifted[:, 0] = decoder_start_token_id
    shifted[:, :len(prompt_tokens)] = input_ids[:, :len(prompt_tokens)]
    shifted[input_ids == -100] = pad_token_id
    return shifted

def collate_fn(batch):
    feats = []
    prompt_length = len(prompt_tokens)
    for i, x in enumerate(batch):
        t = x["input_features"]
        if not torch.is_tensor(t):
            t = torch.as_tensor(t, dtype=torch.float32)
        else:
            t = t.float()

        if t.ndim == 1:
            raise ValueError(f"input_features item {i} is 1D; expected 2D.")

        if t.shape[-1] == 80:
            seq = t
        elif t.shape[0] == 80:
            seq = t.transpose(0, 1)
        else:
            raise ValueError(f"input_features item {i} unexpected shape {t.shape}; expected feat dim 80.")
        feats.append(seq)

    padded = pad_sequence(feats, batch_first=True) 
    input_features = padded.transpose(1, 2) 

    labels_list = []
    for x in batch:
        lab = x["labels"]
        lab_t = torch.tensor(lab, dtype=torch.long) if not torch.is_tensor(lab) else lab.long()
        labels_list.append(lab_t)

    pad_id = processor.tokenizer.pad_token_id
    labels_padded = pad_sequence(labels_list, batch_first=True, padding_value=pad_id) 
    labels_for_loss = labels_padded.clone()
    labels_for_loss[labels_for_loss == pad_id] = -100

    bos_id = processor.tokenizer.bos_token_id if processor.tokenizer.bos_token_id is not None else processor.tokenizer.cls_token_id
    decoder_input_ids = []
    for labels_tensor in labels_list:
        decoder_seq = torch.cat([
            torch.tensor(prompt_tokens, dtype=torch.long),
            labels_tensor[:-1]
        ])
        decoder_input_ids.append(decoder_seq)

    decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=pad_id)
    decoder_input_ids = shift_tokens_right(labels_padded, pad_token_id=pad_id, decoder_start_token_id=bos_id)

    return {
        "input_features": input_features,    
        "labels": labels_for_loss,            
        "decoder_input_ids": decoder_input_ids
    }


In [82]:
# obiekt dataset z huggingface nie chciał przyjąć input_features
class FeatureDataset(Dataset):
    def __init__(self, df):
        self.feats = []
        for arr in df["input_features"].tolist():
            a = np.asarray(arr, dtype=np.float32)
            if a.ndim == 1:
                raise ValueError("Expected 2D input_features, got 1D.")
            self.feats.append(torch.from_numpy(a))
        self.labels = df["labels"].tolist()

    def __len__(self):
        return len(self.feats)

    def __getitem__(self, idx):
        labels = self.labels[idx]
        return {
            "input_features": self.feats[idx],   
            "labels": labels
        }

val_ds = FeatureDataset(df)
val_loader = DataLoader(val_ds, batch_size=PER_DEVICE_BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=True)

In [83]:
audio_files = train_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references_train = train_manifest["text"].tolist()
folder = "train_noisy"
folder_files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
folder_set = set(folder_files)
audio_files = [f for f in audio_files if f in folder_set]

In [84]:
data = []
audio_folder = "train_noisy"
for audio_name, ref in zip(audio_files, references_train):
    filename = os.path.join(audio_folder, audio_name)
    waveform, sr = torchaudio.load(filename)

    if sr != SAMPLING_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    waveform = waveform.squeeze(0) 

    inputs = processor(waveform, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    input_features = inputs.input_features[0].numpy()

    labels = processor.tokenizer(ref, add_special_tokens=True).input_ids
 
    data.append({
        "input_features": input_features, 
        "labels": labels 
    })

train_df = pd.DataFrame(data)


In [85]:
train_ds = FeatureDataset(train_df)
train_loader = DataLoader(train_ds, batch_size=PER_DEVICE_BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn, num_workers=0, pin_memory=True)

In [86]:
vocab_size = len(processor.tokenizer)
token_weights = torch.ones(vocab_size, dtype=torch.float32, device=DEVICE)
for word in leki:
    ids = processor.tokenizer.encode(word, add_special_tokens=False)
    for t in ids:
        if 0 <= t < vocab_size:
            token_weights[t] = MED_WEIGHT

In [87]:
wer_metric = evaluate.load("wer")

In [None]:
def compute_metrics_eval(preds: List[str], refs: List[str]) -> Dict[str, float]:
    wer = wer_metric.compute(references=refs, predictions=preds)
    tp = fp = fn = 0
    meds_lower = [m.lower() for m in leki]
    for hyp, ref in zip(preds, refs):
        hyp_tokens = set([t.lower().strip(".,") for t in hyp.split()])
        ref_tokens = set([t.lower().strip(".,") for t in ref.split()])
        for m in meds_lower:
            in_ref = m in ref_tokens
            in_hyp = m in hyp_tokens
            if in_ref and in_hyp:
                tp += 1
            elif in_hyp and not in_ref:
                fp += 1
            elif in_ref and not in_hyp:
                fn += 1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    return {"wer": wer, "precision_drugs": precision, "recall_drugs": recall}

In [89]:
def compute_metrics_for_trainer(eval_pred):
    preds_ids, labels_ids = eval_pred
    decoded_preds = processor.batch_decode(preds_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    decoded_labels = processor.batch_decode(labels_ids, skip_special_tokens=True)
    return compute_metrics_eval(decoded_preds, decoded_labels)

In [None]:
class WeightedLossTrainer(Seq2SeqTrainer):
    def __init__(self, token_weights_tensor: torch.Tensor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_weights = token_weights_tensor

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        decoder_input_ids = inputs.pop("decoder_input_ids", None)
        outputs = model(input_features=inputs["input_features"], decoder_input_ids=decoder_input_ids, forced_decoder_ids = forced_decoder_ids)
        logits = outputs.logits

        batch_size, seq_len, vocab = logits.size()
        logits_flat = logits.view(-1, vocab)
        labels_flat = labels.view(-1)
        loss_per_token = F.cross_entropy(logits_flat, labels_flat, reduction="none", ignore_index=-100)
        loss_per_token = loss_per_token.view(batch_size, seq_len)

        labels_masked = labels.clone()
        labels_masked[labels_masked == -100] = 0
        weights = self.token_weights[labels_masked]
        mask = (labels != -100).float()
        weights = weights * mask

        weighted_loss = (loss_per_token * weights).sum()
        normalizer = weights.sum().clamp_min(1.0)
        loss = weighted_loss / normalizer
        return (loss, outputs) if return_outputs else loss
    def get_train_dataloader(self):
        return train_loader

    def get_eval_dataloader(self, eval_dataset=None):
        return val_loader

### First entry arguments

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    do_train=True,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    fp16=FP16,
    num_train_epochs=1, 
    learning_rate=LEARNING_RATE,
    predict_with_generate=True, 
    save_strategy="steps",
    save_steps=4,
    save_total_limit=3, 
    logging_steps=1, 
    generation_max_length=64, 
    eval_strategy="steps",
    eval_steps=2, 
    logging_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    remove_unused_columns=False,
    max_steps = 8, 
    report_to = "none"
)
trainer = WeightedLossTrainer(
    token_weights_tensor=token_weights,
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics_for_trainer,
    data_collator=collate_fn
)


  super().__init__(*args, **kwargs)


In [92]:
print("Starting training...")
trainer.train()
print("Training finished. Saving model...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

Starting training...


Step,Training Loss,Validation Loss,Wer,Precision Drugs,Recall Drugs
2,49.8138,3.029984,0.362903,1.0,0.21978
4,40.0702,2.997135,0.360887,1.0,0.21978
6,38.5398,2.968384,0.361895,1.0,0.21978


Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


KeyboardInterrupt: 

### Final training:
- increased eval_steps, loggins_steps for faster training
- increased med_weight because of high precision and low recall
- model after postprocessing will be used (initial prompt)
- increase learning rate to make learning faster
- increase lora to better generalize on medicines



In [5]:
checkpoint = "natural_anonym_synth"
SAMPLING_RATE = 16000
language = "pl"
task = "transcribe"
OUTPUT_DIR = "./whisper_medical_drugs"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_CHECKPOINT = checkpoint

LORA_R = 10
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "v_proj"] 

NUM_EPOCHS = 3
PER_DEVICE_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 9e-4
FP16 = True

MED_WEIGHT = 10.0
initial_prompt = "To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprom."

In [6]:
peft_config = PeftConfig.from_pretrained(checkpoint) 
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, checkpoint)

In [7]:
class WhisperTuner(PeftModel):
    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )

    def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None,
                decoder_inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None,
                return_dict=None, task_ids=None, **kwargs):
        allowed_args = {
            "input_features", "attention_mask", "decoder_input_ids", "decoder_attention_mask",
            "decoder_inputs_embeds", "labels", "output_attentions", "output_hidden_states", "return_dict"
        }

        model_args = {}
        for arg_name in allowed_args:
            arg_val = locals().get(arg_name)
            if arg_val is not None:
                model_args[arg_name] = arg_val
                
        for k, v in kwargs.items():
            if k in allowed_args and k not in model_args:
                model_args[k] = v

        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if peft_config.peft_type == PeftType.POLY:
                model_args["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(**model_args)

    def prepare_inputs_for_generation(self, *args, **kwargs):
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if "input_ids" in model_kwargs and "input_features" not in model_kwargs:
            model_kwargs["input_features"] = model_kwargs.pop("input_ids")
        if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids:
            model_kwargs["forced_decoder_ids"] = self.config.forced_decoder_ids
        
        return model_kwargs

In [8]:
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none"
)

In [9]:
model = WhisperTuner(model, lora_config)
model.to(DEVICE)

model.processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)

print("Model loaded. Trainable parameters:")
try:
    model.print_trainable_parameters()
except Exception:
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {n_trainable} / {n_total}")

processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
model.config.forced_decoder_ids = forced_decoder_ids



Model loaded. Trainable parameters:
trainable params: 2,949,120 || all params: 766,807,040 || trainable%: 0.3846


In [10]:
audio_files = dev_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references = dev_manifest["text"].tolist()
data = []
audio_folder = "dev_noisy"

prompt_tokens = processor.tokenizer(initial_prompt, add_special_tokens=False).input_ids
prompt_length = len(prompt_tokens)

In [11]:
for audio_name, ref in zip(audio_files, references):
    filename = os.path.join(audio_folder, audio_name)
    waveform, sr = torchaudio.load(filename)

    if sr != SAMPLING_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    waveform = waveform.squeeze(0)

    inputs = processor(waveform, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    input_features = inputs.input_features[0].numpy()

    text_tokens = processor.tokenizer(ref, add_special_tokens=True).input_ids
    labels = prompt_tokens + text_tokens 

    data.append({
        "input_features": input_features, 
        "labels": labels 
    })

df = pd.DataFrame(data)



In [12]:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    if input_ids.dim() != 2:
        raise ValueError("shift_tokens_right expects a 2D tensor")
    shifted = input_ids.new_full(input_ids.shape, pad_token_id)
    shifted[:, 1:] = input_ids[:, :-1].clone()
    shifted[:, 0] = decoder_start_token_id
    shifted[input_ids == -100] = pad_token_id
    return shifted

def collate_fn(batch):
    feats = []
    for i, x in enumerate(batch):
        t = x["input_features"]
        if not torch.is_tensor(t):
            t = torch.as_tensor(t, dtype=torch.float32)
        else:
            t = t.float()

        if t.ndim == 1:
            raise ValueError(f"input_features item {i} is 1D; expected 2D.")

        if t.shape[-1] == 80:
            seq = t
        elif t.shape[0] == 80:
            seq = t.transpose(0, 1)
        else:
            raise ValueError(f"input_features item {i} unexpected shape {t.shape}; expected feat dim 80.")
        feats.append(seq)

    padded = pad_sequence(feats, batch_first=True) 
    input_features = padded.transpose(1, 2) 

    labels_list = []
    for x in batch:
        lab = x["labels"]
        lab_t = torch.tensor(lab, dtype=torch.long) if not torch.is_tensor(lab) else lab.long()
        labels_list.append(lab_t)

    pad_id = processor.tokenizer.pad_token_id
    labels_padded = pad_sequence(labels_list, batch_first=True, padding_value=pad_id) 
    labels_for_loss = labels_padded.clone()
    labels_for_loss[labels_for_loss == pad_id] = -100

    bos_id = processor.tokenizer.bos_token_id if processor.tokenizer.bos_token_id is not None else processor.tokenizer.cls_token_id
    decoder_input_ids = shift_tokens_right(labels_padded, pad_token_id=pad_id, decoder_start_token_id=bos_id)

    return {
        "input_features": input_features,    
        "labels": labels_for_loss,            
        "decoder_input_ids": decoder_input_ids
    }


In [13]:
class FeatureDataset(Dataset):
    def __init__(self, df):
        self.feats = []
        for arr in df["input_features"].tolist():
            a = np.asarray(arr, dtype=np.float32)
            if a.ndim == 1:
                raise ValueError("Expected 2D input_features, got 1D.")
            self.feats.append(torch.from_numpy(a))
        self.labels = df["labels"].tolist()

    def __len__(self):
        return len(self.feats)

    def __getitem__(self, idx):
        labels = self.labels[idx]
        return {
            "input_features": self.feats[idx],   
            "labels": labels
        }

val_ds = FeatureDataset(df)
val_loader = DataLoader(val_ds, batch_size=PER_DEVICE_BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=True)

In [14]:
audio_files = train_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references_train = train_manifest["text"].tolist()
folder = "train_noisy"
folder_files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
folder_set = set(folder_files)
audio_files = [f for f in audio_files if f in folder_set]
data = []
audio_folder = "train_noisy"


In [None]:
for audio_name, ref in zip(audio_files, references_train):
    filename = os.path.join(audio_folder, audio_name)
    waveform, sr = torchaudio.load(filename)

    if sr != SAMPLING_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    waveform = waveform.squeeze(0) 

    inputs = processor(waveform, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    input_features = inputs.input_features[0].numpy()

    text_tokens = processor.tokenizer(ref, add_special_tokens=True).input_ids
    labels = prompt_tokens + text_tokens 

    data.append({
        "input_features": input_features, 
        "labels": labels 
    })

train_df = pd.DataFrame(data)
train_ds = FeatureDataset(train_df)
train_loader = DataLoader(train_ds, batch_size=PER_DEVICE_BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn, num_workers=0, pin_memory=True)

In [16]:
vocab_size = len(processor.tokenizer)
token_weights = torch.ones(vocab_size, dtype=torch.float32, device=DEVICE)
for word in leki:
    ids = processor.tokenizer.encode(word, add_special_tokens=False)
    for t in ids:
        if 0 <= t < vocab_size:
            token_weights[t] = MED_WEIGHT

In [83]:
wer_metric = evaluate.load("wer")

Downloading builder script: 5.13kB [00:00, 5.33MB/s]


In [18]:
def compute_metrics_eval(preds: List[str], refs: List[str]) -> Dict[str, float]:
    preds_clean = [pred.replace(initial_prompt, "").strip() for pred in preds]
    refs_clean = [ref.replace(initial_prompt, "").strip() for ref in refs]
    
    wer = wer_metric.compute(references=refs_clean, predictions=preds_clean)
    tp = fp = fn = 0
    meds_lower = [m.lower() for m in leki]
    for hyp, ref in zip(preds_clean, refs_clean):
        hyp_tokens = set([t.lower().strip(".,") for t in hyp.split()])
        ref_tokens = set([t.lower().strip(".,") for t in ref.split()])
        for m in meds_lower:
            in_ref = m in ref_tokens
            in_hyp = m in hyp_tokens
            if in_ref and in_hyp:
                tp += 1
            elif in_hyp and not in_ref:
                fp += 1
            elif in_ref and not in_hyp:
                fn += 1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    return {"wer": wer, "precision_drugs": precision, "recall_drugs": recall}

In [20]:
def compute_metrics_for_trainer(eval_pred):
    preds_ids, labels_ids = eval_pred
    decoded_preds = processor.batch_decode(preds_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    decoded_labels = processor.batch_decode(labels_ids, skip_special_tokens=True)
    return compute_metrics_eval(decoded_preds, decoded_labels)

In [21]:
class WeightedLossTrainer(Seq2SeqTrainer):
    def __init__(self, token_weights_tensor: torch.Tensor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_weights = token_weights_tensor

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        decoder_input_ids = inputs.pop("decoder_input_ids", None)
        outputs = model(input_features=inputs["input_features"], decoder_input_ids=decoder_input_ids)
        logits = outputs.logits

        batch_size, seq_len, vocab = logits.size()
        logits_flat = logits.view(-1, vocab)
        labels_flat = labels.view(-1)
        loss_per_token = F.cross_entropy(logits_flat, labels_flat, reduction="none", ignore_index=-100)
        loss_per_token = loss_per_token.view(batch_size, seq_len)

        labels_masked = labels.clone()
        labels_masked[labels_masked == -100] = 0
        weights = self.token_weights[labels_masked]
        mask = (labels != -100).float()
        weights = weights * mask

        weighted_loss = (loss_per_token * weights).sum()
        normalizer = weights.sum().clamp_min(1.0)
        loss = weighted_loss / normalizer
        return (loss, outputs) if return_outputs else loss

    def get_train_dataloader(self):
        return train_loader

    def get_eval_dataloader(self, eval_dataset=None):
        return val_loader

In [22]:
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    do_train=True,
    do_eval = True,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=1, 
    learning_rate=LEARNING_RATE,
    predict_with_generate=True, 
    save_strategy="steps",
    save_steps=20,
    save_total_limit=2, 
    generation_max_length=64 + prompt_length, 
    eval_strategy="steps",
    eval_steps=20,  
    logging_strategy="steps",
    logging_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    remove_unused_columns=False,
    max_steps = -1, 
    report_to = "none"
)

In [23]:
trainer = WeightedLossTrainer(
    token_weights_tensor=token_weights,
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics_for_trainer,
    data_collator=collate_fn
)

  super().__init__(*args, **kwargs)


In [24]:
print("Starting training...")
trainer.train()
print("Training finished. Saving model...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

Starting training...




Step,Training Loss,Validation Loss,Wer,Precision Drugs,Recall Drugs
20,12.1352,0.932314,1.131048,1.0,0.175824
40,6.8104,0.4259,1.715726,0.471698,0.274725
60,4.4499,0.402264,3.173387,0.45,0.197802


Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Training finished. Saving model...


[]

### Merging model with base

In [6]:
BASE_CHECKPOINT = "./natural_anonym_synth"   
LORA_DIR = "./whisper_medical_drugs"        
MERGED_DIR = "./whisper_merged" 
language = "pl"
task = "transcribe"
SAMPLING_RATE = 16000

In [22]:
peft_config = PeftConfig.from_pretrained(BASE_CHECKPOINT) 
base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path)

processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)

In [23]:
model_with_peft = PeftModel.from_pretrained(base_model, LORA_DIR)
merged_model = model_with_peft.merge_and_unload()
merged_model.to("cpu")



WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1024, 1024, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1024)
      (layers): ModuleList(
        (0-23): 24 x WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=Tru

In [24]:
os.makedirs(MERGED_DIR, exist_ok=True)
merged_model.save_pretrained(MERGED_DIR)
processor.save_pretrained(MERGED_DIR)



[]

In [7]:
model = WhisperForConditionalGeneration.from_pretrained(MERGED_DIR)
processor = WhisperProcessor.from_pretrained(MERGED_DIR, language=language, task=task)
model.eval()

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1024, 1024, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1024)
      (layers): ModuleList(
        (0-23): 24 x WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=Tru

In [40]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt = "To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprom.", merged_model=merged_model)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing:   0%|          | 0/146 [00:00<?, ?it/s]

Transcribing: 100%|██████████| 146/146 [40:02<00:00, 16.45s/it]


In [87]:
def compute_metrics_eval(preds: List[str], refs: List[str], leki: List[str], initial_prompt: str) -> Dict[str, float]:
    preds_clean = [pred.replace(initial_prompt, "").strip() for pred in preds]
    refs_clean = [ref.replace(initial_prompt, "").strip() for ref in refs]

    tp = fp = fn = 0
    meds_lower = [m.lower() for m in leki]

    for hyp, ref in zip(preds_clean, refs_clean):
        hyp_tokens = set([t.lower().strip(".,") for t in hyp.split()])
        ref_tokens = set([t.lower().strip(".,") for t in ref.split()])

        for m in meds_lower:
            in_ref = m in ref_tokens
            in_hyp = m in hyp_tokens
            if in_ref and in_hyp:
                tp += 1
            elif in_hyp and not in_ref:
                fp += 1
            elif in_ref and not in_hyp:
                fn += 1

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

    return {"precision_drugs": precision, "recall_drugs": recall}

In [44]:
metrics = compute_metrics_eval(hypotheses, references, leki, initial_prompt="To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprom.")
print(metrics)

{'wer': 0.34274193548387094, 'precision_drugs': 0.96, 'recall_drugs': 0.26373626373626374}


In [47]:
for ref, hyp in zip(references[60:], hypotheses[60:]):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Jeśli masz wysoki cholesterol dobrze sprawdza się Crestor.	Hypothesis:  Jeśli masz wysoki cholesterol dobrze sprawdza się Crestor.
Reference: Przepiszę ci Crestor, bo zwykle dobrze działa.	Hypothesis:  Przepiszecie Crestor, bo zwykle dobrze działa.
Reference: Najlepiej w twoim przypadku sprawdzi się Crosuvo.	Hypothesis:  Najlepiej w twoim przypadku sprawdzi się Crosuvo.
Reference: Możesz kupić w aptece Crosuvo, powinien pomóc.	Hypothesis:  Możesz kupić w aptece Crossowo, powinien pomóc.
Reference: Na wysoki cholesterol przepiszę ci Ridlip.	Hypothesis:  Na wysoki cholesterol przepisze Ciliglip.
Reference: Przepiszę ci Ridlip, bo zwykle dobrze działa.	Hypothesis:  Przepiszę Ci Lidl, bo zwykle dobrze działa.
Reference: Na wysoki cholesterol przepiszę ci Romazic.	Hypothesis:  Na wysoki cholesterol przepiszę Ciromazic.
Reference: Powinnaś spróbować kuracji Romazicem.	Hypothesis:  Powinna spróbować kuracji romazicem.
Reference: Na recepcie zapiszę Rosucard.	Hypothesis:  Na recepci

### Deployment

In [None]:
token = "<hugging face token>"
login(token)

In [53]:
api = HfApi()
api.create_repo("medical-polish-drugs-whisper", exist_ok=True) 

RepoUrl('https://huggingface.co/wysokAIczad/medical-polish-drugs-whisper', endpoint='https://huggingface.co', repo_type='model', repo_id='wysokAIczad/medical-polish-drugs-whisper')

In [55]:
merged_model.push_to_hub("wysokAIczad/medical-polish-drugs-whisper")
processor.push_to_hub("wysokAIczad/medical-polish-drugs-whisper")

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            
[A
Processing Files (0 / 1)                :   0%|          | 99.6kB / 3.06GB, 55.3kB/s  
[A
[A
[A
[A
[A
[A
Processing Files (0 / 1)                :   2%|▏         | 50.5MB / 3.06GB, 15.8MB/s  
Processing Files (0 / 1)                :   3%|▎         |  101MB / 3.06GB, 29.7MB/s  
Processing Files (0 / 1)                :   5%|▍         |  151MB / 3.06GB, 42.1MB/s  
Processing Files (0 / 1)                :   6%|▌         |  185MB / 3.06GB, 48.8MB/s  
Processing Files (0 / 1)                :   7%|▋         |  218MB / 3.06GB, 54.7MB/s  
Processing Files (0 / 1)                :   8%|▊         |  251MB / 3.06GB, 59.9MB/s  
Processing Files (0 / 1)                :  10%|▉         |  293MB / 3.06GB, 66.6MB/s  
Processing Files (0 / 1)                :  11%|█         |  327MB / 3.06GB, 71.2MB/s  
Processing Files (0 / 1)                :  12%|█▏        |  368MB / 3.06GB, 77.0MB/s  
Processing Files (0

CommitInfo(commit_url='https://huggingface.co/wysokAIczad/medical-polish-drugs-whisper/commit/50e6957f4bf120516cc4f7fc4964cec0be3f96c4', commit_message='Upload processor', commit_description='', oid='50e6957f4bf120516cc4f7fc4964cec0be3f96c4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/wysokAIczad/medical-polish-drugs-whisper', endpoint='https://huggingface.co', repo_type='model', repo_id='wysokAIczad/medical-polish-drugs-whisper'), pr_revision=None, pr_num=None)

### Whole pipeline for usage

In [75]:
processor = WhisperProcessor.from_pretrained("pwysoc/medical-polish-drugs-whisper")
model = WhisperForConditionalGeneration.from_pretrained("pwysoc/medical-polish-drugs-whisper")

In [70]:
waveform, sr = torchaudio.load("file_with_medical_name.wav")
SAMPLING_RATE = 16000
if sr != 16000:
    waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)
if waveform.shape[0] > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)
input_features = processor(
        waveform.squeeze(0),
        sampling_rate=SAMPLING_RATE,
        return_tensors="pt"
    ).input_features

decoder_input_ids = None
initial_prompt="To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprom."
prompt_ids = processor.tokenizer(
    initial_prompt,
    add_special_tokens=False,
    return_tensors="pt"
).input_ids  

forced_ids = torch.tensor([[tok_id for _, tok_id in forced_decoder_ids]])  

decoder_input_ids = torch.cat([prompt_ids, forced_ids], dim=1)   

predicted_ids = model.generate( input_features, decoder_input_ids=decoder_input_ids)[0] 

transcription = processor.decode(predicted_ids, skip_special_tokens=True)

print(transcription)



 Najlepiej w twoim przypadku sprawdzi się Acifolin.


### Evaluation on test set
- wer
- precision and recall on drugs
- confidence per word
- manual analysis of common mistakes
- data recorded by real human

In [73]:
audio_files = test_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references = test_manifest["text"].tolist()

In [74]:
test_dataset = make_dataset("test_noisy", audio_files, references)



In [77]:
hypotheses = []
references = []

for ex in tqdm(test_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt = "To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprom.", merged_model=model)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [43:01<00:00, 17.68s/it]


In [88]:
wer_score = wer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")

Subset WER: 0.3547


In [89]:
metrics = compute_metrics_eval(hypotheses, references, leki, initial_prompt="To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprom.")
print(metrics)

{'precision_drugs': 1.0, 'recall_drugs': 0.23255813953488372}


In [90]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Na recepcie zapiszę Ibuprom jako lek na ból i stan zapalny.	Hypothesis:  Na recepcie zapiszą Ibuprom jako lek na ból i stan zapalny.
Reference: Powinnaś przyjmować Ibuprom na ból i stan zapalny.	Hypothesis:  Powinnaś przyjmować imion na ból i stan zapalny.
Reference: Powinnaś przyjmować Metafen na ból i stan zapalny.	Hypothesis:  Powinnaś przyjmować metafen na ból i stan zapalny.
Reference: Na recepcie zapiszę Metafen jako lek na ból i stan zapalny.	Hypothesis:  Na recepcie zapisze metafen jako lek na ból i stan zapalny.
Reference: Przepiszę ci Nurofen Forte, bo zwykle dobrze działa.	Hypothesis:  Przepisze Ci nur off and forte, bo zwykle dobrze działa.
Reference: Powinnaś spróbować kuracji Nurofenem Forte.	Hypothesis:  Powinna spróbować kuracji nurofenem forte.
Reference: Możesz kupić w aptece Dolgit, powinien pomóc.	Hypothesis:  Możesz kupić w aptece Dolgit, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Dolgit.	Hypothesis:  Najlepiej w twoim przypadku 

In [91]:
def preprocess_audio(audio, model = model):
    model.eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    input_features = []
    inputs = processor(
        audio,
        sampling_rate=16000,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=30 * 16000
    )
    input_features.append(inputs.input_features.to(device))

    return input_features

In [92]:
def greedy_decode_with_logprobs(input_feature, max_tokens=448):
    model.eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    
    encoder_outputs = model.get_encoder()(input_feature)
    decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]], device=device)

    tokens = []
    logps = []

    for _ in range(max_tokens):
        outputs = model(
            input_features=None,
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            return_dict=True
        )
        logits = outputs.logits[:, -1, :]              
        log_probs = F.log_softmax(logits, dim=-1)         

        next_id = torch.argmax(log_probs, dim=-1)           # greedy token
        lp = log_probs[0, next_id].item()                   # log-probability wybranego

        tokens.append(next_id.item())
        logps.append(lp)

        decoder_input_ids = torch.cat([decoder_input_ids, next_id.unsqueeze(0)], dim=-1)

        if next_id.item() == processor.tokenizer.eos_token_id:
            break

    return tokens, logps

In [93]:
def tokens_to_word_confidences(tokens, logps):
    special_ids = set(processor.tokenizer.all_special_ids)
    filtered = [(t, lp) for t, lp in zip(tokens, logps) if t not in special_ids]
    if not filtered:
        return []
    filt_tokens, filt_logps = zip(*filtered)

    tok_strs = processor.tokenizer.convert_ids_to_tokens(list(filt_tokens))

    words = []
    confidences = []
    curr_tokens = []
    curr_logps = []

    for tok, lp, tid in zip(tok_strs, filt_logps, filt_tokens):
        
        is_new_word = tok.startswith('Ġ')

        if is_new_word and curr_tokens:

            word_text = processor.tokenizer.decode(curr_tokens, skip_special_tokens=True).strip()
            avg_lp = sum(curr_logps) / len(curr_logps)
            words.append(word_text)
            confidences.append(math.exp(avg_lp))
            curr_tokens = []
            curr_logps = []

        curr_tokens.append(tid)
        curr_logps.append(lp)

    if curr_tokens:
        word_text = processor.tokenizer.decode(curr_tokens, skip_special_tokens=True).strip()
        avg_lp = sum(curr_logps) / len(curr_logps)
        words.append(word_text)
        confidences.append(math.exp(avg_lp))

    return list(zip(words, confidences))

In [None]:
THRESHOLD = 0.7
def transcribe_with_confidence(audio_path, output_file="word_confidence.txt", threshold=THRESHOLD):
    feats_list = preprocess_audio(audio_path)
    
    with open(output_file, "a", encoding="utf-8") as f:
        
        for idx, feat in enumerate(feats_list, start=1):
            tokens, logps = greedy_decode_with_logprobs(feat)
            pairs = tokens_to_word_confidences(tokens, logps)

            full_transcript = " ".join(word for word, _ in pairs)
            f.write(f"Transkrypcja pełna: {full_transcript}\n")

            low_conf_words = [(word, conf) for word, conf in pairs if conf < threshold]
            if low_conf_words:
                for word, conf in low_conf_words:
                    f.write(f"{word}: confidence={conf:.2f} ")
                f.write("\n")


In [105]:
def load_audio_file(audio_path, sr=16000):
    if audio_path == None:
        return None
    audio, _ = librosa.load(audio_path, sr=sr)
    return audio

In [106]:
folder_wav = "test_noisy"
files = sorted(os.listdir(folder_wav))

for path in files[::2]:
    audio = load_audio_file(os.path.join(folder_wav, path))
    transcribe_with_confidence(audio)