<a href="https://colab.research.google.com/github/Hannes1/youtube-transcript-api/blob/master/onnx_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## Install NeMo
BRANCH = 'v1.0.0b2'
## Grab the config we'll use in this example
!mkdir configs
!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml
!pip install onnxruntime
!python -m pip install 'git+https://github.com/NVIDIA/NeMo.git@v1.0.0b2#egg=nemo_toolkit[all]'
!git clone https://Hannes1:Hansie13!@github.com/Hannes1/chatable_asr_data_downloader.git
!pip install nemo-toolkit[nlp]==1.0.0b1
!pip install gdown #Get onnx model
!gdown https://drive.google.com/uc?id=18IgsCFS7fwlv18FGAhDAfnETFPOw3Mo7
%cd ./chatable_asr_data_downloader
!python -m pip install git+https://github.com/nficano/pytube
!pip install -r requirements.txt

In [18]:
#Just to get audio file move and rename wav file to test in main directory
#Can just upload own voice and call test.wav
from download_audio import download_audio
from pytube import YouTube
from cut_audio import cut_audio,format_to_wav
id = "BtN-goy9VOY" #Youtube Id
download_audio(id,'../')
format_to_wav(id,'../', "mp4", 1, 16000)


In [None]:
%cd ../content
%mv BtN-goy9VOY.wav test.wav #for test purposes

In [8]:
def cut_text(text,max_length):
    sentence_array = []
    current_character = 1
    string_length = len(text)
    forward_count = 0
    character_count = 0
    while character_count < string_length:
        sentence = text[current_character-1:current_character + forward_count +
                            max_length-1]  # cut string -1 because the first word doesn't have a space so it's lost
        if sentence.endswith(' '):
            sentence_array.append(sentence)
            current_character = current_character + 128 + forward_count
            character_count += 128
            forward_count = 0
        else:
            character_count += 1
            forward_count += 1

    return sentence_array        

In [None]:
import json
import os
import tempfile
import onnxruntime
import torch

import numpy as np
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.data.audio_to_text import AudioToCharDataset
from nemo.collections.asr.metrics.wer import WER
from ruamel.yaml import YAML
import nemo.collections.nlp as nemo_nlp

nlp_model = nemo_nlp.models.PunctuationCapitalizationModel.from_pretrained(model_name="Punctuation_Capitalization_with_DistilBERT")

quartznet = nemo_asr.models.EncDecCTCModel.from_pretrained(
    model_name="QuartzNet15x5Base-En")

config_path = './configs/config.yaml'


yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)
print(params)


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() #conver tensor to numpy.ndarray


def setup_transcribe_dataloader(cfg, vocabulary):
    config = {
        'manifest_filepath': os.path.join(cfg['temp_dir'], 'manifest.json'),
        'sample_rate': 16000,
        'labels': vocabulary, 
        'batch_size': min(cfg['batch_size'], len(cfg['paths2audio_files'])),
        'trim_silence': True,
        'shuffle': False,
    }
    dataset = AudioToCharDataset(
        manifest_filepath=config['manifest_filepath'],
        labels=config['labels'],
        sample_rate=config['sample_rate'],
        int_values=config.get('int_values', False),
        augmentor=None,
        max_duration=config.get('max_duration', None),
        min_duration=config.get('min_duration', None),
        max_utts=config.get('max_utts', 0),
        blank_index=config.get('blank_index', -1),
        unk_index=config.get('unk_index', -1),
        normalize=config.get('normalize_transcripts', False),
        trim=config.get('trim_silence', True),
        load_audio=config.get('load_audio', True),
        parser=config.get('parser', 'en'),
        add_misc=config.get('add_misc', False),
    )
    return torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=config['batch_size'],
        collate_fn=dataset.collate_fn,
        drop_last=config.get('drop_last', False),
        shuffle=False,
        num_workers=config.get('num_workers', 0),
        pin_memory=config.get('pin_memory', False),
    )

# quartznet.export('qn.onnx')

ort_session = onnxruntime.InferenceSession('./qn.onnx')
vocabulary = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
              "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]

files = ['./test.wav']
with tempfile.TemporaryDirectory() as tmpdir:
    with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp:
        for audio_file in files:
            entry = {'audio_filepath': audio_file,
                     'duration': 100000, 'text': 'nothing'}
            fp.write(json.dumps(entry) + '\n')

    config = {'paths2audio_files': './test.wav',
              'batch_size': 4, 'temp_dir': tmpdir}
    temporary_datalayer = setup_transcribe_dataloader(
        config, vocabulary)
    for test_batch in temporary_datalayer:
        processed_signal, processed_signal_len = quartznet.preprocessor(
            input_signal=test_batch[0].to('cpu'), length=test_batch[1].to('cpu')
        )
        ort_inputs = {ort_session.get_inputs(
        )[0].name: to_numpy(processed_signal), }
        ologits = ort_session.run(None, ort_inputs)
        alogits = np.asarray(ologits)
        logits = torch.from_numpy(alogits[0])
        greedy_predictions = logits.argmax(dim=-1, keepdim=False)
        wer = WER(vocabulary=vocabulary,
                  batch_dim_index=0, use_cer=False, ctc_decode=True)
        hypotheses = wer.ctc_decoder_predictions_tensor(greedy_predictions)
        print(hypotheses)
        string = hypotheses[0]
        queries = cut_text(string,128)
        inference_results = nlp_model.add_punctuation_capitalization(queries)
        for query, result in zip(queries, inference_results):
          print(f'Query : {query}')
          print(f'Result: {result.strip()}\n')
        # print(result)
        break
   


# Todo

1.   Figure out how to convert nlp model to onnx
2.   Fine tune the punctuation model with more data and bigger input of characters
3.   Convert to flask 

