In [1]:
# Necessary installations
!pip install transformers[torch]
!pip install datasets
!pip install evaluate
!pip install jiwer
!pip install -U openai-whisper

Collecting accelerate>=0.21.0 (from transformers[torch])
  Downloading accelerate-0.28.0-py3-none-any.whl (290 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->transformers[torch])
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->transformers[torch])
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m62.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->transformers[torch])
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
[2K     

## Dataset preparation

In [2]:
# Download dataset and unzip
!wget https://www.cse.iitb.ac.in/~pjyothi/cs753/dataset.zip
!unzip dataset.zip

--2024-03-27 09:07:10--  https://www.cse.iitb.ac.in/~pjyothi/cs753/dataset.zip
Resolving www.cse.iitb.ac.in (www.cse.iitb.ac.in)... 103.21.127.134
Connecting to www.cse.iitb.ac.in (www.cse.iitb.ac.in)|103.21.127.134|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 139124027 (133M) [application/zip]
Saving to: ‘dataset.zip’


2024-03-27 09:07:26 (10.8 MB/s) - ‘dataset.zip’ saved [139124027/139124027]

Archive:  dataset.zip
   creating: CodeSwitched_Data/
  inflating: CodeSwitched_Data/valid.json  
  inflating: CodeSwitched_Data/test_blind_no_transcript.json  
  inflating: CodeSwitched_Data/train.json  
  inflating: CodeSwitched_Data/test_blind.json  
   creating: CodeSwitched_Data/audio/
  inflating: CodeSwitched_Data/audio/SHA1P_utt00003320.wav  
  inflating: CodeSwitched_Data/audio/SHA1P_utt00009337.wav  
  inflating: CodeSwitched_Data/audio/SHA1P_utt00004786.wav  
  inflating: CodeSwitched_Data/audio/SHA1P_utt00008866.wav  
  inflating: CodeSwitched_Data/audi

In [3]:
import datasets
from datasets import load_dataset

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union, Optional,Tuple

from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import WhisperForConditionalGeneration

In [4]:
# Load json files for test split
data_files = {
    "test": "./CodeSwitched_Data/test.json"
}
dataset = load_dataset("json", data_files=data_files)

# Update the audio paths to include appropriate folder-name
def prepend_folder_name(row):
    row["audio"] = 'CodeSwitched_Data/' + row["audio"]
    return row
for key in dataset:
    dataset[key] = dataset[key].map(prepend_folder_name)

# Cast columns to appropriate features
features = datasets.Features(
    {
        "id": datasets.Value("string"),
        "transcription": datasets.Value("string"),
        "audio": datasets.Audio(sampling_rate=16000),
    }
)
dataset = dataset.map(features.encode_example, features=features)

Generating test split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/61 [00:00<?, ? examples/s]

Map:   0%|          | 0/61 [00:00<?, ? examples/s]

In [5]:
# Load necessary processors
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    # encode target text to label ids
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch
dataset = dataset.map(prepare_dataset, num_proc=2)

Map (num_proc=2):   0%|          | 0/61 [00:00<?, ? examples/s]

In [10]:
wer = load("wer")
cer = load("cer")

def inferencer2( model, dataset):


    predictions = []
    ground_truths = []

    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)
    return cer_,wer_

In [11]:
!wget https://www.cse.iitb.ac.in/~pjyothi/cs753/whisper-small-finetuned.pt


--2024-03-27 09:19:33--  https://www.cse.iitb.ac.in/~pjyothi/cs753/whisper-small-finetuned.pt
Resolving www.cse.iitb.ac.in (www.cse.iitb.ac.in)... 103.21.127.134
Connecting to www.cse.iitb.ac.in (www.cse.iitb.ac.in)|103.21.127.134|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 967102564 (922M)
Saving to: ‘whisper-small-finetuned.pt’


2024-03-27 09:26:56 (2.09 MB/s) - ‘whisper-small-finetuned.pt’ saved [967102564/967102564]



In [12]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [13]:
!ls

CodeSwitched_Data  drive	whisper-small-finetuned.pt
dataset.zip	   sample_data	zero_shot_whisper.json


In [14]:
finetuned_model = whisper.load_model("whisper-small-finetuned.pt")
finetuned_cer,finetuned_wer=inferencer2(finetuned_model,dataset["test"])

100%|██████████| 61/61 [06:51<00:00,  6.74s/it]


In [16]:
finetuned_cer, finetuned_wer

(0.43545655855151755, 0.6635459332943242)

## Task 2.1: Zero shot whisper greedy decoding

