### Preliminary benchmark

In [1]:
import pandas as pd
import os
from transformers import (WhisperProcessor, WhisperForConditionalGeneration,Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    default_data_collator, TrainingArguments, Trainer
)
import torchaudio
import torch
from peft import PeftModel, PeftConfig, PeftType
from datasets import Features, Value, Audio, load_dataset, Dataset
import kagglehub
from jiwer import wer,cer
from tqdm import tqdm
import numpy as np
import torch
from transformers.generation.logits_process import LogitsProcessorList, SuppressTokensLogitsProcessor
from transformers import LogitsProcessor
from rapidfuzz import process, fuzz
import requests
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
import evaluate
from typing import Dict, List, Any
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence

  from .autonotebook import tqdm as notebook_tqdm


In [49]:
dev_manifest = pd.read_csv("piper_out/dev_manifest.csv", sep = "|")
train_manifest = pd.read_csv("piper_out/train_manifest.csv", sep = "|")
test_manifest = pd.read_csv("piper_out/test_manifest.csv", sep = "|")

In [50]:
audio_files = dev_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references = dev_manifest["text"].tolist()

In [51]:
with open("leki_nom.txt", "r", encoding="utf-8") as f:
    leki = f.read().splitlines()

In [10]:
download_dir = "C:/Users/Admin/Downloads/Posts.csv"

path = kagglehub.model_download(
    "msxksm/whisper-medium-medical-pl/transformers/default"
)

In [None]:
api_key = "<together.ai api key>"

In [11]:
checkpoint = "natural_anonym_synth"
SAMPLING_RATE = 16000
language = "pl"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(checkpoint) 
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, checkpoint)
merged_model = model.merge_and_unload()
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)

In [9]:
def make_dataset(audio_folder, audio_files, references):
    examples = []
    for audio, ref in zip(audio_files, references):
        filename = os.path.join(audio_folder, audio)
        waveform, sr = torchaudio.load(filename)
        if sr != SAMPLING_RATE:
            waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        examples.append({"waveform": waveform, "phrase": ref})
    return examples

In [23]:
class BiasLogitsProcessor(LogitsProcessor):
    def __init__(self, bias_words, processor, boost=5.0):
        self.bias_token_ids = set()
        for word in bias_words:
            ids = processor.tokenizer.encode(word, add_special_tokens=False)
            self.bias_token_ids.update(ids)
        self.boost = boost
    def __call__(self, input_ids, scores):
        for token_id in self.bias_token_ids:
            if token_id < scores.shape[-1]:
                scores[:, token_id] += self.boost
        return scores


In [None]:
def transcribe_example(
    example,
    initial_prompt=None,
    bad_words=None,    
    bias_words = None, boost = 1,
    llm = None
):
    input_features = processor(
        example["waveform"].squeeze(0),
        sampling_rate=SAMPLING_RATE,
        return_tensors="pt"
    ).input_features

    decoder_input_ids = None
    if initial_prompt:
        prompt_ids = processor.tokenizer(
            initial_prompt,
            add_special_tokens=False,
            return_tensors="pt"
        ).input_ids  

        forced_ids = torch.tensor([[tok_id for _, tok_id in forced_decoder_ids]])  

        decoder_input_ids = torch.cat([prompt_ids, forced_ids], dim=1)      


    bad_words_ids = None
    if bad_words and len(bad_words) > 0:
        bad_words_ids = processor.tokenizer(
            bad_words,
            add_special_tokens=False
        ).input_ids  
    
    logits_processor_list = []
    if bias_words and len(bias_words) > 0:
        bias_processor = BiasLogitsProcessor(bias_words, processor, boost=boost)
        logits_processor_list.append(bias_processor)

    with torch.no_grad():
        if decoder_input_ids is not None:
            predicted_ids = merged_model.generate( input_features, decoder_input_ids=decoder_input_ids,
                                                   bad_words_ids = bad_words_ids, num_beams=boost,
            logits_processor=logits_processor_list )[0] 
        else: 
            predicted_ids = merged_model.generate( input_features, forced_decoder_ids=forced_decoder_ids,
                                                   bad_words_ids = bad_words_ids, num_beams=boost,
            logits_processor=logits_processor_list )[0]

    transcription = processor.decode(predicted_ids, skip_special_tokens=True)
    return transcription



