# Decoding

In this notebook, we provide an interface to experiment with several decoding strategies and evaluate them.

### Setting Up

Simply run the cell below to load the libraries, define important functions and set important global ariables

In [None]:
import whisper, torch
import numpy as np
from data_processing.tokenize_audio import *
from tqdm import tqdm
import tempfile
import torchaudio
import torchaudio.transforms as at
import os, re, csv
from datetime import datetime
import ufal.morphodita
from pydub import AudioSegment
from pydub.silence import split_on_silence
from scipy.signal import butter, lfilter
from gtts import gTTS
from IPython.display import Markdown,Audio, display

extract_prefix = lambda s: re.match(r"[\wěščřžýáíéóúůĎťňŘŠČŽÝÁÍÉÓÚŮ]+", s).group(0) if re.match(r"[\wěščřžýáíéóúůĎťňŘŠČŽÝÁÍÉÓÚŮ]+", s) else ""


# Load the Czech MorphoDiTa lemmatizer model
tagger = ufal.morphodita.Tagger.load("czech-morfflex/czech-morfflex2.0-pdtc1.0-220710-pos_only.tagger")

def lemmatize_czech(word: str) -> str:
    """
    Takes a Czech word as input and returns its lemma.

    Args:
        word (str): The input word in Czech.

    Returns:
        str: The lemmatized form of the word.
    """
    if not tagger:
        raise RuntimeError("Tagger model failed to load.")

    # Create a tokenizer and tagger
    tokenizer = tagger.newTokenizer()
    if not tokenizer:
        raise RuntimeError("Failed to create tokenizer.")

    # Tokenize input
    tokenizer.setText(word)
    forms = ufal.morphodita.Forms()
    lemmas = ufal.morphodita.TaggedLemmas()
    tokens = ufal.morphodita.TokenRanges()

    while tokenizer.nextSentence(forms,tokens):
        tagger.tag(forms, lemmas)

    # Return the first lemma (assuming one word input)
    return lemmas[0].lemma if lemmas else word  # If no lemma found, return original word

