In [None]:
import math
from typing import List

import torch
from torch.nn.utils.rnn import pad_sequence
from torchaudio.models import Hypothesis, RNNTBeamSearch


from syllabe_vocab import VOCAB_TOKENS, syllbe_vocab_size, BLANK_ID

class SampleConfig:
    def __init__(self):
        self.resamplers = {}
        self.mel_spectrograms = {}
        self.resample_rate = 16000
        self.lowpass_filter_width = 64
        self.rolloff = 0.9475937167399596
        self.resampling_method = "kaiser_window"
        self.beta = 14.769656459379492
        self.n_fft = 1024
        self.n_mels = 80

class KinspeakEmformerRNNT(torch.nn.Module):
    def __init__(self, target_vocab_size,
                 target_blank_id):
        super(KinspeakEmformerRNNT, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.target_blank_id = target_blank_id
        self.rnnt = torchaudio.models.emformer_rnnt_base(self.target_vocab_size)
        self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)

    def forward(self, log_mel_spectrograms: torch.Tensor, #log_mel_spectrograms: (N,F,L)
                log_mel_spectrogram_lengths: List[int],
                target_syllabe_ids:torch.Tensor, target_syllabe_id_lengths:List[int],
                target_syllabe_ids_with_eos=True,
                target_syllabe_gpt_output = None):
        sources = log_mel_spectrograms.transpose(1,2)
        source_lengths = torch.tensor(log_mel_spectrogram_lengths).to(sources.device, dtype=torch.int32)
        target_syllabe_ids = target_syllabe_ids.to(dtype=torch.int32)
        targets = target_syllabe_ids.split(target_syllabe_id_lengths)
        targets = pad_sequence(targets, batch_first=True)
        target_lengths = torch.tensor(target_syllabe_id_lengths).to(targets.device, dtype=torch.int32)
        prepended_targets = targets.new_empty([targets.size(0), targets.size(1) + 1])
        prepended_targets[:, 1:] = targets
        prepended_targets[:, 0] = self.target_blank_id
        prepended_target_lengths = target_lengths + 1
        (output, src_lengths, _, __) = self.rnnt(sources, source_lengths, prepended_targets, prepended_target_lengths)
        loss = torchaudio.functional.rnnt_loss(output, targets, src_lengths-1, target_lengths, blank=self.target_blank_id, reduction = 'mean', clamp=1.0)
        return loss

def post_process_hypos(
    tokens: List[int], tgt_dict: List[str], lstrip: bool = True,
) -> str:
    post_process_remove_list = [0,1,2,3,4,6]
    filtered_hypo_tokens = [token_index for token_index in tokens[1:] if token_index not in post_process_remove_list]
    output_string = "".join([tgt_dict[idx] for idx in filtered_hypo_tokens]).replace('|', ' ')
    if lstrip:
        return output_string.lstrip()
    else:
        return output_string


def _piecewise_linear_log(x):
    x[x > math.e] = torch.log(x[x > math.e])
    x[x <= math.e] = x[x <= math.e] / math.e
    return x


class ModelWrapper(torch.nn.Module):
    def __init__(self, tgt_dict: List[str]):
        super().__init__()
        # self.transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
        cfg = SampleConfig()
        win_length = cfg.resample_rate * 25 // 1000  # 25ms
        hop_length = cfg.resample_rate * 10 // 1000  # 10ms
        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=cfg.resample_rate, n_fft=cfg.n_fft,
                                           win_length=win_length,
                                           hop_length=hop_length, center=True, pad_mode="reflect", power=2.0,
                                           norm="slaney", onesided=True, n_mels=cfg.n_mels,
                                           mel_scale="htk", )

        rnnt = torchaudio.models.emformer_rnnt_base(syllbe_vocab_size())
        model = KinspeakEmformerRNNT(syllbe_vocab_size(), BLANK_ID)

        state_dict = torch.load('/home/user/kinspeak_asr_syllabe_emformer_rnnt_base.pt', map_location='cpu')

        model.load_state_dict(state_dict['model_state_dict'])
        rnnt.load_state_dict(model.rnnt.state_dict())
        del state_dict
        del model

        rnnt.eval()

        self.decoder = RNNTBeamSearch(rnnt, BLANK_ID)
        
        self.decoder.eval()
        
        self.tgt_dict = tgt_dict
    
    def feature_extractor(self, input):
        log_eps = 1e-36
        spectrogram = self.mel_spectrogram(input).transpose(1, 0)
        features = torch.log(spectrogram + log_eps)#.unsqueeze(0)[:, :-1]
        length = torch.tensor([features.shape[0]])
        return features, length

    def token_processor(self, hypos, lstrip=False):
        transcript = post_process_hypos(hypos, self.tgt_dict, lstrip=lstrip)
