In [1]:
CUDA_VISIBLE_DEVICES=0,1,2,3,4

In [2]:
import datasets
from datasets import load_dataset

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union, Optional,Tuple

from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import WhisperForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm
2024-04-11 17:30:48.504728: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-11 17:30:48.542210: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-11 17:30:48.542241: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-11 17:30:48.542258: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-11 17:30:4

In [3]:
import json
import numpy as np
from evaluate import load
import whisper
import tqdm
model = whisper.load_model("whisper-small-finetuned.pt")

In [4]:
import os
import traceback
import warnings

from whisper.audio import (
    FRAMES_PER_SECOND,
    HOP_LENGTH,
    N_FRAMES,
    N_SAMPLES,
    SAMPLE_RATE,
    log_mel_spectrogram,
    pad_or_trim,
)
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.utils import (
    exact_div,
    format_timestamp,
    get_writer,
    make_safe,
    optional_float,
    optional_int,
    str2bool,
)

In [5]:
# Load json files for test split
data_files = {
    "test_blind": "dataset/CodeSwitched_Data/test_blind.json"
}
dataset = load_dataset("json", data_files=data_files)

# Update the audio paths to include appropriate folder-name
def prepend_folder_name(row):
    row["audio"] = 'dataset/CodeSwitched_Data/' + row["audio"]
    return row
for key in dataset:
    dataset[key] = dataset[key].map(prepend_folder_name)

# Cast columns to appropriate features
features = datasets.Features(
    {
        "id": datasets.Value("string"),
        "audio": datasets.Audio(sampling_rate=16000),
    }
)
dataset = dataset.map(features.encode_example, features=features)

In [6]:
# Load necessary processors
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
def prepare_dataset(batch):
    audio = batch["audio"]
    return batch
dataset = dataset.map(prepare_dataset, num_proc=2)

In [8]:
from whisper.decoding import DecodingOptions, DecodingResult, DecodingTask, GreedyDecoder, TokenDecoder, Inference
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical

from whisper.audio import CHUNK_LENGTH
from whisper.tokenizer import Tokenizer, get_tokenizer
from whisper.utils import compression_ratio

if TYPE_CHECKING:
    from whisper.model import Whisper

In [9]:
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language="hi", task="transcribe")

In [10]:
# tokens_in_task = set([220, 50257]).union(set(range(50257, 50364)))
tokens_in_task = set([220, 50257])
words_in_task = set([])
with open("word_dictionary.txt", "r") as f:
    words = f.readlines()

for word in words:
    tokens_in_task = tokens_in_task.union(set(tokenizer.encode(word.strip())))
    tokens_in_task = tokens_in_task.union(set(tokenizer.encode(" " + word.strip())))
    tokens_in_task = tokens_in_task.union(set(tokenizer.encode(word.strip() + " ")))
    tokens_in_task = tokens_in_task.union(set(tokenizer.encode(" " + word.strip() + " ")))
    words_in_task.add(word.strip())

tokens_in_task = list(tokens_in_task)
tokens_in_task.sort()

In [11]:
with open("tokens_in_task.txt", "w") as f:
    for t in tokens_in_task:
        f.write(str(t) + "\n")

### DECODER 

In [12]:
from whisper.decoding import BeamSearchDecoder