def load_wave(wave_path, segment_start, segment_end, sample_rate:int=16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr: waveform = at.Resample(sr, sample_rate)(waveform)
    segment_of_interest = waveform[:, int(segment_start * sample_rate/1000):int(segment_end * sample_rate/1000)]
    if len(segment_of_interest.shape) == 2: segment_of_interest = torch.mean(segment_of_interest, dim = 0)
    return segment_of_interest

def save_and_load_audiosegment(segment, sample_rate=16000):
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
        temp_path = tmpfile.name
        segment.export(temp_path, format="wav")
    
    loaded_wave = load_wave(temp_path, 0, len(segment), sample_rate)
    os.remove(temp_path)
    return loaded_wave

MODEL = whisper.load_model("small")
checkpoint = torch.load("best-checkpoint-epoch=0006-run-small_jasmi_kd5e-2.ckpt", map_location="cpu")
ckpt_state_dict = checkpoint["state_dict"]
state_dict = {k.replace("model.",""):v for k,v in ckpt_state_dict.items() if k.startswith("model.")}
MODEL.load_state_dict(state_dict)

### Define Decoding Startegies

In the cell below, we define a class `WhisperSpecialistDecoder`. This class takes a Whisper model as input and facilitates an interface for word-by-word transcription strategies that are more sophisticated than what is offered by the iterface of `whisper` library. The strategies implemented so far are:
1. **Unguided transcription**: uses a beam search to find N most likely transcriptions.
2. **Guided transcription**: uses a beam search to find N most likely transcriptions. The transcriptions are limited to existing Czech words.
3. **Guided transcription with merged segments**: uses a beam search to find N most likely transcriptions. The transcriptions are limited to existing Czech words. The true transcription of previous segments is used as well.

All of these strategies are implemented via `decode_beam` method. We show how to call each of them further below.


In [None]:
import csv

class WhisperSpecialistDecoder:
    def __init__(self, vocabulary_source, tokenizer, model, penalty_strength = 0.3):
        self.tokenizer = tokenizer
        self.model = model
        self.transcriptions = []
        self.first_token_dict = {}
        self.next_token_dict = {}
        self.token_lists = []
        self.token_lists = []
        self.probs_when_unrelated = []
        self.penalty_strength = penalty_strength
        
        # Load CSV and process transcriptions
        self._load_csv(vocabulary_source)
        self._build_token_dict(first=True)
        self._build_token_dict(first=False)
        self._build_token_list()
    
    def _load_csv(self, vocabulary_source):
        """Loads CSV and extracts first column into transcriptions."""
        if isinstance(vocabulary_source, str):
            with open(vocabulary_source, newline='', encoding='utf-8') as f:
                reader = csv.reader(f)
                self.transcriptions = [row[0][1:] for row in reader if row]
        elif isinstance(vocabulary_source, list):
            self.transcriptions = vocabulary_source
    
    def _build_token_dict(self, first = True):
        """Encodes transcriptions and builds a nested token dictionary."""
        for text in self.transcriptions if len(self.transcriptions) < 200 else tqdm(self.transcriptions):
            if not first: text = " " + text.lower()
            tokens = self.tokenizer.encode(f"{text}<|endoftext|>", allowed_special = "all")
            current_dict = self.first_token_dict if first else self.next_token_dict 
            for token in tokens:
                if token in current_dict.keys():
                    current_dict = current_dict[token]
                else:
                    current_dict[token] = {}
                    current_dict = current_dict[token]

    def _build_token_list(self):
        for text in self.transcriptions:
            tokens = self.tokenizer.encode(f"{text}<|endoftext|>", allowed_special = "all") 
            self.token_lists.append(tokens)

    # OPTIONAL: add your own decoding strategies
    def decode_beam(self, audio_segment, beam_size=5, n_results = 3, previous_transcription = "", guided = True):
        """
        Performs beam search decoding with masked logits, batching all beam hypotheses at once.
        
        Args:
            audio_segment: Input audio waveform.
            beam_size: Number of hypotheses to maintain.

        Returns:
            Best decoded transcription.
        """
        assert beam_size >= n_results, "N results to be returned cannot exceed the beam size"
        with torch.no_grad():
            # Prepare input
            audio = whisper.pad_or_trim(audio_segment)
            mel = whisper.log_mel_spectrogram(audio, n_mels=80)
            encoder_output = self.model.encoder(mel.unsqueeze(0))

            # Initialize search
            initial_tokens = self.tokenizer.encode(f"<|startoftranscript|><|cs|><|transcribe|><|notimestamps|>{previous_transcription}", allowed_special="all")
            beam = [(initial_tokens, 0.0, self.next_token_dict if previous_transcription else self.first_token_dict)]  # (tokens, log_prob, current_dict)

            # Beam search loop
            for e in range(25):  # Max length of sequence
                new_beam = []
                # Reove finished hypotheses
                removed_hypotheses = []
                for b, el_tuple in enumerate(beam):
                    tokens, _, _ = el_tuple
                    if tokens[-1] == self.tokenizer.eot:
                        new_beam.append(el_tuple)
                        removed_hypotheses.append(b)
                for h_idx in removed_hypotheses[::-1]: del beam[h_idx] 
                if not len(beam):
                    beam = new_beam
                    break  # Stop early if all hypotheses are finished

                # Prepare batched inputs
                token_batch = [tokens for tokens, _, _ in beam]
                current_dicts = [current_dict for _, _, current_dict in beam]
                encoded_txt_tokens = torch.tensor(token_batch).long().to(encoder_output.device)

                # Decode batch
                logits = self.model.decoder(encoded_txt_tokens, encoder_output)[:, -1, :]  # Shape: (beam_size, vocab_size)

                if guided:
                    # Apply individual masks for each beam
                    vocab_size = logits.shape[-1]
                    masks = torch.full((len(beam), vocab_size), float('-inf'), device=logits.device)

                    for i, current_dict in enumerate(current_dicts):
                        allowed_indices = list(current_dict.keys())
                        masks[i, allowed_indices] = 0  # Allow only valid tokens

                    logits = logits + masks  # Mask invalid tokens

                log_probs = torch.log_softmax(logits, dim=-1)  # Compute log probabilities

                # Get top-k candidates for each beam hypothesis
                topk_log_probs, topk_indices = torch.topk(log_probs, beam_size, dim=-1)

                # Expand beam
                for i in range(len(beam)):
                    for j in range(beam_size):
                        new_token = topk_indices[i, j].item()
                        new_log_prob = beam[i][1] + topk_log_probs[i, j].item()
                        new_dict = beam[i][2].get(new_token, {})  # Navigate token tree
                        new_beam.append((beam[i][0] + [new_token], new_log_prob, new_dict))

                # Keep top beam_size hypotheses
                beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_size]

            # Return best sequence
            best_tokens = [(beam_el[0],beam_el[1]) for beam_el in beam[:n_results+1]]  # Best scoring sequence
            return [(self.tokenizer.decode(tokens[len(initial_tokens):-1]).split()[0] if self.tokenizer.decode(tokens[len(initial_tokens):-1]).split() else '', log_prob) for tokens,log_prob in best_tokens] #self.tokenizer.decode(best_tokens)