In [38]:
def postprocess_transcription(transcription, known_terms, threshold=85):
    words = transcription.split()
    corrected = []
    for w in words:
        match, score, _ = process.extractOne(w, known_terms, scorer=fuzz.ratio)
        if score >= threshold:
            corrected.append(match)
        else:
            corrected.append(w)
    return " ".join(corrected)

In [7]:
def correct_with_llm(transcription, api_key, model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"):
    url = "https://api.together.xyz/v1/chat/completions"
    headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}

    prompt = f"""
    Jesteś ekspertem medycznym. Oto transkrypcja mowy:
    ---
    {transcription}
    ---
    Popraw błędy językowe i zamień nieprecyzyjne wyrażenia
    na właściwe terminy medyczne. Nie dodawaj nowych informacji.
    Zwróć tylko poprawioną wersję.
    """

    payload = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 512,
        "temperature": 0.2
    }

    response = requests.post(url, headers=headers, json=payload)
    response.raise_for_status()
    return response.json()["choices"][0]["message"]["content"].strip()

In [52]:
dataset = make_dataset("dev_noisy", audio_files, references)



In [53]:
audio_files = train_manifest["wav_path"]
audio_files = [os.path.basename(f) for f in audio_files]
references_train = train_manifest["text"].tolist()

In [54]:
folder = "train_noisy"
folder_files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
folder_set = set(folder_files)
audio_files = [f for f in audio_files if f in folder_set]

In [63]:
train_dataset = make_dataset("train_noisy", audio_files, references_train)



In [None]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

In [111]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.9536
Subset CER: 0.5920


In [112]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nie spróbować kuracji w upromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-PRO.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptycy metafem, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się nurowenworte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam norowę Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci doldit w osłupę, pobudzę węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węz

In [32]:
folder = "subset"
files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

In [33]:
subset_dataset = make_dataset("subset", files, leki)



In [149]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [10:27<00:00,  8.36s/it]


In [150]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 1.0990


In [151]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostarosnie.
Reference: ActiFolin	Hypothesis: Abstryj folii.
Reference: Aliness Kwas Foliowy	Hypothesis: Alinis quas foillowy.
Reference: Amlodipina	Hypothesis: Auto do dipina.
Reference: Amoksycylina	Hypothesis: Amoż licejny.
Reference: Amotaks	Hypothesis: Amotax.
Reference: ApoD3	Hypothesis: Apoedektyczna opieka w KG.
Reference: Arterios	Hypothesis: Aktelio.
Reference: Augmentin	Hypothesis: Augunantnie.
Reference: Avamina	Hypothesis: Awa minął.
Reference: Biaron D	Hypothesis: Jak odbiór rąde.
Reference: Bibloc	Hypothesis: PiWC.
Reference: Bisakodyl	Hypothesis: Dli zakończonych leków nie zalecanią.
Reference: Bisocard	Hypothesis: I sotoc.
Reference: Bisoprolol	Hypothesis: Nie zauważyłem.
Reference: Bisoratio	Hypothesis: Wii sora dnia.
Reference: Concor	Hypothesis: Ponco".
Reference: Coronal	Hypothesis: Trona.
Reference: Crestor	Hypothesis: Krystolat w kontrolnym ZWK nie obniża się.
Reference: Crosuvo	Hypothesis: Proszę wą krossofool.
Reference: D-Vitum 

In [None]:
# potestowac tylko na lekach i na trainie

### caly error analysis
# precision recall accuracy na target terms
# porownac z jakims modelem open source
# word level confidence
# czy jakies konkretne bledy sie powtarzaja