"""
Custom classes to integrate constraints within beam search decoding.
"""
class CustomBeamSearchDecoder(BeamSearchDecoder):
    """
    Updates the existing `BeamSearchDecoder` class form whisper to use constrained beam search
    """
    def __init__(
        self,
        beam_size: int,
        eot: int,
        inference: Inference,
        cutoff: int = 3,
        patience: Optional[float] = None,
    ):
        super().__init__(beam_size, eot, inference, patience)
        # TODO: Add any additional variables you need for constrained beam search here
        
    def update(
        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
    ) -> Tuple[Tensor, bool]:
        if tokens.shape[0] % self.beam_size != 0:
            raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")

        n_audio = tokens.shape[0] // self.beam_size
        if self.finished_sequences is None:  # for the first update
            self.finished_sequences = [{} for _ in range(n_audio)]

        logprobs = F.log_softmax(logits.float(), dim=-1)
        next_tokens, source_indices, finished_sequences = [], [], []
        for i in range(n_audio):
            scores, sources, finished = {}, {}, {}

            # STEP 1: calculate the cumulative log probabilities for possible candidates
            for j in range(self.beam_size):
                idx = i * self.beam_size + j
                prefix = tokens[idx].tolist()
                for logprob, token in zip(*logprobs[idx][tokens_in_task].topk(self.beam_size + 1)):
                # for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
                    new_logprob = (sum_logprobs[idx] + logprob).item()
                    sequence = tuple(prefix + [tokens_in_task[token.item()]])
                    # sequence = tuple(prefix + [token.item()])
                    scores[sequence] = new_logprob
                    sources[sequence] = idx

            # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
            saved = 0
            for sequence in sorted(scores, key=scores.get, reverse=True):
                if sequence[-1] == self.eot:
                    finished[sequence] = scores[sequence]
                else:
                    sum_logprobs[len(next_tokens)] = scores[sequence]
                    next_tokens.append(sequence)
                    source_indices.append(sources[sequence])

                    saved += 1
                    if saved == self.beam_size:
                        break

            finished_sequences.append(finished)

        tokens = torch.tensor(next_tokens, device=tokens.device)
        self.inference.rearrange_kv_cache(source_indices)

        # add newly finished sequences to self.finished_sequences
        assert len(self.finished_sequences) == len(finished_sequences)
        for previously_finished, newly_finished in zip(
            self.finished_sequences, finished_sequences
        ):
            for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
                if len(previously_finished) >= self.max_candidates:
                    break  # the candidate list is full
                previously_finished[seq] = newly_finished[seq]

        # mark as completed if all audio has enough number of samples
        completed = all(
            len(sequences) >= self.max_candidates
            for sequences in self.finished_sequences
        )
        return tokens, completed

    def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
        # collect all finished sequences, including patience, and add unfinished ones if not enough
        sum_logprobs = sum_logprobs.cpu()
        for i, sequences in enumerate(self.finished_sequences):
            if (
                len(sequences) < self.beam_size
            ):  # when not enough sequences are finished
                for j in list(np.argsort(sum_logprobs[i]))[::-1]:
                    sequence = preceding_tokens[i, j].tolist() + [self.eot]
                    sequences[tuple(sequence)] = sum_logprobs[i][j].item()
                    if len(sequences) >= self.beam_size:
                        break

        tokens: List[List[Tensor]] = [
            [torch.tensor(seq) for seq in sequences.keys()]
            for sequences in self.finished_sequences
        ]
        sum_logprobs: List[List[float]] = [
            list(sequences.values()) for sequences in self.finished_sequences
        ]
        return tokens, sum_logprobs

class CustomDecodingTask(DecodingTask):
    """
    Updates the existing `DecodingTask` class form whisper to use constrained beam search
    """
    def __init__(self, model: "Whisper", options: DecodingOptions):
        options = replace(options, beam_size = 2 * options.beam_size if options.beam_size is not None else None)
        super().__init__(model, options)
        self.decoder = CustomBeamSearchDecoder(
            self.options.beam_size, self.tokenizer.eot, self.inference, self.options.patience
        )