But before we play around with the decoder, we have to load all the Czech words and store them in the decoder.

In [None]:
czech_word_list = open("data/czech_vocabulary.txt").read().split(",")
czech_word_list = list(set([word.capitalize() for word in tqdm(czech_word_list) if word]))[1:]

Finally, we can create our decoder:

In [None]:
woptions = whisper.DecodingOptions(language="cs", without_timestamps=True)
tokenizer = whisper.tokenizer.get_tokenizer(True, language="cs", task = woptions.task) 
decoder = WhisperSpecialistDecoder(czech_word_list,tokenizer, MODEL)

100%|██████████| 35116/35116 [00:05<00:00, 6224.53it/s]
100%|██████████| 35116/35116 [00:05<00:00, 6153.86it/s]


### Transcribe Conversations

For this task we have roecorded three special conversations. In which the patient was asked to keep at least 1s pauses between her individual words.  We load them in the cell below.

In [3]:
audio_segment_108 = load_audio("data/audio/Jasmi_Mimozemstan_108.mp3")
audio_segment_109 = load_audio("data/audio/Jasmi_Hlaseni_109.mp3")
audio_segment_110 = load_audio("data/audio/Jasmi_Rytiri_kosmu_110.mp3")

And we split the audios to the shorter segments corresponding to individual words.

In [None]:
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return lfilter(b, a, data)

def apply_bandpass(audio: AudioSegment, lowcut=165, highcut=3000):
    # Convert audio to NumPy array
    samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
    sample_rate = audio.frame_rate

    # Apply bandpass filter
    filtered_samples = butter_bandpass_filter(samples, lowcut, highcut, sample_rate)

    # Convert back to int16 (assuming original audio was in 16-bit PCM)
    filtered_samples = np.int16(filtered_samples)

    # Create new AudioSegment
    return AudioSegment(
        filtered_samples.tobytes(),
        frame_rate=sample_rate,
        sample_width=audio.sample_width,
        channels=audio.channels
    )

def extract_loudest_segments(audio: AudioSegment, n_words: int, min_period: int, take_before: int, take_after: int):
    """
    Identifies the n_words loudest 100ms chunks that are at least min_period apart,
    and extracts segments around them.

    :param audio: AudioSegment to be analyzed
    :param n_words: Number of loudest chunks to extract
    :param min_period: Minimum distance (ms) between peaks
    :param take_before: Time (ms) to include before each peak
    :param take_after: Time (ms) to include after each peak
    :return: List of AudioSegment chunks
    """
    chunk_size = 100  # 100ms chunks
    num_chunks = len(audio) // chunk_size
    
    # Compute dBFS for each chunk
    chunk_loudness = [(i * chunk_size, apply_bandpass(audio[i * chunk_size:(i + 1) * chunk_size]).dBFS) for i in range(num_chunks)]
    
    # Sort by loudness (descending order)
    chunk_loudness.sort(key=lambda x: x[1], reverse=True)
    
    # Select loudest chunks with min_period constraint
    selected_peaks = []
    for time, _ in chunk_loudness:
        if all(abs(time - p) >= min_period for p in selected_peaks):
            selected_peaks.append(time)
        if len(selected_peaks) == n_words:
            break

    # Sort selected peaks by their temporal order
    selected_peaks.sort()
    
    # Extract segments around peaks
    chunks = []
    for peak in selected_peaks:
        start = max(0, peak - take_before)
        end = min(len(audio), peak + take_after)
        chunks.append(audio[start:end])
    
    return chunks

