In [None]:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.io as pio
import json
import torchaudio

import utils.myutils as myutils

pio.renderers.default = "jupyterlab" # use "jupyterlab" for jupyterlab or ""notebook_connected" for jupyter notebook 

In [None]:
############################################## SET UP PARAMETERS ###############################################
SAMPLING_RATE = 16000
PATH_MODEL = "PATH_TO_MODEL"

In [None]:
def stringEdit(listText, isPhoneme=False):
    # Edit test transcript to remove unwanted characters:
    tempList = []
    for text in listText:
        #text = text.upper()
        text = text.replace('\n', ' ')        
        text = text.replace('-', ' ')
        text = text.replace('\'', ' ')        
 
        text = " ".join(text.split())
        text = text.replace(' ', '|')        
        tempList.append(text)
    return tempList

transcript = 'TRANSCRIPT'

processor = Wav2Vec2Processor.from_pretrained(PATH_MODEL)
model = Wav2Vec2ForCTC.from_pretrained(PATH_MODEL)

transcript = stringEdit([transcript])[0]

transcript = " ".join(transcript.split())
print(f"Transcript {transcript}")

transcript = transcript.replace(' ', '|')                  

# Load vocab file:
with open(PATH_MODEL + '/vocab.json', encoding='utf-8') as json_file:
    vocab = json.load(json_file)

with torch.inference_mode():
    waveform, sr = torchaudio.load('PATH_TO_AUDIO_FILE')

# transpose to match dimension with Wav2Vec2 processor
input_audio = torch.transpose(waveform, 1,0)

# indexing to get only 1 channel (incase the recording are from 2 channels)
input_audio = input_audio[:, 0]                       

input_values = processor(input_audio, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_values

with torch.no_grad():
    logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
prediction = processor.batch_decode(predicted_ids)
print(f"Prediction {prediction}")

tokens = [vocab[c] for c in transcript]    

In [None]:
def temperature_scaling(logits, temperature):
    """
    Apply temperature scaling to logits.
    
    :param logits: torch tensor of logits
    :param temperature: temperature factor
    :return: scaled logits
    """
    return logits / temperature

def topk_normalize(probabilities, topk = 3):
    """
    Normalize the top-k probabilities in a tensor while keeping the rest unchanged.

    :param probabilities: The input tensor.
    :param topk: The number of top elements to consider.
    :return: A tensor of the same size as the input, but with only the top-k elements 
    normalized to sum to 1, and the rest left as they were.
    """
    # Compute the highest value and its index
    # Only use the higest - k = 1
    top_values, top_indices = torch.topk(probabilities, 1)        
    
    top_values_norm, top_indices_norm = torch.topk(probabilities, topk)
    bottom_values, bottom_indice = torch.topk(probabilities, probabilities.shape[-1] - topk, largest=False)

    top_prob_normalized = top_values_norm / top_values 

    # Generate the new probability tensor where only topk_norm items are normalized
    new_probabilities = torch.zeros_like(probabilities)

    # Scatter the normalized top-k values into the result tensor
    new_probabilities.scatter_(-1, top_indices_norm, top_prob_normalized)

    # Scatter the untouched bottom probabilities
    new_probabilities.scatter_(-1, bottom_indice, bottom_values)    
    
    return new_probabilities

temperature = 10
topk = 3

# Apply temperature scaling with a factor of temperature
scaled_logits = temperature_scaling(logits, temperature)

# Compute probabilities using softmax
scaled_probabilities = torch.nn.functional.softmax(scaled_logits, dim=-1)

# Normalize the top-k probabilities
emission = topk_normalize(scaled_probabilities, topk=topk)[0]
emission = torch.log(emission)

vocab_sort = dict(sorted(vocab.items(), key=lambda item: item[1]))
list_token = [k for k,v in vocab_sort.items()]

fig = px.imshow(emission.T.exp(), y=list_token, color_continuous_scale='viridis', height=1200)
fig.show()