In [111]:
import numpy as np
import math
import soundfile as sf
import librosa
import matplotlib.pyplot as plt
import torch
from pydub import AudioSegment
from torch.utils.data import DataLoader, IterableDataset

import wave
import os


# Configuration dictionary
from utils.asr_utils import *


# ort_session_en_ner, model_tokenizer_en, filterbank_featurizer = create_ort_session(model_name="EN_ner_emotion_commonvoice", model_shelf="/disk1/artifacts/whissle/model_shelf")
ort_session_en_ner, model_tokenizer_en, filterbank_featurizer = create_ort_session(model_name="EN_ner_conformer_ctc_large", model_shelf="/home/svanga/test/model_shelf")


spm_model_path:  /home/svanga/test/model_shelf/EN_IoT_conformer_ctc_large/tokenizer/tokenizer.model


In [112]:
# Function to convert stereo audio to mono
def convert_to_stereo_mono(audio_path, output_path, channel=1):
    stereo_audio = AudioSegment.from_wav(audio_path)
    mono_audio = stereo_audio.set_channels(channel)
    mono_audio.export(output_path, format="wav")
    return output_path

def calculate_n_buffers(audio_file_path, chunk_len_in_secs, buffer_len_in_secs, stride):
    with wave.open(audio_file_path, 'r') as wav_file:
        sample_rate = wav_file.getframerate()
        num_samples = wav_file.getnframes()
        duration_in_secs = num_samples / sample_rate

    effective_chunk_len = chunk_len_in_secs + buffer_len_in_secs - stride

    n_buffers = int(duration_in_secs / effective_chunk_len)

    return n_buffers



def convert_to_mono(audio_path, output_path, channel=1):
    # Load the stereo audio file
    stereo_audio = AudioSegment.from_wav(audio_path)
    
    # Select the channel: 0 for left, 1 for right
    if channel == 0:
        mono_audio = stereo_audio.split_to_mono()[0]  # Left channel
    elif channel == 1:
        mono_audio = stereo_audio.split_to_mono()[1]  # Right channel
    else:
        raise ValueError("Channel must be 0 (left) or 1 (right)")
    
    # Export the mono audio to the specified output path
    mono_audio.export(output_path, format="wav")
    return output_path



In [113]:
# A simple iterator class to return successive chunks of samples
class AudioChunkIterator:
    def __init__(self, samples, chunk_len_in_secs, sample_rate):
        self._samples = samples
        self._chunk_len = int(chunk_len_in_secs * sample_rate)
        self._start = 0
        self.output = True
        self.sample_rate = sample_rate

    def __iter__(self):
        return self

    def __next__(self):
        if not self.output:
            raise StopIteration
        last = self._start + self._chunk_len
        if last <= len(self._samples):
            chunk = self._samples[self._start: last]
            start_time = self._start / self.sample_rate
            self._start = last
        else:
            chunk = np.zeros(self._chunk_len, dtype='float32')
            samp_len = len(self._samples) - self._start
            chunk[:samp_len] = self._samples[self._start:len(self._samples)]
            start_time = self._start / self.sample_rate
            self.output = False
        return chunk, start_time

# A helper function for extracting samples as a numpy array from the audio file
def get_samples(audio_file, target_sr=16000):
    with sf.SoundFile(audio_file, 'r') as f:
        sample_rate = f.samplerate
        samples = f.read(dtype='float32')
        if sample_rate != target_sr:
            samples = librosa.resample(samples.T, orig_sr=sample_rate, target_sr=target_sr).T
        return samples.flatten()

# Function to prepare buffers and their start times
def prepare_buffers(samples, config):
    sample_rate = config['sample_rate']
    buffer_len_in_secs = config['buffer_len_in_secs']
    chunk_len_in_secs = config['chunk_len_in_secs']
    n_buffers = config['n_buffers']

    buffer_len = sample_rate * buffer_len_in_secs
    sampbuffer = np.zeros(buffer_len, dtype=np.float32)
    
    chunk_reader = AudioChunkIterator(samples, chunk_len_in_secs, sample_rate)
    chunk_len = sample_rate * chunk_len_in_secs
    buffer_list = []
    buffer_start_times = []
    
    for count, (chunk, start_time) in enumerate(chunk_reader, start=1):
        sampbuffer[:-chunk_len] = sampbuffer[chunk_len:] 
        sampbuffer[-chunk_len:] = chunk
        buffer_list.append(np.array(sampbuffer))
        buffer_start_times.append(start_time)
    return buffer_list, buffer_start_times

def speech_collate_fn(batch):
    _, audio_lengths = zip(*batch)
    max_audio_len = max(audio_lengths).item()
   
    audio_signal= []
    for sig, sig_len in batch:
        sig_len = sig_len.item()
        if sig_len < max_audio_len:
            pad = (0, max_audio_len - sig_len)
            sig = torch.nn.functional.pad(sig, pad)
        audio_signal.append(sig)
        
    audio_signal = torch.stack(audio_signal)
    audio_lengths = torch.stack(audio_lengths)
    return audio_signal, audio_lengths