def get_row_chunks(csv_path, audio_segments):
    segment_pointer = 0
    last_end = -np.inf

    # Storage for all extracted segments
    all_extracted_segments = []

    with open(csv_path, newline='', encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        
        for row in tqdm(reader):
            if len(row) < 3:
                continue  # Skip malformed rows
            
            text = row[0]
            start_time = float(row[1])  # Assuming values are in milliseconds
            end_time = float(row[2])
            if end_time < last_end: segment_pointer += 1
            last_end = end_time

            
            # Count words
            word_count = len(text.split())
            
            # Extract audio segment
            short_segment = audio_segments[segment_pointer][start_time:end_time]
            
            # Get loudest segments
            loudest_segments = extract_loudest_segments(short_segment, word_count, 3000, 1500, 2500)
            
            # Append to results
            all_extracted_segments.append(loudest_segments)

    return all_extracted_segments

all_extracted_segments = []
all_extracted_segments = get_row_chunks("data/annotations/long_pause_sessions.csv",[audio_segment_108,audio_segment_109,audio_segment_110])

Let's also load the transcriptions.

In [None]:
csv_path, n_lines = "data/annotations/long_pause_sessions.csv",148
row_list = []
with open(csv_path, newline='', encoding='utf-8') as f:
    reader = csv.reader(f)
    for row in reader: row_list.append(row)

Now we can finally start transcribing the individual segments. At the very beggining, let's look at transcriptions generated by the provided interface to see why a custom decoder is neccessary  

In [None]:
result = MODEL.transcribe(all_extracted_segments[0][0], language = "cs")#, beam_size = 110)
result

{'text': 'proces se vyrábějí, kdy se mu podaří',
 'segments': [{'id': 0,
   'seek': 0,
   'start': 0.0,
   'end': 4.1,
   'text': 'proces se vyrábějí, kdy se mu podaří',
   'tokens': [50364,
    4318,
    887,
    369,
    371,
    6016,
    27879,
    9648,
    73,
    870,
    11,
    350,
    3173,
    369,
    2992,
    2497,
    64,
    15781,
    870],
   'temperature': 0.0,
   'avg_logprob': -0.475648832321167,
   'compression_ratio': 0.8541666666666666,
   'no_speech_prob': 7.932011709781139e-13}],
 'language': 'cs'}

As we can see, the model returns only one transcription which is insufficient for our needs. Let's furthermore try to look just at the transcription texts of multiple segments: 

In [None]:
# Get predictions
silence = AudioSegment.silent(500)
for r,row_chunks in enumerate(all_extracted_segments[:15]):
    if int(row_list[r][3]) != 0: continue
    print(f"================ {r} ================")
    merged_segment = AudioSegment.empty()
    for c, chunk in enumerate(row_chunks): 
        if c: merged_segment += silence
        merged_segment += chunk
    whisper_input = save_and_load_audiosegment(merged_segment)
    predicted_transcription = MODEL.transcribe(whisper_input, language = "cs")["text"]
    true_transcription = row_list[c][0]
    print("Predicted transcription:", predicted_transcription)
    display(Audio(gTTS(text=predicted_transcription, lang = "cs", slow = False)))
    print("True transcription:", true_transcription)
    display(Audio(gTTS(text=true_transcription, lang = "cs", slow = False)))

Now, let's look at the outputs of our own decoder. These are presented in a slighlty different form, because this time, we provide not just one but rather five predictions for every word

In [None]:
def word_by_word_transcription(decoding_startegy):
    for r, row_chunks in enumerate(all_extracted_segments[:15]):
        if int(row_list[r][3]) != 0:continue
        # Print the whole true transcription
        print(f"================ {r} ================")
        true_transcription = row_list[r][0]
        transcription_words = true_transcription.split()
        print("True transcription:", true_transcription)
        display(Audio(gTTS(text=true_transcription, lang = "cs", slow = False)))
        # Loop the words
        merged_segment = AudioSegment.empty()
        for ch, chunk in enumerate(row_chunks): 
            merged_segment = AudioSegment.empty()
            for c in range(ch + 1): 
                if c: merged_segment += silence
                merged_segment += row_chunks[c]

            word_list = decoding_startegy(chunk, merged_segment)
            transcription_words
            # Display predictions
            print("True word:", transcription_words[ch])
            print("Predictions:")
            for p,pred in enumerate(word_list):
                pred_text, pred_log_prob = pred
                pred_prob = np.exp(pred_log_prob)
                print(f"{p}. Prediction: {pred_text} (p = {pred_prob})")

# Transcription strategies
unguided_transcribe = lambda chunk, merged_segment: decoder.decode_beam(save_and_load_audiosegment(chunk), 15,5,guided=False)
guided_transcribe = lambda chunk, merged_segment: decoder.decode_beam(save_and_load_audiosegment(chunk), 15,5,guided=True)
guided_and_merging_transcribe = lambda chunk, merged_segment: decoder.decode_beam(save_and_load_audiosegment(merged_segment), 15,5,guided=True)
# OPTIONAL: Call/design your own transcription strategies


#TODO: replace the undefined variable unselected_strategy by the startegy of your choice
word_by_word_transcription(unselected_strategy)

### Quantitave Decoding Evaluation

Here, we quantitavely evaluate all of our decoding strategies on all usable segments from the imported csv file. As our evaluation metric, we use recall of top N predictions where N ranges from 1 to 10. We lemmatize both predictions and the targets before computing the metric.

In [None]:
import time

def evaluate_recall(strategy):
    sentence_list = []
    recall = [0 for i in range(5)]
    total_chunks = sum([len(line) for line in all_extracted_segments])
    silence = AudioSegment.silent(500)

    r = 0
    start_time = time.time()
    for row in tqdm(row_list):
        transcription = row[0].split()
        recording_type = row[3]
        if int(recording_type) != 0: 
            r += 1
            continue
        row_chunks = all_extracted_segments[r]
        last_sentences = []
        for ch,chunk in enumerate(row_chunks):
            transcription_word = transcription[ch]
            # Get predictions
            merged_segment = AudioSegment.empty()
            for c in range(ch + 1): 
                if c: merged_segment += silence
                merged_segment += row_chunks[c]

            whisper_audio = save_and_load_audiosegment(chunk)
            decoded = decoder.decode_beam(whisper_audio,15,9)
            # Deterministically preprocess predictions
            # Compute the recall
            processed_words = [extract_prefix(lemmatize_czech(decoded_word.replace(" ",""))).replace("_","").capitalize() for decoded_word,_ in decoded]
            
            for i in range(1,11): 
                if extract_prefix(lemmatize_czech(transcription_word)).replace("_","").capitalize() in processed_words[:11-i]: 
                    recall[10-i] += 1/total_chunks

        sentence_list.append(last_sentences.copy())
        r += 1
    end_time = time.time()
    second_per_sample = (end_time-start_time)/total_chunks
    return second_per_sample, recall

# Transcription strategies
unguided_transcribe = lambda chunk, merged_segment: decoder.decode_beam(save_and_load_audiosegment(chunk), 15,5,guided=False)
guided_transcribe = lambda chunk, merged_segment: decoder.decode_beam(save_and_load_audiosegment(chunk), 15,5,guided=True)
guided_and_merging_transcribe = lambda chunk, merged_segment: decoder.decode_beam(save_and_load_audiosegment(merged_segment), 15,5,guided=True)
# OPTIONAL: Call/design your own transcription strategies


#TODO: replace the undefined variable unselected_strategy by the startegy of your choice
second_per_sample, recall = evaluate_recall(unselected_strategy)

# Show results
for r, recall_item in recall:
    print(f"Top-{r} Recall:", recall_item)
print("Seconds per sample:",second_per_sample)