In [1]:
# Necessary installations
!pip install transformers[torch]
!pip install datasets
!pip install evaluate
!pip install jiwer
!pip install -U openai-whisper

Collecting openai-whisper
  Using cached openai_whisper-20231117-py3-none-any.whl
  Downloading openai-whisper-20231106.tar.gz (798 kB)
[K     |████████████████████████████████| 798 kB 16.7 MB/s eta 0:00:01
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone


## Dataset preparation

In [2]:
# Download dataset and unzip
!wget https://www.cse.iitb.ac.in/~pjyothi/cs753/dataset.zip -O dataset.zip
!unzip dataset.zip -d dataset

--2024-04-02 17:37:50--  https://www.cse.iitb.ac.in/~pjyothi/cs753/dataset.zip
Resolving www.cse.iitb.ac.in (www.cse.iitb.ac.in)... 10.129.3.3
Connecting to www.cse.iitb.ac.in (www.cse.iitb.ac.in)|10.129.3.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 139124027 (133M) [application/zip]
Saving to: ‘dataset.zip’


2024-04-02 17:37:54 (30.8 MB/s) - ‘dataset.zip’ saved [139124027/139124027]

Archive:  dataset.zip
   creating: dataset/CodeSwitched_Data/
  inflating: dataset/CodeSwitched_Data/valid.json  
  inflating: dataset/CodeSwitched_Data/test_blind_no_transcript.json  
  inflating: dataset/CodeSwitched_Data/train.json  
  inflating: dataset/CodeSwitched_Data/test_blind.json  
   creating: dataset/CodeSwitched_Data/audio/
  inflating: dataset/CodeSwitched_Data/audio/SHA1P_utt00003320.wav  
  inflating: dataset/CodeSwitched_Data/audio/SHA1P_utt00009337.wav  
  inflating: dataset/CodeSwitched_Data/audio/SHA1P_utt00004786.wav  
  inflating: dataset/CodeSwitche

In [3]:
import datasets
from datasets import load_dataset

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union, Optional,Tuple

from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import WhisperForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm
2024-04-02 23:13:04.402134: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-02 23:13:04.438368: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-02 23:13:04.438403: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-02 23:13:04.438426: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-02 23:13:0

In [4]:
 # Load json files for test split
data_files = {
    "test": "dataset/CodeSwitched_Data/test.json"
}
dataset = load_dataset("json", data_files=data_files)

# Update the audio paths to include appropriate folder-name
def prepend_folder_name(row):
    row["audio"] = 'dataset/CodeSwitched_Data/' + row["audio"]
    return row
for key in dataset:
    dataset[key] = dataset[key].map(prepend_folder_name)

# Cast columns to appropriate features
features = datasets.Features(
    {
        "id": datasets.Value("string"),
        "transcription": datasets.Value("string"),
        "audio": datasets.Audio(sampling_rate=16000),
    }
)
dataset = dataset.map(features.encode_example, features=features)

In [5]:
# Load necessary processors
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    # encode target text to label ids
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch
dataset = dataset.map(prepare_dataset, num_proc=2)

In [7]:
import json
import numpy as np
from evaluate import load
import whisper
import tqdm

wer = load("wer")
cer = load("cer")

def inferencer2( model, dataset):


    predictions = []
    ground_truths = []

    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)
    return cer_,wer_

In [10]:
!wget https://www.cse.iitb.ac.in/~pjyothi/cs753/whisper-small-finetuned.pt -O whisper-small-finetuned.pt

--2024-04-02 17:39:42--  https://www.cse.iitb.ac.in/~pjyothi/cs753/whisper-small-finetuned.pt
Resolving www.cse.iitb.ac.in (www.cse.iitb.ac.in)... 10.129.3.3
Connecting to www.cse.iitb.ac.in (www.cse.iitb.ac.in)|10.129.3.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 967102564 (922M)
Saving to: ‘whisper-small-finetuned.pt’


2024-04-02 17:40:09 (34.8 MB/s) - ‘whisper-small-finetuned.pt’ saved [967102564/967102564]



In [8]:
torch.cuda.set_device(4) 

In [28]:
finetuned_model = whisper.load_model("whisper-small-finetuned.pt")
finetuned_cer,finetuned_wer=inferencer2(finetuned_model,dataset["test"])