In [114]:
class AudioBuffersDataLayer(IterableDataset):
    def __init__(self):
        super().__init__()
        
    def __iter__(self):
        return self
    
    def __next__(self):
        if self._buf_count == len(self.signal):
            raise StopIteration
        self._buf_count += 1
        return torch.as_tensor(self.signal[self._buf_count-1], dtype=torch.float32), \
               torch.as_tensor(self.signal_shape[0], dtype=torch.int64)
        
    def set_signal(self, signals):
        self.signal = signals
        self.signal_shape = self.signal[0].shape
        self._buf_count = 0

    def __len__(self):
        return 1

class ChunkBufferDecoder:
    def __init__(self, ort_session, tokenizer, featurizer, stride, config):
        self.ort_session = ort_session
        self.input_names = [input.name for input in ort_session.get_inputs()]
        self.output_names = [output.name for output in ort_session.get_outputs()]
        self.tokenizer = tokenizer
        self.featurizer = featurizer
        self.data_layer = AudioBuffersDataLayer()
        self.data_loader = DataLoader(self.data_layer, batch_size=1, collate_fn=speech_collate_fn)
        self.buffers = []
        self.all_preds = []
        self.chunk_len = config['chunk_len_in_secs']
        self.buffer_len = config['buffer_len_in_secs']
        assert config['chunk_len_in_secs'] <= config['buffer_len_in_secs']
        
        feature_stride = 0.01 # window_stride
        self.model_stride_in_secs = feature_stride * stride
        self.n_tokens_per_chunk = math.ceil(self.chunk_len / self.model_stride_in_secs) -2
        self.blank_id = 1024
        self.plot = False
        
    @torch.no_grad()    
    def transcribe_buffers(self, buffers, buffer_start_times, merge=True, plot=False):
        self.plot = plot
        self.buffers = buffers
        self.buffer_start_times = buffer_start_times
        self.data_layer.set_signal(buffers[:])
        self._get_batch_preds()
        hyp, timestamps = self.decode_final(merge)    
        return hyp, timestamps
    
    def _get_batch_preds(self):
        device = 'cpu'
        for batch in iter(self.data_loader):
            audio_signal, audio_signal_len = batch
            audio_signal, audio_signal_len = audio_signal.to(device), audio_signal_len.to(device)
            features, features_length = self.featurizer.forward(audio_signal, audio_signal_len)
            input_data = {
                self.input_names[0]: features.cpu().numpy(),
                self.input_names[1]: features_length.cpu().numpy()
            }
            log_probs = self.ort_session.run([self.output_names[0]], input_data)
            greedy_predictions = torch.tensor(log_probs[0]).argmax(dim=-1, keepdim=False)
            preds = torch.unbind(greedy_predictions)
            for pred in preds:
                self.all_preds.append(pred.cpu().numpy())

    
    def decode_final(self, merge=True):
        self.unmerged = []
        self.toks_unmerged = []
        delay = math.ceil((self.chunk_len + (self.buffer_len - self.chunk_len) / 2) / self.model_stride_in_secs)

        decoded_frames = []
        all_toks = []
        for idx, pred in enumerate(self.all_preds):
            ids, toks, offsets = self._greedy_decoder(pred, self.tokenizer, self.buffer_start_times[idx])
            decoded_frames.append((ids, offsets))
            all_toks.append(toks)

        for decoded, offsets in decoded_frames:
            self.unmerged += list(zip(decoded[len(decoded) - 1 - delay:len(decoded) - 1 - delay + self.n_tokens_per_chunk],
                                      offsets[len(offsets) - 1 - delay:len(offsets) - 1 - delay + self.n_tokens_per_chunk]))

        if self.plot:
            for i, tok in enumerate(all_toks):
                plt.plot(self.buffers[i])
                plt.show()
                print(tok)
                print("\nGreedy labels collected from this buffer")
                print(tok[len(tok) - 1 - delay:len(tok) - 1 - delay + self.n_tokens_per_chunk])
                self.toks_unmerged += tok[len(tok) - 1 - delay:len(tok) - 1 - delay + self.n_tokens_per_chunk]
            print("\nTokens collected from successive buffers before CTC merge")
            print(self.toks_unmerged)

        if not merge:
            return self.unmerged
        hyp, timestamps = self.greedy_merge(self.unmerged)
        return hyp, timestamps
    
    def _greedy_decoder(self, preds, tokenizer, buffer_start_time):
        s = []
        ids = []
        offsets = []
        for i in range(preds.shape[0]):
            if preds[i] == self.blank_id:
                s.append("_")
            else:
                pred = preds[i]
                s.append(tokenizer.id_to_piece(pred.item()))
            ids.append(preds[i])
            offsets.append(buffer_start_time + i * self.model_stride_in_secs)
        return ids, s, offsets
         
    def greedy_merge(self, preds):
        decoded_prediction = []
        word_timestamps = []
        previous = self.blank_id
        for p, offset in preds:
            if (p != previous or previous == self.blank_id) and p != self.blank_id:
                decoded_prediction.append((p.item(), offset))
            previous = p
        hypothesis = self.tokenizer.decode_ids([p for p, _ in decoded_prediction])
        
        # Print word-level start and end timestamps
        words = []
        timestamps = []
        current_word = ""
        current_start_timestamp = 0.0
        current_end_timestamp = 0.0
        for token, timestamp in decoded_prediction:
            subword = self.tokenizer.id_to_piece(token)
            if subword.startswith("▁"):
                if current_word:
                    words.append(current_word)
                    timestamps.append((current_start_timestamp, current_end_timestamp + self.model_stride_in_secs))
                current_word = subword.replace("▁", "")
                current_start_timestamp = timestamp
            else:
                current_word += subword
            current_end_timestamp = timestamp
            # print(current_word)
        if current_word:
            words.append(current_word)
            timestamps.append((current_start_timestamp, current_end_timestamp + self.model_stride_in_secs))

        for word, (start, end) in zip(words, timestamps):
            word_timestamps.append({'word':word, 'start': start, 'end': end})
            #print(f"{word}: {start:.2f}s - {end:.2f}s")
        
        return hypothesis, word_timestamps