@torch.no_grad()
def custom_decode(
    model: torch.nn.Module,
    mel: torch.Tensor,
    options: DecodingOptions = DecodingOptions(),
    **kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
    """
    decode function to perform constrained beam search decoding
    """
    single = mel.ndim
    if single == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    result = CustomDecodingTask(model, options).run(mel)

    return result[0] if single else result


In [13]:
def custom_transcribe(
    model: "Whisper",
    audio: Union[str, np.ndarray, torch.Tensor],
    *,
    compression_ratio_threshold: Optional[float] = 2.4,
    logprob_threshold: Optional[float] = -1.0,
    condition_on_previous_text: bool = True,
    prepend_punctuations: str = "\"'“¿([{-",
    append_punctuations: str = "\"'.。,，!！?？:：”)]}、",
    **decode_options,
):
    dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
    if model.device == torch.device("cpu"):
        if torch.cuda.is_available():
            warnings.warn("Performing inference on CPU when CUDA is available")
        if dtype == torch.float16:
            warnings.warn("FP16 is not supported on CPU; using FP32 instead")
            dtype = torch.float32

    if dtype == torch.float32:
        decode_options["fp16"] = False

    # Pad 30-seconds of silence to the input audio, for slicing
    mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
    content_frames = mel.shape[-1] - N_FRAMES

    language: str = decode_options["language"]
    task: str = decode_options.get("task", "transcribe")
    tokenizer = get_tokenizer(
        model.is_multilingual,
        num_languages=model.num_languages,
        language=language,
        task=task,
    )
    
    def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
        decode_result = None

        temp = 0.0
        kwargs = {**decode_options}
        kwargs.pop("best_of", None)

        options = DecodingOptions(**kwargs, temperature=temp)
        decode_result = custom_decode(model, segment, options)

        return decode_result

    seek = 0
    input_stride = exact_div(
        N_FRAMES, model.dims.n_audio_ctx
    )  # mel frames per output token: 2
    time_precision = (
        input_stride * HOP_LENGTH / SAMPLE_RATE
    )  # time per output token: 0.02 (seconds)
    all_tokens = []
    all_segments = []

    prompt_reset_since = 0
    initial_prompt_tokens = []

    def new_segment(
        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
    ):
        tokens = tokens.tolist()
        t_tokens = [token for token in tokens if ((token >= tokenizer.eot) or (token in tokens_in_task))]
        print(set(tokens).difference(t_tokens))
        tokens = t_tokens
        text_tokens = [token for token in tokens if (token < tokenizer.eot) ]
        return {
            "seek": seek,
            "start": start,
            "end": end,
            "text": tokenizer.decode(text_tokens),
            "tokens": tokens,
            "temperature": result.temperature,
            "avg_logprob": result.avg_logprob,
            "compression_ratio": result.compression_ratio,
            "no_speech_prob": result.no_speech_prob,
        }
    
    last_speech_timestamp = 0.0
    while seek < content_frames:
        time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
        mel_segment = mel[:, seek : seek + N_FRAMES]
        segment_size = min(N_FRAMES, content_frames - seek)
        segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
        mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

        decode_options["prompt"] = all_tokens[prompt_reset_since:]
        result: DecodingResult = decode_with_fallback(mel_segment)
        tokens = torch.tensor(result.tokens)

        previous_seek = seek
        current_segments = []

        timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
        single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
        consecutive.add_(1)
        if len(consecutive) > 0:
            # if the output contains two consecutive timestamp tokens
            slices = consecutive.tolist()
            if single_timestamp_ending:
                slices.append(len(tokens))

            last_slice = 0
            for current_slice in slices:
                sliced_tokens = tokens[last_slice:current_slice]
                start_timestamp_pos = (
                    sliced_tokens[0].item() - tokenizer.timestamp_begin
                )
                end_timestamp_pos = (
                    sliced_tokens[-1].item() - tokenizer.timestamp_begin
                )
                current_segments.append(
                    new_segment(
                        start=time_offset + start_timestamp_pos * time_precision,
                        end=time_offset + end_timestamp_pos * time_precision,
                        tokens=sliced_tokens,
                        result=result,
                    )
                )
                last_slice = current_slice

            if single_timestamp_ending:
                # single timestamp at the end means no speech after the last timestamp.
                seek += segment_size
            else:
                # otherwise, ignore the unfinished segment and seek to the last timestamp
                last_timestamp_pos = (
                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                )
                seek += last_timestamp_pos * input_stride
        else:
            duration = segment_duration
            timestamps = tokens[timestamp_tokens.nonzero().flatten()]
            if (
                len(timestamps) > 0
                and timestamps[-1].item() != tokenizer.timestamp_begin
            ):
                # no consecutive timestamps but it has a timestamp; use the last one.
                last_timestamp_pos = (
                    timestamps[-1].item() - tokenizer.timestamp_begin
                )
                duration = last_timestamp_pos * time_precision

            current_segments.append(
                new_segment(
                    start=time_offset,
                    end=time_offset + duration,
                    tokens=tokens,
                    result=result,
                )
            )
            seek += segment_size

        # if a segment is instantaneous or does not contain text, clear it
        for i, segment in enumerate(current_segments):
            if segment["start"] == segment["end"] or segment["text"].strip() == "":
                segment["text"] = ""
                segment["tokens"] = []
                segment["words"] = []

        all_segments.extend(
            [
                {"id": i, **segment}
                for i, segment in enumerate(
                    current_segments, start=len(all_segments)
                )
            ]
        )
        all_tokens.extend(
            [token for segment in current_segments for token in segment["tokens"]]
        )

        if not condition_on_previous_text or result.temperature > 0.5:
              # do not feed the prompt tokens if a high temperature was used
              prompt_reset_since = len(all_tokens)

    return dict(
        text=tokenizer.decode(all_tokens),
        segments=all_segments,
        language=language,
    )

In [37]:
!pip install editdistance
import editdistance

def cer_custom(hypothesis):
    # Convert strings to lists of characters
    hypothesis_chars = list(hypothesis)
    best_cer = 100000
    best_word = None
    for word in words_in_task:
        reference_chars = list(word)
        distance = editdistance.eval(hypothesis_chars, reference_chars)
        cer = distance / len(reference_chars)
        if(cer < best_cer):
            best_word = word
            best_cer = cer
    
    return best_cer, best_word

def process_sentence(sentence):
    if(sentence != ""):
        # Remove the + from the sentences
        sentence = sentence[1:]
        words = sentence.split()
        new_words = []
        
        # Remove words not in given vocab and compress multiple occurences
        for word in words:
            if("assador" in word):
                ind = word.find("assador")
                new_words.append(word[ : ind])
                break
            
            c, w = cer_custom(word)
            if (c > 0.25):
                continue
            if((len(new_words) == 0 or new_words[-1] != w)):
                new_words.append(w)

        return " ".join(new_words)
    else:
        return sentence            
    



In [None]:
to_json_after_processing = []
to_json_before_processing = []
for item in tqdm.tqdm(dataset["test_blind"]):
    prediction = custom_transcribe(model=model, audio = torch.tensor(item['audio']['array'].astype(np.float32)), language="hi", beam_size=4)['text']
    print(prediction)
    to_json_before_processing.append({
        'ID': item['id'],
        'TRANSCRIPT': prediction
    })
    
    prediction = process_sentence(prediction)
    
    print(prediction)
    to_json_after_processing.append({
        'ID': item['id'],
        'TRANSCRIPT': prediction
    })

In [38]:
bf = pd.read_csv("predictions_before_proc.csv")

sentences = bf["TRANSCRIPT"]
processed = [] 
for sentence in sentences:
    processed.append(process_sentence(str(sentence)))

bf["TRANSCRIPT"] = processed
# bf.to_csv("predictions.csv")

In [39]:
bf.to_csv("predictions.csv", index = False)

In [None]:
import pandas as pd
csv_file_path_after_processing = "predictions.csv"
csv_file_path_before_processing = "predictions_before_proc.csv"

x = pd.DataFrame(to_json_before_processing)
x.to_csv(csv_file_path_before_processing, index = False)

x = pd.DataFrame(to_json_after_processing)
x.to_csv(csv_file_path_after_processing, index = False)