100%|██████████| 61/61 [10:05<00:00,  9.93s/it]


In [29]:
finetuned_cer, finetuned_wer

(0.4630715123094959, 0.6834406085430076)

## Task 2.1: Zero shot whisper greedy decoding

> In this task, you will run a standard greedy decoding of the test utterances using the pretrained Whisper small model.



In [9]:
import json
import numpy as np
from evaluate import load
import whisper
import tqdm

wer = load("wer")
cer = load("cer")

def inferencer(json_file_path, model, dataset):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [43]:
model = whisper.load_model("small")
json_file_path = "./zero_shot_whisper.json"
inferencer(json_file_path, model, dataset['test'])

100%|███████████████████████████████████████| 461M/461M [00:07<00:00, 68.9MiB/s]
100%|██████████| 61/61 [13:53<00:00, 13.66s/it]


In [44]:
!head -n 5 ./zero_shot_whisper.json

[
    {
        "cer": 0.6453041552689853,
        "wer": 0.8660035108250439
    },


In [38]:
device = torch.device("cuda:4")

## Task 2.2: Constrained Filtering-based Greedy Decoding

In [11]:
from whisper.decoding import DecodingOptions, DecodingResult, DecodingTask, GreedyDecoder, TokenDecoder, Inference
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical

from whisper.audio import CHUNK_LENGTH
from whisper.tokenizer import Tokenizer, get_tokenizer
from whisper.utils import compression_ratio

if TYPE_CHECKING:
    from whisper.model import Whisper

In [39]:
"""
TODO: Define helper functions to perform constrained filtering here, if needed.
"""
import torch
from typing import Tuple

def sample_batch(logits: torch.Tensor, N: int = 10, p: float = 0.9) -> Tuple[torch.Tensor, torch.Tensor]:
    """
      Sample the next best token for every example within a batch, using constrained filtering.
    """
    # TODO: Implement this function
    # TODO: You can create any helper functions that sample_batch needs within this cell
    soft_logits = F.softmax(logits)
    values, indices = torch.topk(soft_logits.reshape(-1), N)
    
    sum = 0
    for i, value in enumerate(values):
       sum += value
       if (sum >= p):
         values = values[ :i+1]
         indices = indices[ :i+1]

    
    distribution = torch.distributions.Categorical(values)
    
    return torch.tensor([indices[distribution.sample()]]).to(device)


In [23]:
import os
import traceback
import warnings

from whisper.audio import (
    FRAMES_PER_SECOND,
    HOP_LENGTH,
    N_FRAMES,
    N_SAMPLES,
    SAMPLE_RATE,
    log_mel_spectrogram,
    pad_or_trim,
)
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.utils import (
    exact_div,
    format_timestamp,
    get_writer,
    make_safe,
    optional_float,
    optional_int,
    str2bool,
)

In [19]:
"""
Custom classes to integrate constrained filtering within greedy decoding.
"""

class CustomGreedyDecoder(GreedyDecoder):
    """
    Updates the existing `GreedyDecoder` class form whisper to use constrained_filtering
    """
    def __init__(self, temperature: float, eot: int):
        super().__init__(temperature, eot)

    def update(
        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
    ) -> Tuple[Tensor, bool]:
        """
        Updates the token list `tokens` to add a next best token using constrained_filtering
        """
        # TODO: sample_batch should be modified to implement constrained_filtering
        # print("Logits       ##############", logits.shape)
        # print("SUM_LOGPROBS ##############", sum_logprobs.shape)
        # print("TOKENS       ##############", tokens.shape)

        next_tokens = sample_batch(logits)
        
        logprobs = F.log_softmax(logits.float(), dim=-1)
        current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
        sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)  

        # print("Next Tokens  ##############", next_tokens.shape)
        # print("currLogprobs ##############", current_logprobs.shape)
        next_tokens[tokens[:, -1] == self.eot] = self.eot
        tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)

        completed = (tokens[:, -1] == self.eot).all()
        return tokens, completed

class CustomDecodingTask(DecodingTask):
    """
    Updates the existing `DecodingTask` class form whisper to incorporate constrained_filtering
    """
    def __init__(self, model: "Whisper", options: DecodingOptions):
        super().__init__(model,options)
        self.decoder = CustomGreedyDecoder(
            self.options.temperature, self.tokenizer.eot
        )