### Post-processing model without full fine-tuning

##### setting initial-prompt parameter

In [31]:
initial_prompt = "To nagranie jest fragmentem z wywiadu medycznego. Zawiera nazwy leków takich jak Paracetamol, Ibuprofen, Metformin."

In [None]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing:   0%|          | 0/146 [00:00<?, ?it/s]Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Transcribing: 100%|██████████| 146/146 [41:27<00:00, 17.04s/it]


In [22]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.4496
Subset CER: 0.1566


In [26]:
for ref, hyp in zip(references[30:], hypotheses[30:]):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Na recepcie zapiszę Hydrochlorotiazyd.	Hypothesis: Na recepcie zapiszę hydrochorotiazyt.
Reference: Powinnaś przyjmować Hydrochlorotiazyd na nadciśnienie.	Hypothesis: Pożynnaś przyjmować hydrochlorotiazytna nadciśnienie.
Reference: Najlepiej w twoim przypadku sprawdzi się Losartan.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się losartem.
Reference: Możesz kupić w aptece Losartan, powinien pomóc.	Hypothesis:  Możesz kupić w aptecjum sartan, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Avamina.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się wamin.
Reference: Powinnaś spróbować kuracji Avaminą.	Hypothesis: Pożynne nasz to jest kreatura, kreaturia waminowa.
Reference: Przepiszę ci Formetic, bo zwykle dobrze działa.	Hypothesis: Przepiszę onciwormic, bo zwykle dobrze działa.
Reference: Możesz kupić w aptece Formetic, powinien pomóc.	Hypothesis:  Możesz kupić w Abtec Informatik, powinien pomóc.
Reference: Na recepcie zapiszę Glucophage XR jako lek

In [35]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [27:56<00:00, 22.35s/it]


In [36]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 2.5681


In [37]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostrosz.
Reference: ActiFolin	Hypothesis: Abstij folii.
Reference: Aliness Kwas Foliowy	Hypothesis: Alinis quas fojolowy.
Reference: Amlodipina	Hypothesis: Analogipna.
Reference: Amoksycylina	Hypothesis: Amocylina.
Reference: Amotaks	Hypothesis: Amo tachs.
Reference: ApoD3	Hypothesis: Apoedet 3.
Reference: Arterios	Hypothesis: Ateria.
Reference: Augmentin	Hypothesis: Paukowa nąciń.
Reference: Avamina	Hypothesis:  Awamina.
Reference: Biaron D	Hypothesis: Jak wiarą de do zespoňu zespoňu zespoňu zespoňem.
Reference: Bibloc	Hypothesis: Pierć.
Reference: Bisakodyl	Hypothesis: Dziesia kodyl.
Reference: Bisocard	Hypothesis: I SOTS op.
Reference: Bisoprolol	Hypothesis: Nie zrozumiem.
Reference: Bisoratio	Hypothesis: Wi sora tie.
Reference: Concor	Hypothesis: Ponco.
Reference: Coronal	Hypothesis: Trona.
Reference: Crestor	Hypothesis: Tak to.
Reference: Crosuvo	Hypothesis: Proszuję do poradni lekarza lekarza lekarza lekarza.
Reference: D-Vitum Forte	Hypothesis: W

In [39]:
initial_prompt = "Rozmowa medyczna. Mogą wystąpić nazwy leków, w tym: Ibuprom, Metformina, Amlodipina, Ramipryl, Omeprazol, Bisoprolol."

In [28]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [46:20<00:00, 19.05s/it] 


In [29]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.4657
Subset CER: 0.2351


In [30]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nespróbować kuracji ibupromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzić się ibuprom.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis:  Możesz kupić wapnę Ticimetaphen, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzić snu rofenforte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam Noropher Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę nidolgit, bo zwykle dobrze działa.
Reference: Na recepcie zapiszę Dolgit jako lek na ból i stan zapalny.	Hypothesis: Na recepcie zapisze on dolegijego leka ból i stan zapalny.
Reference: Na nadciś