> In this task, you will run a standard greedy decoding of the test utterances using the pretrained Whisper small model.



In [7]:
import json
import numpy as np
from evaluate import load
import whisper
import tqdm

wer = load("wer")
cer = load("cer")

def inferencer(json_file_path, model, dataset):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

In [8]:
model = whisper.load_model("small")
json_file_path = "./zero_shot_whisper.json"
inferencer(json_file_path, model, dataset['test'])

100%|████████████████████████████████████████| 461M/461M [00:03<00:00, 141MiB/s]
100%|██████████| 61/61 [10:18<00:00, 10.13s/it]


In [9]:
!head -n 5 ./zero_shot_whisper.json

[
    {
        "cer": 0.6618470756806044,
        "wer": 0.8765359859566998
    },


## Task 2.2: Constrained Filtering-based Greedy Decoding

In [None]:
!wget https://www.cse.iitb.ac.in/~pjyothi/cs753/whisper-small-finetuned.pt

In [None]:
from whisper.decoding import DecodingOptions, DecodingResult, DecodingTask, GreedyDecoder, TokenDecoder, Inference
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical

from whisper.audio import CHUNK_LENGTH
from whisper.tokenizer import Tokenizer, get_tokenizer
from whisper.utils import compression_ratio

if TYPE_CHECKING:
    from whisper.model import Whisper

In [None]:
"""
TODO: Define helper functions to perform constrained filtering here, if needed.
"""

def sample_batch(logits: torch.Tensor, N: int = 10, p: float = 0.9) -> Tuple[torch.Tensor, torch.Tensor]:
    """
      Sample the next best token for every example within a batch, using constrained filtering.
    """
    # TODO: Implement this function
    # TODO: You can create any helper functions that sample_batch needs within this cell

In [None]:
import os
import traceback
import warnings

from whisper.audio import (
    FRAMES_PER_SECOND,
    HOP_LENGTH,
    N_FRAMES,
    N_SAMPLES,
    SAMPLE_RATE,
    log_mel_spectrogram,
    pad_or_trim,
)
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.utils import (
    exact_div,
    format_timestamp,
    get_writer,
    make_safe,
    optional_float,
    optional_int,
    str2bool,
)

In [None]:
"""
Custom classes to integrate constrained filtering within greedy decoding.
"""

class CustomGreedyDecoder(GreedyDecoder):
    """
    Updates the existing `GreedyDecoder` class form whisper to use constrained_filtering
    """
    def __init__(self, temperature: float, eot: int):
        super().__init__(temperature, eot)

    def update(
        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
    ) -> Tuple[Tensor, bool]:
        """
        Updates the token list `tokens` to add a next best token using constrained_filtering
        """
        # TODO: sample_batch should be modified to implement constrained_filtering
        next_tokens = sample_batch(logits)

        logprobs = F.log_softmax(logits.float(), dim=-1)
        current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
        sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)

        next_tokens[tokens[:, -1] == self.eot] = self.eot
        tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)

        completed = (tokens[:, -1] == self.eot).all()
        return tokens, completed

class CustomDecodingTask(DecodingTask):
    """
    Updates the existing `DecodingTask` class form whisper to incorporate constrained_filtering
    """
    def __init__(self, model: "Whisper", options: DecodingOptions):
        super().__init__(model,options)
        self.decoder = CustomGreedyDecoder(
            self.options.temperature, self.tokenizer.eot
        )
@torch.no_grad()
def custom_decode(
    model: torch.nn.Module,
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),

    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    """
    decode function to perform constrained_filtering based greedy decoding
    """
    single = mel.ndim
    if single == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = CustomDecodingTask(model, options).run(mel)

    return result[0] if single else result