In [120]:
config = {
    "sample_rate": 16000,
    "chunk_len_in_secs": 4,
    "context_len_in_secs": 2,
    "buffer_len_in_secs": 6,
    "n_buffers": 5,
    "stride": 4
}


file_path='/home/svanga/PromptingNemo/applications/voicebot/a.wav'
output_mono_path='/home/svanga/PromptingNemo/applications/voicebot/b.wav'
mono_audio_path = convert_to_mono(file_path, output_mono_path, 0)
config['n_buffers'] = calculate_n_buffers(mono_audio_path, config["chunk_len_in_secs"], config["buffer_len_in_secs"], config["stride"])
samples = get_samples(mono_audio_path, target_sr=config["sample_rate"])
buffer_list, buffer_start_times = prepare_buffers(samples, config)


In [121]:
ort_session_en_ner, model_tokenizer_en, filterbank_featurizer = create_ort_session(model_name="EN_IoT_conformer_ctc_large", model_shelf="/home/svanga/test/model_shelf")

stride = 4 # 4, 8 for Citrinet
asr_decoder = ChunkBufferDecoder(ort_session_en_ner, model_tokenizer_en, filterbank_featurizer, stride, config)
transcription, timestamps = asr_decoder.transcribe_buffers(buffer_list, buffer_start_times, plot=False)
print("Transcription:", transcription)
print(timestamps)

spm_model_path:  /home/svanga/test/model_shelf/EN_IoT_conformer_ctc_large/tokenizer/tokenizer.model
Transcription: other the reformorms that  took place around this time included the organisation of units into standdardd formations such as battlealions END.  increased payments to volulunteerersEND landand rants for afffien serviceEND. the establishment ofENTITY-EVENT_NAME anual  train  cabs END usually over ENTITY-EVENT_NAME easter END.  thecre of cris of professional soders known as permaninent staffff END. to provide  training the requirement for officeerss at non commissioned officeiceers . to  pass exams at the esttablishment of minimum required attendance. INTENT-CALENDAR_SET 
[{'word': 'other', 'start': 0.0, 'end': 2.08}, {'word': 'the', 'start': 2.12, 'end': 2.16}, {'word': 'reformorms', 'start': 2.2, 'end': 2.68}, {'word': 'that', 'start': 2.84, 'end': 2.88}, {'word': 'took', 'start': 3.04, 'end': 3.2}, {'word': 'place', 'start': 3.2800000000000002, 'end': 3.3200000000000003}, 

In [122]:
ort_session_en_ner, model_tokenizer_en, filterbank_featurizer = create_ort_session(model_name="EN_ner_conformer_ctc_large", model_shelf="/home/svanga/test/model_shelf")
stride = 8 # 4, 8 for Citrinet
asr_decoder = ChunkBufferDecoder(ort_session_en_ner, model_tokenizer_en, filterbank_featurizer, stride, config)
transcription, timestamps = asr_decoder.transcribe_buffers(buffer_list, buffer_start_times, plot=False)
print("Transcription:", transcription)
print(timestamps)

spm_model_path:  /home/svanga/test/model_shelf/EN_ner_conformer_ctc_large/tokenizer/tokenizer.model
Transcription: ther reforms that took place around this time included the organization of units into standard formations such as battalions increased payments to volunteers , land grants for efficient service , the establishment of NER_DATE annual END training camps , usually over  Easter END , the creation of cadderies of professional soldiers known as permanent staff to provide training , the requirement for officers and non-commissioned officers to pass exams and the establishment of minimum required attendance
[{'word': 'ther', 'start': 0.0, 'end': 2.0}, {'word': 'reforms', 'start': 2.16, 'end': 2.64}, {'word': 'that', 'start': 2.8000000000000003, 'end': 2.8800000000000003}, {'word': 'took', 'start': 2.96, 'end': 3.12}, {'word': 'place', 'start': 3.2, 'end': 3.3600000000000003}, {'word': 'around', 'start': 3.52, 'end': 3.7600000000000002}, {'word': 'this', 'start': 3.84, 'end': 3.92}