@torch.no_grad()
def custom_decode(
    model: torch.nn.Module,
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),

    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    """
    decode function to perform constrained_filtering based greedy decoding
    """
    single = mel.ndim
    if single == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = CustomDecodingTask(model, options).run(mel)

    return result[0] if single else result

def custom_transcribe(
    model: "Whisper",
    audio: Union[str, np.ndarray, torch.Tensor],
    *,
    compression_ratio_threshold: Optional[float] = 2.4,
    logprob_threshold: Optional[float] = -1.0,
    condition_on_previous_text: bool = True,
    prepend_punctuations: str = "\"'“¿([{-",
    append_punctuations: str = "\"'.。,，!！?？:：”)]}、",
    **decode_options,
):
    dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
    if model.device == torch.device("cpu"):
        if torch.cuda.is_available():
            warnings.warn("Performing inference on CPU when CUDA is available")
        if dtype == torch.float16:
            warnings.warn("FP16 is not supported on CPU; using FP32 instead")
            dtype = torch.float32

    if dtype == torch.float32:
        decode_options["fp16"] = False

    # Pad 30-seconds of silence to the input audio, for slicing
    mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
    content_frames = mel.shape[-1] - N_FRAMES

    language: str = decode_options["language"]
    task: str = decode_options.get("task", "transcribe")
    tokenizer = get_tokenizer(
        model.is_multilingual,
        num_languages=model.num_languages,
        language=language,
        task=task,
    )

    def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
        decode_result = None

        temp = 0.0
        kwargs = {**decode_options}
        kwargs.pop("best_of", None)

        options = DecodingOptions(**kwargs, temperature=temp)
        decode_result = custom_decode(model, segment, options)

        return decode_result

    seek = 0
    input_stride = exact_div(
        N_FRAMES, model.dims.n_audio_ctx
    )  # mel frames per output token: 2
    time_precision = (
        input_stride * HOP_LENGTH / SAMPLE_RATE
    )  # time per output token: 0.02 (seconds)
    all_tokens = []
    all_segments = []

    prompt_reset_since = 0
    initial_prompt_tokens = []

    def new_segment(
        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
    ):
        tokens = tokens.tolist()
        text_tokens = [token for token in tokens if token < tokenizer.eot]
        return {
            "seek": seek,
            "start": start,
            "end": end,
            "text": tokenizer.decode(text_tokens),
            "tokens": tokens,
            "temperature": result.temperature,
            "avg_logprob": result.avg_logprob,
            "compression_ratio": result.compression_ratio,
            "no_speech_prob": result.no_speech_prob,
        }

    last_speech_timestamp = 0.0
    while seek < content_frames:
        time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
        mel_segment = mel[:, seek : seek + N_FRAMES]
        segment_size = min(N_FRAMES, content_frames - seek)
        segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
        mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

        decode_options["prompt"] = all_tokens[prompt_reset_since:]
        result: DecodingResult = decode_with_fallback(mel_segment)
        tokens = torch.tensor(result.tokens)

        previous_seek = seek
        current_segments = []

        timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
        single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
        consecutive.add_(1)
        if len(consecutive) > 0:
            # if the output contains two consecutive timestamp tokens
            slices = consecutive.tolist()
            if single_timestamp_ending:
                slices.append(len(tokens))

            last_slice = 0
            for current_slice in slices:
                sliced_tokens = tokens[last_slice:current_slice]
                start_timestamp_pos = (
                    sliced_tokens[0].item() - tokenizer.timestamp_begin
                )
                end_timestamp_pos = (
                    sliced_tokens[-1].item() - tokenizer.timestamp_begin
                )
                current_segments.append(
                    new_segment(
                        start=time_offset + start_timestamp_pos * time_precision,
                        end=time_offset + end_timestamp_pos * time_precision,
                        tokens=sliced_tokens,
                        result=result,
                    )
                )
                last_slice = current_slice

            if single_timestamp_ending:
                # single timestamp at the end means no speech after the last timestamp.
                seek += segment_size
            else:
                # otherwise, ignore the unfinished segment and seek to the last timestamp
                last_timestamp_pos = (
                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                )
                seek += last_timestamp_pos * input_stride
        else:
            duration = segment_duration
            timestamps = tokens[timestamp_tokens.nonzero().flatten()]
            if (
                len(timestamps) > 0
                and timestamps[-1].item() != tokenizer.timestamp_begin
            ):
                # no consecutive timestamps but it has a timestamp; use the last one.
                last_timestamp_pos = (
                    timestamps[-1].item() - tokenizer.timestamp_begin
                )
                duration = last_timestamp_pos * time_precision

            current_segments.append(
                new_segment(
                    start=time_offset,
                    end=time_offset + duration,
                    tokens=tokens,
                    result=result,
                )
            )
            seek += segment_size

        # if a segment is instantaneous or does not contain text, clear it
        for i, segment in enumerate(current_segments):
            if segment["start"] == segment["end"] or segment["text"].strip() == "":
                segment["text"] = ""
                segment["tokens"] = []
                segment["words"] = []

        all_segments.extend(
            [
                {"id": i, **segment}
                for i, segment in enumerate(
                    current_segments, start=len(all_segments)
                )
            ]
        )
        all_tokens.extend(
            [token for segment in current_segments for token in segment["tokens"]]
        )

        if not condition_on_previous_text or result.temperature > 0.5:
              # do not feed the prompt tokens if a high temperature was used
              prompt_reset_since = len(all_tokens)

    return dict(
        text=tokenizer.decode(all_tokens),
        segments=all_segments,
        language=language,
    )