def custom_transcribe(
    model: "Whisper",
    audio: Union[str, np.ndarray, torch.Tensor],
    *,
    compression_ratio_threshold: Optional[float] = 2.4,
    logprob_threshold: Optional[float] = -1.0,
    condition_on_previous_text: bool = True,
    prepend_punctuations: str = "\"'“¿([{-",
    append_punctuations: str = "\"'.。,，!！?？:：”)]}、",
    **decode_options,
):
    dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
    if model.device == torch.device("cpu"):
        if torch.cuda.is_available():
            warnings.warn("Performing inference on CPU when CUDA is available")
        if dtype == torch.float16:
            warnings.warn("FP16 is not supported on CPU; using FP32 instead")
            dtype = torch.float32

    if dtype == torch.float32:
        decode_options["fp16"] = False

    # Pad 30-seconds of silence to the input audio, for slicing
    mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
    content_frames = mel.shape[-1] - N_FRAMES

    language: str = decode_options["language"]
    task: str = decode_options.get("task", "transcribe")
    tokenizer = get_tokenizer(
        model.is_multilingual,
        num_languages=model.num_languages,
        language=language,
        task=task,
    )

    def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
        decode_result = None

        temp = 0.0
        kwargs = {**decode_options}
        kwargs.pop("best_of", None)

        options = DecodingOptions(**kwargs, temperature=temp)
        decode_result = custom_decode(model, segment, options)

        return decode_result

    seek = 0
    input_stride = exact_div(
        N_FRAMES, model.dims.n_audio_ctx
    )  # mel frames per output token: 2
    time_precision = (
        input_stride * HOP_LENGTH / SAMPLE_RATE
    )  # time per output token: 0.02 (seconds)
    all_tokens = []
    all_segments = []

    prompt_reset_since = 0
    initial_prompt_tokens = []

    def new_segment(
        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
    ):
        tokens = tokens.tolist()
        text_tokens = [token for token in tokens if token < tokenizer.eot]
        return {
            "seek": seek,
            "start": start,
            "end": end,
            "text": tokenizer.decode(text_tokens),
            "tokens": tokens,
            "temperature": result.temperature,
            "avg_logprob": result.avg_logprob,
            "compression_ratio": result.compression_ratio,
            "no_speech_prob": result.no_speech_prob,
        }

    last_speech_timestamp = 0.0
    while seek < content_frames:
        time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
        mel_segment = mel[:, seek : seek + N_FRAMES]
        segment_size = min(N_FRAMES, content_frames - seek)
        segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
        mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

        decode_options["prompt"] = all_tokens[prompt_reset_since:]
        result: DecodingResult = decode_with_fallback(mel_segment)
        tokens = torch.tensor(result.tokens)

        previous_seek = seek
        current_segments = []

        timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
        single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
        consecutive.add_(1)
        if len(consecutive) > 0:
            # if the output contains two consecutive timestamp tokens
            slices = consecutive.tolist()
            if single_timestamp_ending:
                slices.append(len(tokens))

            last_slice = 0
            for current_slice in slices:
                sliced_tokens = tokens[last_slice:current_slice]
                start_timestamp_pos = (
                    sliced_tokens[0].item() - tokenizer.timestamp_begin
                )
                end_timestamp_pos = (
                    sliced_tokens[-1].item() - tokenizer.timestamp_begin
                )
                current_segments.append(
                    new_segment(
                        start=time_offset + start_timestamp_pos * time_precision,
                        end=time_offset + end_timestamp_pos * time_precision,
                        tokens=sliced_tokens,
                        result=result,
                    )
                )
                last_slice = current_slice

            if single_timestamp_ending:
                # single timestamp at the end means no speech after the last timestamp.
                seek += segment_size
            else:
                # otherwise, ignore the unfinished segment and seek to the last timestamp
                last_timestamp_pos = (
                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                )
                seek += last_timestamp_pos * input_stride
        else:
            duration = segment_duration
            timestamps = tokens[timestamp_tokens.nonzero().flatten()]
            if (
                len(timestamps) > 0
                and timestamps[-1].item() != tokenizer.timestamp_begin
            ):
                # no consecutive timestamps but it has a timestamp; use the last one.
                last_timestamp_pos = (
                    timestamps[-1].item() - tokenizer.timestamp_begin
                )
                duration = last_timestamp_pos * time_precision

            current_segments.append(
                new_segment(
                    start=time_offset,
                    end=time_offset + duration,
                    tokens=tokens,
                    result=result,
                )
            )
            seek += segment_size

        # if a segment is instantaneous or does not contain text, clear it
        for i, segment in enumerate(current_segments):
            if segment["start"] == segment["end"] or segment["text"].strip() == "":
                segment["text"] = ""
                segment["tokens"] = []
                segment["words"] = []

        all_segments.extend(
            [
                {"id": i, **segment}
                for i, segment in enumerate(
                    current_segments, start=len(all_segments)
                )
            ]
        )
        all_tokens.extend(
            [token for segment in current_segments for token in segment["tokens"]]
        )

        if not condition_on_previous_text or result.temperature > 0.5:
              # do not feed the prompt tokens if a high temperature was used
              prompt_reset_since = len(all_tokens)

    return dict(
        text=tokenizer.decode(all_tokens),
        segments=all_segments,
        language=language,
    )

In [None]:
import re