In [40]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, initial_prompt=initial_prompt)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [31:26<00:00, 25.15s/it]   


In [43]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 1.7799


In [44]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostrosi.
Reference: ActiFolin	Hypothesis: Abstij folii.
Reference: Aliness Kwas Foliowy	Hypothesis: Alina z Kwas Foiowem.
Reference: Amlodipina	Hypothesis: Amlodipina.
Reference: Amoksycylina	Hypothesis: Amoż siedzylina.
Reference: Amotaks	Hypothesis: Amotax.
Reference: ApoD3	Hypothesis: Apoedetry.
Reference: Arterios	Hypothesis: Ateria.
Reference: Augmentin	Hypothesis: Pałka nańczyna.
Reference: Avamina	Hypothesis:  Awaimina.
Reference: Biaron D	Hypothesis: Jedl.
Reference: Bibloc	Hypothesis: Pierdzol.
Reference: Bisakodyl	Hypothesis: Dli zakonyli.
Reference: Bisocard	Hypothesis: I.
Reference: Bisoprolol	Hypothesis: Nie zrozumiał.
Reference: Bisoratio	Hypothesis: Wi sora tie.
Reference: Concor	Hypothesis: Ponco.
Reference: Coronal	Hypothesis: Trona.
Reference: Crestor	Hypothesis: Przepływ do oparzeń dożylnych do zespołu podatku w poradni lewej komory, z powodu oparcia zespołu powierzchni, z powodu oparzenia z oparzeń z oparzeń z oparzeń z oparzeń z opa

In [45]:
print(references)
print(hypotheses)

['Actarosin', 'ActiFolin', 'Aliness Kwas Foliowy', 'Amlodipina', 'Amoksycylina', 'Amotaks', 'ApoD3', 'Arterios', 'Augmentin', 'Avamina', 'Biaron D', 'Bibloc', 'Bisakodyl', 'Bisocard', 'Bisoprolol', 'Bisoratio', 'Concor', 'Coronal', 'Crestor', 'Crosuvo', 'D-Vitum Forte', 'Dapagliflozyna', 'Devikap', 'Dolgit', 'Folik', 'Formetic', 'Glucophage XR', 'Hydrochlorotiazyd', 'Ibuprom', 'Ibuvit D3', 'KFD Omega 3', 'Laktuloza', 'Linagliptyna', 'Liraglutyd', 'Losartan', 'Macromax', 'Melatonina', 'Metafen', 'Metformax', 'Metformina', 'Mirtazapina', 'Nurofen Forte', 'Olicaps Witamina D3', 'Omeprazol', 'Oriovit-D 1000', 'Ospen 1000', 'OstroVit Omega 3', 'Pantoprazol', 'Perindopryl', 'Prestilol', 'Ramipryl', 'Ridlip', 'Romazic', 'Rosucard', 'Rosutrox', 'Rosuvastatin Medical Valley', 'Roswera', 'Sitagliptyna', 'Sobycor', 'Sukralfat', 'Sumamed', 'Suvardio', 'Synjardy', 'Trazodon', 'Vigalex Bio', 'Vigalex Forte', 'Vigalex Max', 'Vigantol', 'Vigantoletten', 'Witamina D3 Forte', 'Xigduo', 'Zahron', 'Zarant

### Using suppressed words 

In [None]:
bad_words = [
    "Awaimina", "Netformaz", "Metaphen", "Mieta zapina", "suwardia", "wigantoletem", "i bufet", "amotax"
    "Nurowendą", "Oliczaw z Vitamina", "Tarzodą", "hydrochorotiazyt", "medformax", "devicap"
    "Witamina D340", "Ostrosi", "Metaphen", "Netformaz", "Nurowendą", "kąco", "corozalne", "rosferę"
]

In [56]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, bad_words=bad_words)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [1:24:00<00:00, 34.52s/it]  