#         print(f"\nTranscript: '{transcript}'\n")
        return transcript

    def infer(self, features, length, state, hypothesis):
        # print('\nfeatures:', features.shape, 'Sum:', features.sum(), 'Min:', features.min(), 'Max:', features.max(), 'Mean:', features.mean(), '\n')
        # print('state:', type(state))
        hypos, state = self.decoder.infer(features, length, 8, state=state, hypothesis=hypothesis)
        return hypos, state

class ContextCacher:
    """Cache the end of input data and prepend the next input data with it.

    Args:
        segment_length (int): The size of main segment.
            If the incoming segment is shorter, then the segment is padded.
        context_length (int): The size of the context, cached and appended.
    """

    def __init__(self, segment_length: int, context_length: int):
        self.segment_length = segment_length
        self.context_length = context_length
        self.context = torch.zeros([context_length])

    def __call__(self, chunk: torch.Tensor):
        if chunk.size(0) < self.segment_length:
            chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
        chunk_with_context = torch.cat((self.context, chunk))
        self.context = chunk[-self.context_length :]
        return chunk_with_context

print('Models defined!')


In [None]:

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

import IPython
import matplotlib.pyplot as plt
from torchaudio.io import StreamReader

hop_length=160
segment_length=16
right_context_length=4

sample_rate = 16000
segment_length = segment_length * hop_length
context_length = right_context_length * hop_length

print(f"Sample rate: {sample_rate}")
print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")


In [None]:

# src = "https://download-a.akamaihd.net/files/media_periodical/be/w_YW_202105_06.mp3"

# Imbogo
src = "https://www.laits.utexas.edu/phonology/sounds/MP3/142b.mp3"

# Abakobwa
# src = "https://www.laits.utexas.edu/phonology/sounds/MP3/007b.mp3"

streamer = StreamReader(src)
streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=sample_rate)

print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0))



In [None]:
cacher = ContextCacher(segment_length, context_length)

wrapper = ModelWrapper(VOCAB_TOKENS)

state, hypothesis = None, None

print('Wrapper defined')

In [None]:
stream_iterator = streamer.stream()


def _plot(feats, num_iter, unit=25):
    unit_dur = segment_length / sample_rate * unit
    num_plots = num_iter // unit + (1 if num_iter % unit else 0)
    if num_plots == 1:
        i = 0
        fig, ax = plt.subplots(num_plots, 1)
        t0 = 0
        feats_ = feats[i * unit : (i + 1) * unit]
        t1 = t0 + segment_length / sample_rate * len(feats_)
        feats_ = torch.cat([f[2:-2] for f in feats_])  # remove boundary effect and overlap
        ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
        ax.tick_params(which="both", left=False, labelleft=False)
        ax.set_xlim(t0, t0 + unit_dur)
        t0 = t1
    else:
        fig, axes = plt.subplots(num_plots, 1)
        t0 = 0
        for i, ax in enumerate(axes):
            feats_ = feats[i * unit : (i + 1) * unit]
            t1 = t0 + segment_length / sample_rate * len(feats_)
            feats_ = torch.cat([f[2:-2] for f in feats_])  # remove boundary effect and overlap
            ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
            ax.tick_params(which="both", left=False, labelleft=False)
            ax.set_xlim(t0, t0 + unit_dur)
            t0 = t1
    fig.suptitle("MelSpectrogram Feature")
    plt.tight_layout()


@torch.inference_mode()
def run_inference(num_iter=100):
    global state, hypothesis
    chunks = []
    feats = []
    for i, (chunk,) in enumerate(stream_iterator, start=1):
        first = (state is None)
        segment = cacher(chunk[:, 0])
        if first:
            print('segment:', segment.shape, flush=True)
            print('Time:', segment.size(0)/sample_rate, 'secs', flush=True)
        features, length = wrapper.feature_extractor(segment)

        print('features:', features.shape, flush=True)
        print('length:', length, flush=True)

        hypos, state = wrapper.infer(features, length, state, hypothesis)
        if first:
            print('hypos:', type(hypos), type(hypos[0]), type(hypos[1]), type(hypos[0][0]))
        print('hypos:', len(hypos), len(hypos[0]), len(hypos[1]), len(hypos[0][0]))
        hypothesis = hypos
        transcript = wrapper.token_processor(hypos[0][0], lstrip=False)
        print(f"\nTranscript: '{transcript}'\n")
#         print('chunk:', chunk.shape, flush=True)
#         print('segment:', segment.shape, flush=True)
#         print('features:', features.shape, flush=True)
#         print('features length:', length, flush=True)
#         print('Hypos:', hypos[0][0], flush=True)
#         print('Hypos:', len(hypos[0][0]), flush=True)
#         print('transcript:', transcript, flush=True)
#         print(i, transcript[:10], end="\n", flush=True)
        # print(transcript, end="\r", flush=True)

        chunks.append(chunk)
        feats.append(features)
        if i == num_iter:
            break

    # Plot the features
    _plot(feats, num_iter)
    return IPython.display.Audio(torch.cat(chunks).T.numpy(), rate=sample_rate)

print('Inference code ready')


In [None]:

run_inference(10)