# Function to ignore tags in transcripts
def remove_tags(input_string):
    tag_pattern = re.compile(r'<[^>]+>')
    cleaned_string = re.sub(tag_pattern, '', input_string)
    return cleaned_string

def inferencer(json_file_path, model, dataset):

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = custom_transcribe(model = model, audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi")['text']

        ground_truths.append(remove_tags(ground_truth))
        predictions.append(remove_tags(prediction))

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [None]:
model = whisper.load_model("/content/whisper-small-finetuned.pt")
json_file_path = "./fine_tuned_whisper_with_constrained_filtering.json"
inferencer(json_file_path, model, dataset['test'])

In [None]:
!head -n 5 ./fine_tuned_whisper_with_constrained_filtering.json


## Task 2.3: Beam search Decoding using finetuned Whisper model


In [None]:
def inferencer(json_file_path, model, dataset, beam_size):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = model.transcribe(audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi", beam_size=beam_size)['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [None]:
json_file_path = "./fine_tuned_whisper_with_beam_search.json"
inferencer(json_file_path, model, dataset['test'], beam_size=3)

 28%|██▊       | 17/61 [29:07<46:38, 63.59s/it]

In [None]:
!head -n 5 ./fine_tuned_whisper_with_beam_search.json

## Task 2.4: Constrained Beam Search Decoding on finetuned Whisper model



In [None]:
from whisper.decoding import BeamSearchDecoder

"""
Custom classes to integrate constraints within beam search decoding.
"""
class CustomBeamSearchDecoder(BeamSearchDecoder):
    """
    Updates the existing `BeamSearchDecoder` class form whisper to use constrained beam search
    """
    def __init__(
        self,
        beam_size: int,
        eot: int,
        inference: Inference,
        cutoff: int = 3,
        patience: Optional[float] = None,
    ):
        super().__init__(beam_size, eot, inference, patience)
        # TODO: Add any additional variables you need for constrained beam search here

    def update(
        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
    ) -> Tuple[Tensor, bool]:
        """
        TODO: This is the main routine that implements constrained beam search
        Updates the token list `tokens` to add a next best token using constrained beam search
        Refer to https://github.com/openai/whisper/blob/main/whisper/decoding.py#L301-L404 for the original update function
        """

        return tokens, completed

    def finalize(self, preceding_tokens, sum_logprobs):
        # collect all finished sequences, including patience, and add unfinished ones if not enough
        """
        TODO: Add new code (or copy existing finalize code) here to implement constrained beam search
        Refer to https://github.com/openai/whisper/blob/main/whisper/decoding.py#L301-L404 for the original finalize function
        """

        return tokens, sum_logprobs

class CustomDecodingTask(DecodingTask):
    """
    Updates the existing `DecodingTask` class form whisper to use constrained beam search
    """
    def __init__(self, model: "Whisper", options: DecodingOptions):
        options = replace(options, beam_size = 2 * options.beam_size if options.beam_size is not None else None)
        super().__init__(model, options)
        self.decoder = CustomBeamSearchDecoder(
            self.options.beam_size, self.tokenizer.eot, self.inference, self.options.patience
        )

@torch.no_grad()
def custom_decode(
    model: torch.nn.Module,
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),
    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    """
    decode function to perform constrained beam search decoding
    """
    single = mel.ndim
    if single == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = CustomDecodingTask(model, options).run(mel)

    return result[0] if single else result


In [None]:
def inferencer(json_file_path, model, dataset, beam_size):
    """
    Perform inference on every example within `dataset` using `model` and save it in `json_file_path`.
    """

    predictions = []
    ground_truths = []
    to_json = []
    for item in tqdm.tqdm(dataset):
        ground_truth = item['transcription']
        prediction = custom_transcribe(model=model, audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi", beam_size=beam_size)['text']

        ground_truths.append(ground_truth)
        predictions.append(prediction)

        to_json.append({
            'id': item['id'],
            'ground_truth': ground_truth,
            'prediction': prediction
        })

    cer_ = cer.compute(predictions=predictions, references=ground_truths)
    wer_ = wer.compute(predictions=predictions, references=ground_truths)

    to_json = [{'cer':cer_, 'wer': wer_}] + to_json

    # Write updated data back to JSON file
    with open(json_file_path, 'w') as json_file:
        json.dump(to_json, json_file, indent=4)

In [None]:
json_file_path = "./fine_tuned_whisper_with_constrained_beam_search.json"
inferencer(json_file_path, model, dataset['test'], beam_size=3)

In [None]:
!head -n 5 ./fine_tuned_whisper_with_constrained_beam_search.json