In [57]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.9536
Subset CER: 0.5920


In [58]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nie spróbować kuracji w upromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-PRO.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptycy metafem, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się nurowenworte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam norowę Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci doldit w osłupę, pobudzę węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węzeł węz

### Testing soft bias with tokens favorization

In [28]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, bias_words=leki)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [36:13<00:00, 14.89s/it] 


In [29]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.5101
Subset CER: 0.3221


In [30]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Poza tym w Klinikacji Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum Kulicjum
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-Prom.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pożynne spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić wapnę Tafan, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi

In [34]:
hypotheses = []
references = []

for ex in tqdm(subset_dataset, desc="Transcribing"):
    transcription = transcribe_example(ex, bias_words=leki)
    hypotheses.append(transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 75/75 [2:30:35<00:00, 120.48s/it]   


In [35]:
cer_score = cer(references, hypotheses)

print(f"Subset CER: {cer_score:.4f}")

Subset CER: 4.7263


In [36]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Actarosin	Hypothesis: Ostrożnie.
Reference: ActiFolin	Hypothesis: XG-Foli.
Reference: Aliness Kwas Foliowy	Hypothesis: Alina z KVAS fojjowym.
Reference: Amlodipina	Hypothesis: Auto-drypina.
Reference: Amoksycylina	Hypothesis: Amocylina.
Reference: Amotaks	Hypothesis: Amutax.
Reference: ApoD3	Hypothesis: Apoid 3.
Reference: Arterios	Hypothesis: Ateria.
Reference: Augmentin	Hypothesis: Pałka na 3.
Reference: Avamina	Hypothesis: Awamina.
Reference: Biaron D	Hypothesis: Kieruję do poradni lekarza lekarza leczo-pokrwia.
Reference: Bibloc	Hypothesis: PiWC.
Reference: Bisakodyl	Hypothesis: Diakodyl.
Reference: Bisocard	Hypothesis: I sotoc.
Reference: Bisoprolol	Hypothesis: Nisoprodu.
Reference: Bisoratio	Hypothesis: WIJSR 3.
Reference: Concor	Hypothesis: Pontyk onakowy-vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv vvvv vvvv vvv v v v v v v v v v v v

### Fuzzy Matching

In [45]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    postprocessed_transcription = postprocess_transcription(transcription=transcription, known_terms=leki)
    hypotheses.append(postprocessed_transcription)
    references.append(ex["phrase"])

Transcribing: 100%|██████████| 146/146 [36:17<00:00, 14.91s/it] 


In [47]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.5484
Subset CER: 0.2668


In [48]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po winu nie spróbować kuracji w upromem.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się IBU-PRO.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pobinna spróbować kuracji metafenem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptycy metafem, powinien pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się nurowenworte.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam norowę Fortem, bo dobrze działa nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci dole git, bo zwykle dobrze działa.
Reference: Na recepcie zapiszę Dolgit jako lek na ból i stan zapalny.	Hypothesis: Na recepcie zapiszę dolegięcie do leka ból i stan zapalny.
Reference: Na nadciś

### LLM correction

In [None]:
hypotheses = []
references = []

for ex in tqdm(dataset, desc="Transcribing"):
    transcription = transcribe_example(ex)
    corrected = correct_with_llm(transcription=transcription, api_key = api_key)
    hypotheses.append(corrected)
    references.append(ex["phrase"])

In [55]:
wer_score = wer(references, hypotheses)
cer_score = cer(references, hypotheses)

print(f"Subset WER: {wer_score:.4f}")
print(f"Subset CER: {cer_score:.4f}")

Subset WER: 0.8586
Subset CER: 0.7186


In [56]:
for ref, hyp in zip(references, hypotheses):
    print(f"Reference: {ref}\tHypothesis: {hyp}")

Reference: Powinnaś spróbować kuracji Ibupromem.	Hypothesis: Po spożyciu alkoholu nie należy próbować kuracji odwykowej.
Reference: Najlepiej w twoim przypadku sprawdzi się Ibuprom.	Hypothesis: W Pana przypadku najskuteczniejsze będzie ibuprofenum.
Reference: Powinnaś spróbować kuracji Metafenem.	Hypothesis: Pobinna spróbować kuracji metafenem. -> Powinna spróbować kuracji metamizolem.
Reference: Możesz kupić w aptece Metafen, powinien pomóc.	Hypothesis: Możesz kupić w aptece metforminę, powinna pomóc.
Reference: Najlepiej w twoim przypadku sprawdzi się Nurofen Forte.	Hypothesis: Najlepiej w twoim przypadku sprawdzi się niesteroidowy lek przeciwzapalny.
Reference: Zalecam Nurofen Forte, bo dobrze działa na ból i stan zapalny.	Hypothesis: Zalecam doustnie Fortem, ponieważ dobrze działa na nabór i stan zapalny.
Reference: Przepiszę ci Dolgit, bo zwykle dobrze działa.	Hypothesis: Przepiszę ci diklofenak w żelu, bo zwykle dobrze działa.
Reference: Na recepcie zapiszę Dolgit jako lek na ból

## Wnioski:
- najlepiej WER i CER zdaje się obniżać parametr initial_prompt

## Fine tuning

In [34]:
OUTPUT_DIR = "./whisper_medical_finetuned_lora"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLING_RATE = 16000
BASE_CHECKPOINT = checkpoint
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "v_proj"] 

