In [1]:
# Necessary installations
# !pip install transformers[torch]
# !pip install datasets
# !pip install evaluate
# !pip install jiwer
# !pip install -U openai-whisper

## Dataset preparation

In [2]:
# Download dataset and unzip
## !wget https://www.cse.iitb.ac.in/~pjyothi/cs753/dataset.zip
## !unzip dataset.zip

In [3]:
import datasets
from datasets import load_dataset

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union, Optional,Tuple

from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import WhisperForConditionalGeneration

In [4]:
# Load json files for test split
data_files = {
    "test": "./CodeSwitched_Data/test.json"
}
dataset = load_dataset("json", data_files=data_files)

# Update the audio paths to include appropriate folder-name
def prepend_folder_name(row):
    row["audio"] = 'CodeSwitched_Data/' + row["audio"]
    return row
for key in dataset:
    dataset[key] = dataset[key].map(prepend_folder_name)

# Cast columns to appropriate features
features = datasets.Features(
    {
        "id": datasets.Value("string"),
        "transcription": datasets.Value("string"),
        "audio": datasets.Audio(sampling_rate=16000),
    }
)
dataset = dataset.map(features.encode_example, features=features)

In [5]:
# Load necessary processors
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    # encode target text to label ids
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch
dataset = dataset.map(prepare_dataset, num_proc=2)

## Task 2.1: Zero shot whisper greedy decoding

> In this task, you will run a standard greedy decoding of the test utterances using the pretrained Whisper small model.



In [7]:
import json
import numpy as np
from evaluate import load
import whisper
import tqdm

wer = load("wer")
cer = load("cer")

def inferencer(json_file_path, model, dataset):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [8]:
model = whisper.load_model("small")
json_file_path = "./zero_shot_whisper.json"
inferencer(json_file_path, model, dataset['test'])

100%|██████████| 61/61 [06:02<00:00,  5.94s/it]


In [9]:
!head -n 5 ./zero_shot_whisper.json