In [34]:
import re

# Function to ignore tags in transcripts
def remove_tags(input_string):
    tag_pattern = re.compile(r'<[^>]+>')
    cleaned_string = re.sub(tag_pattern, '', input_string)
    return cleaned_string

def inferencer(json_file_path, model, dataset):

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = custom_transcribe(model = model, audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(remove_tags(ground_truth))
        predictions.append(remove_tags(prediction))

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [40]:
model = whisper.load_model("whisper-small-finetuned.pt")
json_file_path = "./fine_tuned_whisper_with_constrained_filtering.json"
inferencer(json_file_path, model, dataset['test'])

  soft_logits = F.softmax(logits)
100%|██████████| 61/61 [07:29<00:00,  7.37s/it]


In [41]:
!head -n 5 ./fine_tuned_whisper_with_constrained_filtering.json

[
    {
        "cer": 0.5599843688940993,
        "wer": 0.7717963721474547
    },



## Task 2.3: Beam search Decoding using finetuned Whisper model


In [30]:
def inferencer(json_file_path, model, dataset, beam_size):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi", beam_size=beam_size)['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [15]:
model = whisper.load_model("whisper-small-finetuned.pt")

In [31]:
json_file_path = "./fine_tuned_whisper_with_beam_search.json"
inferencer(json_file_path, model, dataset['test'], beam_size=3)

100%|██████████| 61/61 [16:34<00:00, 16.30s/it]


In [32]:
!head -n 5 ./fine_tuned_whisper_with_beam_search.json

[
    {
        "cer": 0.49863227823368506,
        "wer": 0.713867758923347
    },


## Task 2.4: Constrained Beam Search Decoding on finetuned Whisper model



In [24]:
from whisper.decoding import BeamSearchDecoder

"""
Custom classes to integrate constraints within beam search decoding.
"""
class CustomBeamSearchDecoder(BeamSearchDecoder):
    """
    Updates the existing `BeamSearchDecoder` class form whisper to use constrained beam search
    """
    def __init__(
        self,
        beam_size: int,
        eot: int,
        inference: Inference,
        cutoff: int = 3,
        patience: Optional[float] = None,
    ):
        super().__init__(beam_size, eot, inference, patience)
        # TODO: Add any additional variables you need for constrained beam search here

    def update(
        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
    ) -> Tuple[Tensor, bool]:
        """
        TODO: This is the main routine that implements constrained beam search
        Updates the token list `tokens` to add a next best token using constrained beam search
        Refer to https://github.com/openai/whisper/blob/main/whisper/decoding.py#L301-L404 for the original update function
        """
        if tokens.shape[0] % self.beam_size != 0:
            raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")

        n_audio = tokens.shape[0] // self.beam_size
        if self.finished_sequences is None:  # for the first update
            self.finished_sequences = [{} for _ in range(n_audio)]

        logprobs = F.log_softmax(logits.float(), dim=-1)
        next_tokens, source_indices, finished_sequences = [], [], []
        for i in range(n_audio):
            scores, sources, finished = {}, {}, {}

            # STEP 1: calculate the cumulative log probabilities for possible candidates
            for j in range(self.beam_size):
                idx = i * self.beam_size + j
                prefix = tokens[idx].tolist()
                for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
                    new_logprob = (sum_logprobs[idx] + logprob).item()
                    sequence = tuple(prefix + [token.item()])
                    scores[sequence] = new_logprob
                    sources[sequence] = idx

            # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
            saved = 0
            for sequence in sorted(scores, key=scores.get, reverse=True):
                if sequence[-1] == self.eot:
                    finished[sequence] = scores[sequence]
                else:
                    sum_logprobs[len(next_tokens)] = scores[sequence]
                    next_tokens.append(sequence)
                    source_indices.append(sources[sequence])

                    saved += 1
                    if saved == self.beam_size:
                        break

            finished_sequences.append(finished)

        tokens = torch.tensor(next_tokens, device=tokens.device)
        self.inference.rearrange_kv_cache(source_indices)

        # add newly finished sequences to self.finished_sequences
        assert len(self.finished_sequences) == len(finished_sequences)
        for previously_finished, newly_finished in zip(
            self.finished_sequences, finished_sequences
        ):
            for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
                if len(previously_finished) >= self.max_candidates:
                    break  # the candidate list is full
                previously_finished[seq] = newly_finished[seq]

        # mark as completed if all audio has enough number of samples
        completed = all(
            len(sequences) >= self.max_candidates
            for sequences in self.finished_sequences
        )
        return tokens, completed

    def finalize(self, preceding_tokens, sum_logprobs):
        # collect all finished sequences, including patience, and add unfinished ones if not enough
        """
        TODO: Add new code (or copy existing finalize code) here to implement constrained beam search
        Refer to https://github.com/openai/whisper/blob/main/whisper/decoding.py#L301-L404 for the original finalize function
        """
        sum_logprobs = sum_logprobs.cpu()
        for i, sequences in enumerate(self.finished_sequences):
            if (
                len(sequences) < self.beam_size
            ):  # when not enough sequences are finished
                for j in list(np.argsort(sum_logprobs[i]))[::-1]:
                    sequence = preceding_tokens[i, j].tolist() + [self.eot]
                    sequences[tuple(sequence)] = sum_logprobs[i][j].item()
                    if len(sequences) >= self.beam_size:
                        break

        tokens: List[List[Tensor]] = [
            [torch.tensor(seq) for seq in sequences.keys()]
            for sequences in self.finished_sequences
        ]
        sum_logprobs: List[List[float]] = [
            list(sequences.values()) for sequences in self.finished_sequences
        ]
        return tokens, sum_logprobs

class CustomDecodingTask(DecodingTask):
    """
    Updates the existing `DecodingTask` class form whisper to use constrained beam search
    """
    def __init__(self, model: "Whisper", options: DecodingOptions):
        options = replace(options, beam_size = 2 * options.beam_size if options.beam_size is not None else None)
        super().__init__(model, options)
        self.decoder = CustomBeamSearchDecoder(
            self.options.beam_size, self.tokenizer.eot, self.inference, self.options.patience
        )

@torch.no_grad()
def custom_decode(
    model: torch.nn.Module,
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),
    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    """
    decode function to perform constrained beam search decoding
    """
    single = mel.ndim
    if single == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = CustomDecodingTask(model, options).run(mel)

    return result[0] if single else result


In [27]:
def inferencer(json_file_path, model, dataset, beam_size):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = custom_transcribe(model=model, audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi", beam_size=beam_size)['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [28]:
json_file_path = "./fine_tuned_whisper_with_constrained_beam_search.json"
inferencer(json_file_path, model, dataset['test'], beam_size=3)

100%|██████████| 61/61 [17:47<00:00, 17.49s/it]


In [29]:
!head -n 5 ./fine_tuned_whisper_with_constrained_beam_search.json

[
    {
        "cer": 0.6959749902305589,
        "wer": 0.874195435927443
    },