NUM_EPOCHS = 6
PER_DEVICE_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 1e-4
FP16 = True

MED_WEIGHT = 5.0

In [35]:
checkpoint = "natural_anonym_synth"
SAMPLING_RATE = 16000
language = "pl"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(checkpoint) 
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, checkpoint)

In [36]:
class WhisperTuner(PeftModel):
    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )

    def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None,
                decoder_inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None,
                return_dict=None, task_ids=None, **kwargs):
        allowed = {"input_features","attention_mask","decoder_input_ids","decoder_attention_mask",
                "decoder_inputs_embeds","labels","output_attentions","output_hidden_states","return_dict"}
        filtered = {k:v for k,v in kwargs.items() if k in allowed}
        if input_features is not None:
            filtered["input_features"] = input_features
        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if peft_config.peft_type == PeftType.POLY:
                kwargs["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(**filtered)
    def prepare_inputs_for_generation(self, *args, **kwargs):
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if "input_ids" in model_kwargs and "input_features" not in model_kwargs:
            model_kwargs["input_features"] = model_kwargs.pop("input_ids")
        return model_kwargs

In [37]:
lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias="none"
    )
model = WhisperTuner(model, lora_config)



In [38]:
model.to(DEVICE)
print("Model loaded. Trainable parameters:")
try:
    model.print_trainable_parameters()
except Exception:
    n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {n_trainable} / {n_total}")
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)

Model loaded. Trainable parameters:
trainable params: 2,359,296 || all params: 766,217,216 || trainable%: 0.3079