[
    {
        "cer": 0.6903738439494594,
        "wer": 0.88765359859567
    },


In [10]:
## !wget https://www.cse.iitb.ac.in/~pjyothi/cs753/whisper-small-finetuned.pt

In [11]:
model = whisper.load_model("./whisper-small-finetuned.pt")
json_file_path = "./zero_shot_whisper_fine_tuned.json"
inferencer(json_file_path, model, dataset['test'])

100%|██████████| 61/61 [05:13<00:00,  5.15s/it]


In [12]:
!head -n 5 ./zero_shot_whisper_fine_tuned.json

[
    {
        "cer": 0.4706265468281881,
        "wer": 0.6705675833820948
    },


## Task 2.2: Constrained Filtering-based Greedy Decoding

In [13]:
## !wget https://www.cse.iitb.ac.in/~pjyothi/cs753/whisper-small-finetuned.pt

In [14]:
from whisper.decoding import DecodingOptions, DecodingResult, DecodingTask, GreedyDecoder, TokenDecoder, Inference
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical

from whisper.audio import CHUNK_LENGTH
from whisper.tokenizer import Tokenizer, get_tokenizer
from whisper.utils import compression_ratio

if TYPE_CHECKING:
    from whisper.model import Whisper

In [31]:
"""
TODO: Define helper functions to perform constrained filtering here, if needed.
"""

def sample_batch(logits: torch.Tensor, N: int = 10, p: float = 0.9) -> torch.Tensor:
    """
      Sample the next best token for every example within a batch, using constrained filtering.
    """
    # TODO: Implement this function as follows:
    #   For every example within the batch:
    #       1. For tokens that do not satisfy the (N,p) constraints, update its value in `logits` to 'Infinity'.
    #       2. Convert the updated `logits` into a probability distribution.
    #       3. Finally, sample (https://pytorch.org/docs/stable/generated/torch.multinomial.html) from this probability distribution to obtain 1 token per example.
    # TODO: You can create any helper functions that sample_batch needs within this cell
    top_values, top_indices = torch.topk(logits, N, dim=-1)
    p_dist = torch.exp(top_values)
    p_dist = p_dist/torch.sum(p_dist)
    top_values1 = torch.where(torch.cumsum(p_dist,dim=-1) < p , top_values, -float('inf'))
    top_values1[0][0] = top_values[0][0]
    probabilities = F.softmax(top_values1, dim=-1)
    sampled_indices = torch.multinomial(probabilities[0], 1)
    sampled_indices = torch.reshape(sampled_indices, (1, 1))
    sampled_tokens = torch.gather(top_indices, -1, sampled_indices)
    return sampled_tokens.squeeze(-1)

In [32]:
import os
import traceback
import warnings

from whisper.audio import (
    FRAMES_PER_SECOND,
    HOP_LENGTH,
    N_FRAMES,
    N_SAMPLES,
    SAMPLE_RATE,
    log_mel_spectrogram,
    pad_or_trim,
)
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.utils import (
    exact_div,
    format_timestamp,
    get_writer,
    make_safe,
    optional_float,
    optional_int,
    str2bool,
)

In [33]:
"""
Custom classes to integrate constrained filtering within greedy decoding.
"""

class CustomGreedyDecoder(GreedyDecoder):
    """
    Updates the existing `GreedyDecoder` class form whisper to use constrained_filtering
    """
    def __init__(self, temperature: float, eot: int):
        super().__init__(temperature, eot)

    def update(
        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
    ) -> Tuple[Tensor, bool]:
        """
        Updates the token list `tokens` to add a next best token using constrained_filtering
        """
        # TODO: sample_batch should be modified to implement constrained_filtering
        next_tokens = sample_batch(logits)
        logprobs = F.log_softmax(logits.float(), dim=-1)
        current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
        # sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)

        next_tokens[tokens[:, -1] == self.eot] = self.eot
        tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)

        completed = (tokens[:, -1] == self.eot).all()
        return tokens, completed

class CustomDecodingTask(DecodingTask):
    """
    Updates the existing `DecodingTask` class form whisper to incorporate constrained_filtering
    """
    def __init__(self, model: "Whisper", options: DecodingOptions):
        super().__init__(model,options)
        self.decoder = CustomGreedyDecoder(
            self.options.temperature, self.tokenizer.eot
        )
@torch.no_grad()
def custom_decode(
    model: torch.nn.Module,
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),

    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    """
    decode function to perform constrained_filtering based greedy decoding
    """
    single = mel.ndim
    if single == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = CustomDecodingTask(model, options).run(mel)

    return result[0] if single else result

def custom_transcribe(
    model: "Whisper",
    audio: Union[str, np.ndarray, torch.Tensor],
    *,
    compression_ratio_threshold: Optional[float] = 2.4,
    logprob_threshold: Optional[float] = -1.0,
    condition_on_previous_text: bool = True,
    prepend_punctuations: str = "\"'“¿([{-",
    append_punctuations: str = "\"'.。,，!！?？:：”)]}、",
    **decode_options,
):
    dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
    if model.device == torch.device("cpu"):
        if torch.cuda.is_available():
            warnings.warn("Performing inference on CPU when CUDA is available")
        if dtype == torch.float16:
            warnings.warn("FP16 is not supported on CPU; using FP32 instead")
            dtype = torch.float32

    if dtype == torch.float32:
        decode_options["fp16"] = False

    # Pad 30-seconds of silence to the input audio, for slicing
    mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
    content_frames = mel.shape[-1] - N_FRAMES

    language: str = decode_options["language"]
    task: str = decode_options.get("task", "transcribe")
    tokenizer = get_tokenizer(
        model.is_multilingual,
        num_languages=model.num_languages,
        language=language,
        task=task,
    )

    def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
        decode_result = None

        temp = 0.0
        kwargs = {**decode_options}
        kwargs.pop("best_of", None)

        options = DecodingOptions(**kwargs, temperature=temp)
        decode_result = custom_decode(model, segment, options)

        return decode_result

    seek = 0
    input_stride = exact_div(
        N_FRAMES, model.dims.n_audio_ctx
    )  # mel frames per output token: 2
    time_precision = (
        input_stride * HOP_LENGTH / SAMPLE_RATE
    )  # time per output token: 0.02 (seconds)
    all_tokens = []
    all_segments = []

    prompt_reset_since = 0
    initial_prompt_tokens = []

    def new_segment(
        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
    ):
        tokens = tokens.tolist()
        text_tokens = [token for token in tokens if token < tokenizer.eot]
        return {
            "seek": seek,
            "start": start,
            "end": end,
            "text": tokenizer.decode(text_tokens),
            "tokens": tokens,
            "temperature": result.temperature,
            "avg_logprob": result.avg_logprob,
            "compression_ratio": result.compression_ratio,
            "no_speech_prob": result.no_speech_prob,
        }

    last_speech_timestamp = 0.0
    while seek < content_frames:
        time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
        mel_segment = mel[:, seek : seek + N_FRAMES]
        segment_size = min(N_FRAMES, content_frames - seek)
        segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
        mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

        decode_options["prompt"] = all_tokens[prompt_reset_since:]
        result: DecodingResult = decode_with_fallback(mel_segment)
        tokens = torch.tensor(result.tokens)

        previous_seek = seek
        current_segments = []

        timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
        single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
        consecutive.add_(1)
        if len(consecutive) > 0:
            # if the output contains two consecutive timestamp tokens
            slices = consecutive.tolist()
            if single_timestamp_ending:
                slices.append(len(tokens))

            last_slice = 0
            for current_slice in slices:
                sliced_tokens = tokens[last_slice:current_slice]
                start_timestamp_pos = (
                    sliced_tokens[0].item() - tokenizer.timestamp_begin
                )
                end_timestamp_pos = (
                    sliced_tokens[-1].item() - tokenizer.timestamp_begin
                )
                current_segments.append(
                    new_segment(
                        start=time_offset + start_timestamp_pos * time_precision,
                        end=time_offset + end_timestamp_pos * time_precision,
                        tokens=sliced_tokens,
                        result=result,
                    )
                )
                last_slice = current_slice

            if single_timestamp_ending:
                # single timestamp at the end means no speech after the last timestamp.
                seek += segment_size
            else:
                # otherwise, ignore the unfinished segment and seek to the last timestamp
                last_timestamp_pos = (
                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                )
                seek += last_timestamp_pos * input_stride
        else:
            duration = segment_duration
            timestamps = tokens[timestamp_tokens.nonzero().flatten()]
            if (
                len(timestamps) > 0
                and timestamps[-1].item() != tokenizer.timestamp_begin
            ):
                # no consecutive timestamps but it has a timestamp; use the last one.
                last_timestamp_pos = (
                    timestamps[-1].item() - tokenizer.timestamp_begin
                )
                duration = last_timestamp_pos * time_precision

            current_segments.append(
                new_segment(
                    start=time_offset,
                    end=time_offset + duration,
                    tokens=tokens,
                    result=result,
                )
            )
            seek += segment_size

        # if a segment is instantaneous or does not contain text, clear it
        for i, segment in enumerate(current_segments):
            if segment["start"] == segment["end"] or segment["text"].strip() == "":
                segment["text"] = ""
                segment["tokens"] = []
                segment["words"] = []

        all_segments.extend(
            [
                {"id": i, **segment}
                for i, segment in enumerate(
                    current_segments, start=len(all_segments)
                )
            ]
        )
        all_tokens.extend(
            [token for segment in current_segments for token in segment["tokens"]]
        )

        if not condition_on_previous_text or result.temperature > 0.5:
              # do not feed the prompt tokens if a high temperature was used
              prompt_reset_since = len(all_tokens)

    return dict(
        text=tokenizer.decode(all_tokens),
        segments=all_segments,
        language=language,
    )

In [34]:
import re

# Function to ignore tags in transcripts
def remove_tags(input_string):
    tag_pattern = re.compile(r'<[^>]+>')
    cleaned_string = re.sub(tag_pattern, '', input_string)
    return cleaned_string

def inferencer(json_file_path, model, dataset):

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = custom_transcribe(model = model, audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(remove_tags(ground_truth))
        predictions.append(remove_tags(prediction))

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [35]:
model = whisper.load_model("./whisper-small-finetuned.pt")
json_file_path = "./fine_tuned_whisper_with_constrained_filtering.json"
inferencer(json_file_path, model, dataset['test'])

100%|██████████| 61/61 [02:50<00:00,  2.80s/it]


In [36]:
!head -n 5 ./fine_tuned_whisper_with_constrained_filtering.json

[
    {
        "cer": 0.5576397030089879,
        "wer": 0.7653598595669983
    },



## Task 2.3: Beam search Decoding using finetuned Whisper model


In [21]:
def inferencer(json_file_path, model, dataset, beam_size):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi", beam_size=beam_size)['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [22]:
model = whisper.load_model("./whisper-small-finetuned.pt")
json_file_path = "./fine_tuned_whisper_with_beam_search.json"
inferencer(json_file_path, model, dataset['test'], beam_size=3)

100%|██████████| 61/61 [06:54<00:00,  6.80s/it]


In [23]:
!head -n 5 ./fine_tuned_whisper_with_beam_search.json

[
    {
        "cer": 0.47596717467760846,
        "wer": 0.6933879461673493
    },


## Task 2.4: Constrained Beam Search Decoding on finetuned Whisper model



In [24]:
from whisper.decoding import BeamSearchDecoder

"""
Custom classes to integrate constraints within beam search decoding.
"""
class CustomBeamSearchDecoder(BeamSearchDecoder):
    """
    Updates the existing `BeamSearchDecoder` class form whisper to use constrained beam search
    """
    def __init__(
        self,
        beam_size: int,
        eot: int,
        inference: Inference,
        cutoff: int = 3,
        patience: Optional[float] = None,
    ):
        super().__init__(beam_size, eot, inference, patience)
        # TODO: Add any additional variables you need for constrained beam search here

        self.csize = self.beam_size // 2
        self.ucsize = self.beam_size - self.csize
        self.debugCount = 0 
        self.tokenizer1 = get_tokenizer(
            model.is_multilingual,
            num_languages=model.num_languages,
            language="hi",
            task="transcribe"
        )
        self.hasEngMem = {}
        self.engMask = None
    
    def isengtoken(self, token, ret_print=False):
        text = self.tokenizer1.decode([token])
        if text[:2]=="<|":
            return False
        for i in text:
            if (i <= 'z' and i >= 'a') or (i <= 'Z' and i >= 'A'):
                return True
        return False
    
    def hasEng(self, tokens, ret_print=False):
        if (tokens  in  self.hasEngMem ):
            return self.hasEngMem[tokens]
        if tokens[:-1]  in  self.hasEngMem:
            x = self.hasEngMem[tokens[:-1]] or self.isengtoken(tokens[-1])
            self.hasEngMem[tokens] = x
            return x
        for token in tokens:
            if self.isengtoken(token, ret_print):
                self.hasEngMem[tokens] = True
                return True
        self.hasEngMem[tokens] = False
        return False
    
    def getEng(self, tokens, ret_print = False):
        text = []
        for token in tokens:
            text.append(self.tokenizer1.decode([token]))
        return '_'.join(text)

    def topkeng(self, logProbs: Tensor, k : int, prefix):
        if ( self.engMask is None):
            self.engMask = torch.zeros_like(logProbs)
            for i in range(len(logProbs)):
                if self.isengtoken(i):
                    self.engMask[i] = 1
            self.engMask = self.engMask == 1
        masked_logprobs = logProbs[self.engMask]
        t_Range = torch.arange(len(logProbs))
        t_Range=t_Range.to(model.device)
        masked_range = t_Range[self.engMask]
        values, top_k_indices = torch.topk(masked_logprobs, k, largest=True)
        indices = masked_range[top_k_indices]
        return values, indices

    def update(
        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
    ) -> Tuple[Tensor, bool]:
        """
        TODO: This is the main routine that implements constrained beam search
        Updates the token list `tokens` to add a next best token using constrained beam search
        Refer to https://github.com/openai/whisper/blob/main/whisper/decoding.py#L301-L404 for the original update function
        """        
        if tokens.shape[0] % self.beam_size != 0:
            raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")

        n_audio = tokens.shape[0] // self.beam_size
        if self.finished_sequences is None:  # for the first update
            self.finished_sequences = [{} for _ in range(n_audio)]
            self.c = tokens.shape[0]//2
            self.uc = tokens.shape[0] - self.c
    
        logprobs = F.log_softmax(logits.float(), dim=-1)
        next_tokens, source_indices, finished_sequences = [], [], []

        for i in range(n_audio):
            scores, sources, finished = {}, {}, {}
            c_scores, c_sources, c_finished = {}, {}, {}
            uc_scores, uc_sources, uc_finished = {}, {}, {}
            
            # Alternate Step 1
            for j in range(self.beam_size):
                idx = i * self.beam_size + j
                prefix = tokens[idx]
                if ( not self.hasEng(prefix, True) ):
                    prefix = prefix.tolist()
                    for logprob, token in zip(*logprobs[idx].topk(self.beam_size+1)):
                        new_logprob = (sum_logprobs[idx] + logprob).item()
                        sequence = tuple(prefix + [token.item()])
                        uc_scores[sequence] = new_logprob
                        uc_sources[sequence] = idx
                    for logprob, token in zip(*self.topkeng(logprobs[idx],self.beam_size+1, prefix)):
                        new_logprob = (sum_logprobs[idx] + logprob).item()
                        sequence = tuple(prefix + [token.item()])
                        c_scores[sequence] = new_logprob
                        c_sources[sequence] = idx
                else:
                    prefix = prefix.tolist()
                    for logprob, token in zip(*logprobs[idx].topk(self.beam_size+1)):
                        new_logprob = (sum_logprobs[idx] + logprob).item()
                        sequence = tuple(prefix + [token.item()])
                        c_scores[sequence] = new_logprob
                        c_sources[sequence] = idx
                        uc_scores[sequence] = new_logprob
                        uc_sources[sequence] = idx

            # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
            saved_with_eng = 0
            saved_uc = 0
            for sequence in sorted(uc_scores, key=uc_scores.get, reverse=True):
                if sequence[-1] == self.eot:
                    uc_finished[sequence]  = uc_scores[sequence]
                    if ( self.hasEng(sequence) ):
                        saved_with_eng += 1
                        finished[sequence] = uc_scores[sequence]
                else:
                    sum_logprobs[len(next_tokens)] = uc_scores[sequence]
                    next_tokens.append(sequence)
                    source_indices.append(uc_sources[sequence])
                    saved_uc += 1 
                    if saved_uc == self.ucsize:
                        break

            saved_c = 0
            for sequence in sorted(c_scores, key=c_scores.get, reverse=True):
                if sequence[-1] == self.eot:
                    c_finished[sequence] = c_scores[sequence]
                    if (not self.hasEng(sequence)):
                        raise("Constrained sequence without english tokens is not possible. Something went wrong")
                    saved_with_eng += 1
                    finished[sequence] = c_scores[sequence]
                else:
                    sum_logprobs[len(next_tokens)] = c_scores[sequence]
                    next_tokens.append(sequence)
                    source_indices.append(c_sources[sequence])
                    saved_c += 1
                    if saved_c == self.csize:
                        break

            finished_sequences.append(finished)
            
        tokens = torch.tensor(next_tokens, device=tokens.device)
        self.inference.rearrange_kv_cache(source_indices)

        # add newly finished sequences to self.finished_sequences
        assert len(self.finished_sequences) == len(finished_sequences)
        for previously_finished, newly_finished in zip(
            self.finished_sequences, finished_sequences
        ):
            for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
                if len(previously_finished) >= self.max_candidates:
                    break  # the candidate list is full
                previously_finished[seq] = newly_finished[seq]

        # mark as completed if all audio has enough number of samples
        completed = all(
            len(sequences) >= self.max_candidates and all ( self.hasEng(sequence) for sequence in sequences )
            for sequences in self.finished_sequences
        )
        return tokens, completed

    def finalize(self, preceding_tokens, sum_logprobs):
        # collect all finished sequences, including patience, and add unfinished ones if not enough
        """
        TODO: Add new code (or copy existing finalize code) here to implement constrained beam search
        Refer to https://github.com/openai/whisper/blob/main/whisper/decoding.py#L301-L404 for the original finalize function
        """
        # collect all finished sequences, including patience, and add unfinished ones if not enough
        sum_logprobs = sum_logprobs.cpu()
        for i, sequences in enumerate(self.finished_sequences):
            if (
                len(sequences) < self.beam_size
            ):  # when not enough sequences are finished
                for j in list(np.argsort(sum_logprobs[i]))[::-1]:
                    sequence = preceding_tokens[i, j].tolist() + [self.eot]
                    sequences[tuple(sequence)] = sum_logprobs[i][j].item()
                    if len(sequences) >= self.beam_size:
                        break

        tokens: List[List[Tensor]] = [
            [torch.tensor(seq) for seq in sequences.keys()]
            for sequences in self.finished_sequences
        ]
        sum_logprobs: List[List[float]] = [
            list(sequences.values()) for sequences in self.finished_sequences
        ]
        return tokens, sum_logprobs

class CustomDecodingTask(DecodingTask):
    """
    Updates the existing `DecodingTask` class form whisper to use constrained beam search
    """
    def __init__(self, model: "Whisper", options: DecodingOptions):
        options = replace(options, beam_size = 2 * options.beam_size if options.beam_size is not None else None)
        super().__init__(model, options)
        self.decoder = CustomBeamSearchDecoder(
            self.options.beam_size, self.tokenizer.eot, self.inference, self.options.patience
        )

@torch.no_grad()
def custom_decode(
    model: torch.nn.Module,
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),
    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    """
    decode function to perform constrained beam search decoding
    """
    single = mel.ndim
    if single == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = CustomDecodingTask(model, options).run(mel)

    return result[0] if single else result


In [25]:
def inferencer(json_file_path, model, dataset, beam_size):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = custom_transcribe(model=model, audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi", beam_size=beam_size)['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)
        
        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [26]:
model = whisper.load_model("./whisper-small-finetuned.pt")
json_file_path = "./fine_tuned_whisper_with_constrained_beam_search.json"
inferencer(json_file_path, model, dataset['test'], beam_size=3)

100%|██████████| 61/61 [29:56<00:00, 29.45s/it] 


In [27]:
!head -n 5 ./fine_tuned_whisper_with_constrained_beam_search.json

[
    {
        "cer": 0.6510355607659242,
        "wer": 0.8308952603861908
    },