In [None]:
def prepare_example(example):
    audio = example["waveform"]

    inputs = processor(audio, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    example["input_features"] = inputs.input_features[0]

    labels = processor.tokenizer(example["phrase"], add_special_tokens=True).input_ids
    example["labels"] = torch.tensor(labels, dtype=torch.long)

    return example

In [None]:
train_dataset = Dataset.from_list(train_dataset)
val_dataset = Dataset.from_list(dataset)

In [44]:
train_dataset = train_dataset.map(prepare_example)
val_dataset = val_dataset.map(prepare_example)

Map: 100%|██████████| 1102/1102 [01:04<00:00, 17.15 examples/s]
Map: 100%|██████████| 146/146 [00:07<00:00, 19.11 examples/s]


In [80]:
vocab_size = len(processor.tokenizer)
token_weights = torch.ones(vocab_size, dtype=torch.float32, device=DEVICE)
for word in leki:
    ids = processor.tokenizer.encode(word, add_special_tokens=False)
    for t in ids:
        if 0 <= t < vocab_size:
            token_weights[t] = MED_WEIGHT

In [81]:
wer_metric = evaluate.load("wer")

In [82]:
def compute_metrics_eval(preds: List[str], refs: List[str]) -> Dict[str, float]:
    wer = wer_metric.compute(references=refs, predictions=preds)
    tp = fp = fn = 0
    meds_lower = [m.lower() for m in leki]
    for hyp, ref in zip(preds, refs):
        hyp_tokens = set([t.lower().strip(".,") for t in hyp.split()])
        ref_tokens = set([t.lower().strip(".,") for t in ref.split()])
        for m in meds_lower:
            in_ref = m in ref_tokens
            in_hyp = m in hyp_tokens
            if in_ref and in_hyp:
                tp += 1
            elif in_hyp and not in_ref:
                fp += 1
            elif in_ref and not in_hyp:
                fn += 1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    return {"wer": wer, "precision_drugs": precision, "recall_drugs": recall}

In [83]:
def compute_metrics_for_trainer(eval_pred):
    preds_ids, labels_ids = eval_pred
    decoded_preds = processor.batch_decode(preds_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    decoded_labels = processor.batch_decode(labels_ids, skip_special_tokens=True)
    return compute_metrics_eval(decoded_preds, decoded_labels)

In [84]:
def collate_fn(batch):
    input_features = torch.stack([torch.tensor(x["input_features"]) for x in batch]).to(DEVICE)

    labels = [torch.tensor(x["labels"], dtype=torch.long) for x in batch]
    labels_padded = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id
    ).to(DEVICE)

    labels_for_loss = labels_padded.clone()
    labels_for_loss[labels_for_loss == processor.tokenizer.pad_token_id] = -100

    batch_out = {
        "input_features": input_features,
        "labels": labels_for_loss,
        "decoder_input_ids": labels_padded
    }
    return batch_out

In [85]:
class WeightedLossTrainer(Seq2SeqTrainer):
    def __init__(self, token_weights_tensor: torch.Tensor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_weights = token_weights_tensor

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        decoder_input_ids = inputs.pop("decoder_input_ids", None)
        input_features = inputs.pop("input_features")
        outputs = model(input_features=inputs["input_features"], labels=labels, decoder_input_ids = decoder_input_ids)
        logits = outputs.logits

        batch_size, seq_len, vocab = logits.size()
        logits_flat = logits.view(-1, vocab)
        labels_flat = labels.view(-1)
        loss_per_token = F.cross_entropy(logits_flat, labels_flat, reduction="none", ignore_index=-100)
        loss_per_token = loss_per_token.view(batch_size, seq_len)

        labels_masked = labels.clone()
        labels_masked[labels_masked == -100] = 0
        weights = self.token_weights[labels_masked]
        mask = (labels != -100).float()
        weights = weights * mask

        weighted_loss = (loss_per_token * weights).sum()
        normalizer = weights.sum().clamp_min(1.0)
        loss = weighted_loss / normalizer
        return (loss, outputs) if return_outputs else loss

In [86]:
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    fp16=FP16,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    predict_with_generate=True, 
    save_strategy="steps",
    save_steps=4, #1000
    save_total_limit=3, 
    logging_steps=1, #100
    eval_strategy="steps",
    eval_steps=2, #500
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    remove_unused_columns=False,
    disable_tqdm=False,
    max_steps = 8
)
trainer = WeightedLossTrainer(
    token_weights_tensor=token_weights,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics_for_trainer,
)


  super().__init__(*args, **kwargs)


In [75]:
print("Starting training...")
trainer.train()
print("Training finished. Saving model...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

Starting training...


  input_features = torch.stack([torch.tensor(x["input_features"]) for x in batch]).to(DEVICE)
  labels = [torch.tensor(x["labels"], dtype=torch.long) for x in batch]


KeyError: 'input_features'