In [1]:
# Copyright 2019 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# Kaldi TRTIS Inference Online Demo

## Overview


This repository provides a wrapper around the online GPU-accelerated ASR pipeline from the paper [GPU-Accelerated Viterbi Exact Lattice Decoder for Batched Online and Offline Speech Recognition](https://arxiv.org/abs/1910.10032). That work includes a high-performance implementation of a GPU HMM Decoder, a low-latency Neural Net driver, fast Feature Extraction for preprocessing, and new ASR pipelines tailored for GPUs. These different modules have been integrated into the Kaldi ASR framework.

This repository contains a TensorRT Inference Server custom backend for the Kaldi ASR framework. This custom backend calls the high-performance online GPU pipeline from the Kaldi ASR framework. This TensorRT Inference Server integration provides ease-of-use to Kaldi ASR inference: gRPC streaming server, dynamic sequence batching, and multi-instances support. A client connects to the gRPC server, streams audio by sending chunks to the server, and gets back the inferred text as an answer. More information about the TensorRT Inference Server can be found [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/).  



### Learning objectives

This notebook demonstrates the steps for carrying out inferencing with the Kaldi TRTIS backend server using a Python gRPC client in an online context, that is, we will stream live audio from a microphone to the inference server and receive the results back.

## Content
1. [Pre-requisite](#1)
1. [Setup](#2)
1. [Audio helper classes](#3)
1. [Inference](#4)


<a id="1"></a>
## 1. Pre-requisite


### 1.1 Docker containers
Follow the steps in [README](README.md) to build Kaldi server and client containers.

### 1.2 Hardware
This notebook can be executed on any CUDA-enabled NVIDIA GPU, although for efficient mixed precision inference, a [Tensor Core NVIDIA GPU](https://www.nvidia.com/en-us/data-center/tensorcore/) is desired (Volta, Turing or newer architectures). 

In [1]:
!nvidia-smi

Thu Mar  5 00:28:21 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.48.02    Driver Version: 440.48.02    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Quadro GV100        Off  | 00000000:05:00.0 Off |                  Off |
| 32%   42C    P2    28W / 250W |  17706MiB / 32506MiB |      3%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

This notebook also requires access to a microphone. 

<a id="2"></a>
## 2 Setup 
### Import libraries and parameters

In [2]:
import argparse
import numpy as np
import os
import sys
from builtins import range
from functools import partial
import soundfile
import pyaudio as pa
import soundfile
import librosa

import grpc
from tensorrtserver.api import api_pb2
from tensorrtserver.api import grpc_service_pb2
from tensorrtserver.api import grpc_service_pb2_grpc
import tensorrtserver.api.model_config_pb2 as model_config

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--file', help='Path for input file. First line should contain number of lines to search in')

parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,
                    help='Enable verbose output')
parser.add_argument('-a', '--async', dest="async_set", action="store_true", required=False,
                    default=False, help='Use asynchronous inference API')
parser.add_argument('--streaming', action="store_true", required=False, default=False,
                    help='Use streaming inference API')
parser.add_argument('-m', '--model-name', type=str, required=False, default='kaldi_online' ,
                    help='Name of model')
parser.add_argument('-x', '--model-version', type=int, required=False, default=1,
                    help='Version of model. Default is to use latest version.')
parser.add_argument('-b', '--batch-size', type=int, required=False, default=1,
                    help='Batch size. Default is 1.')
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',
                    help='Inference server URL. Default is localhost:8001.')
parser.add_argument('--chunk_duration', type=float, required=False,
                    default=0.51,
                    help="duration of the audio chunk for streaming "
                            "recognition, in seconds")
parser.add_argument('--input_device_id', type=int, required=False,
                    default=-1, help='Input device id to use to capture audio')
parser.add_argument('--sample_rate', type=int, required=False,
                    default=16000, help='Sample rate.')
FLAGS = parser.parse_args()

### Checking server status

We first query the status of the server. The target model is 'kaldi_online'. A successful deployment of the Kaldi TRTIS server should result in output similar to the below.

```
request_status {
  code: SUCCESS
  server_id: "inference:0"
  request_id: 17514
}
server_status {
  id: "inference:0"
  version: "1.9.0"
  uptime_ns: 14179155408971
  model_status {
    key: "kaldi_online"
...
```

In [4]:
# Create gRPC stub for communicating with the server
channel = grpc.insecure_channel(FLAGS.url)
grpc_stub = grpc_service_pb2_grpc.GRPCServiceStub(channel)

# Prepare request for Status gRPC
request = grpc_service_pb2.StatusRequest(model_name=FLAGS.model_name)
# Call and receive response from Status gRPC
response = grpc_stub.Status(request)

print(response)

request_status {
  code: SUCCESS
  server_id: "inference:0"
  request_id: 6234
}
server_status {
  id: "inference:0"
  version: "1.9.0"
  uptime_ns: 4061941924008
  model_status {
    key: "kaldi_online"
    value {
      config {
        name: "kaldi_online"
        platform: "custom"
        version_policy {
          latest {
            num_versions: 1
          }
        }
        max_batch_size: 2200
        input {
          name: "WAV_DATA"
          data_type: TYPE_FP32
          dims: 8160
        }
        input {
          name: "WAV_DATA_DIM"
          data_type: TYPE_INT32
          dims: 1
        }
        output {
          name: "TEXT"
          data_type: TYPE_STRING
          dims: 1
        }
        instance_group {
          name: "kaldi_online_0"
          count: 2
          gpus: 0
          kind: KIND_GPU
        }
        default_model_filename: "libkaldi-trtisbackend.so"
        sequence_batching {
          max_sequence_idle_microseconds: 5000000
          

### Testing microphone

We next identify the input devices in the system. You will need to select a relevant input device amongst the ones listed. 

In [5]:
import pyaudio
import wave

p = pyaudio.PyAudio()  # Create an interface to PortAudio

device_info = p.get_host_api_info_by_index(0)
num_devices = device_info.get('deviceCount')

devices = {}
for i in range(0, num_devices):
    #if (p.get_device_info_by_host_api_device_index(0, i).get(
    #    'maxInputChannels')) > 0:
        devices[i] = p.get_device_info_by_host_api_device_index(
            0, i)

if (len(devices) == 0):
    raise RuntimeError("Cannot find any valid input devices")


print("\nInput Devices:")
for id, info in devices.items():
    print("{}: {}".format(id,info.get("name")))
input_device_id = int(input("Enter device id to use: "))


Input Devices:
0: HDA Intel PCH: ALC1150 Analog (hw:0,0)
1: HDA Intel PCH: ALC1150 Digital (hw:0,1)
2: HDA Intel PCH: ALC1150 Alt Analog (hw:0,2)
3: HD Pro Webcam C920: USB Audio (hw:1,0)
4: HDA NVidia: HDMI 0 (hw:2,3)
5: HDA NVidia: HDMI 2 (hw:2,8)
6: HDA NVidia: HDMI 3 (hw:2,9)
7: sysdefault
8: front
9: surround21
10: surround40
11: surround41
12: surround50
13: surround51
14: surround71
15: iec958
16: spdif
17: default
18: dmix
Enter device id to use: 3


We then employ the selected device, record from it and play back to verify that everything is in order.

In [6]:
import pprint
pp = pprint.PrettyPrinter(indent=4)
    
print("Device info:")
devinfo = p.get_device_info_by_index(input_device_id)  # Or whatever device you care about.
pp.pprint(devinfo)

chunk = 1024  # Record in chunks of 1024 samples
sample_format = pyaudio.paInt16  # 16 bits per sample
channels = 1
fs = devinfo['defaultSampleRate']  # Record at device default sampling rate
seconds = 3
filename = "test.wav"

print('Recording')

stream = p.open(format=sample_format,
                channels=channels,
                rate=int(devinfo["defaultSampleRate"]),
                frames_per_buffer=chunk,
                input=True,
                input_device_index=input_device_id)

frames = []  # Initialize array to store frames

# Store data in chunks for 3 seconds
for i in range(0, int(fs / chunk * seconds)):
    data = stream.read(chunk)
    frames.append(data)

# Stop and close the stream 
stream.stop_stream()
stream.close()
# Terminate the PortAudio interface
# p.terminate()

print('Finished recording')

# Save the recorded data as a WAV file
wf = wave.open(filename, 'wb')
wf.setnchannels(channels)
wf.setsampwidth(p.get_sample_size(sample_format))
wf.setframerate(fs)
wf.writeframes(b''.join(frames))
wf.close()


Device info:
{   'defaultHighInputLatency': 0.048,
    'defaultHighOutputLatency': -1.0,
    'defaultLowInputLatency': 0.01196875,
    'defaultLowOutputLatency': -1.0,
    'defaultSampleRate': 32000.0,
    'hostApi': 0,
    'index': 3,
    'maxInputChannels': 2,
    'maxOutputChannels': 0,
    'name': 'HD Pro Webcam C920: USB Audio (hw:1,0)',
    'structVersion': 2}
Recording
Finished recording


In [None]:
import IPython.display as ipd
ipd.Audio(filename)

<a id="3"></a>
## 3. Audio helper classes

Next, we define some helper classes for pre-processing audio. The below AudioSegment class takes audio signal and converts the sampling rate to that required by the Kaldi ASR model, which is 16000Hz by default.

Note:  For historical reasons, Kaldi expects waveforms in the range (2^15-1)x[-1, 1], not the usual default DSP range [-1, 1]. Therefore, we scale the audio signal by a factor of (2^15-1).

In [8]:
WAV_SCALE_FACTOR = 2**15-1

class AudioSegment(object):
    """Monaural audio segment abstraction.
    :param samples: Audio samples [num_samples x num_channels].
    :type samples: ndarray.float32
    :param sample_rate: Audio sample rate.
    :type sample_rate: int
    :raises TypeError: If the sample data type is not float or int.
    """

    def __init__(self, samples, sample_rate, target_sr=16000, trim=False,
                 trim_db=60):
        """Create audio segment from samples.
        Samples are convert float32 internally, with int scaled to [-1, 1].
        """
        samples = self._convert_samples_to_float32(samples)
        if target_sr is not None and target_sr != sample_rate:
            samples = librosa.core.resample(samples, sample_rate, target_sr)
            sample_rate = target_sr
        if trim:
            samples, _ = librosa.effects.trim(samples, trim_db)
        self._samples = samples
        self._sample_rate = sample_rate
        if self._samples.ndim >= 2:
            self._samples = np.mean(self._samples, 1)

    @staticmethod
    def _convert_samples_to_float32(samples):
        """Convert sample type to float32.
        Audio sample type is usually integer or float-point.
        Integers will be scaled to [-1, 1] in float32.
        """
        float32_samples = samples.astype('float32')
        if samples.dtype in np.sctypes['int']:
            bits = np.iinfo(samples.dtype).bits
            float32_samples *= (1. / ((2 ** (bits - 1)) - 1))
        elif samples.dtype in np.sctypes['float']:
            pass
        else:
            raise TypeError("Unsupported sample type: %s." % samples.dtype)
        return WAV_SCALE_FACTOR * float32_samples

    @classmethod
    def from_file(cls, filename, target_sr=16000, offset=0, duration=0,
                 min_duration=0, trim=False):
        """
        Load a file supported by librosa and return as an AudioSegment.
        :param filename: path of file to load
        :param target_sr: the desired sample rate
        :param int_values: if true, load samples as 32-bit integers
        :param offset: offset in seconds when loading audio
        :param duration: duration in seconds when loading audio
        :return: numpy array of samples
        """
        with sf.SoundFile(filename, 'r') as f:
            dtype_options = {'PCM_16': 'int16', 'PCM_32': 'int32', 'FLOAT': 'float32'}
            dtype_file = f.subtype
            if dtype_file in dtype_options:
                dtype = dtype_options[dtype_file]
            else:
                dtype = 'float32'
            sample_rate = f.samplerate
            if offset > 0:
                f.seek(int(offset * sample_rate))
            if duration > 0:
                samples = f.read(int(duration * sample_rate), dtype=dtype)
            else:
                samples = f.read(dtype=dtype)

        num_zero_pad = int(target_sr * min_duration - samples.shape[0])
        if num_zero_pad > 0:
            samples = np.pad(samples, [0, num_zero_pad], mode='constant')

        samples = samples.transpose()
        return cls(samples, sample_rate, target_sr=target_sr, trim=trim)

    @property
    def samples(self):
        return self._samples.copy()

    @property
    def sample_rate(self):
        return self._sample_rate

<a id="4"></a>
## Inference

We first create an inference context object that connects to the Kaldi TRTIS servier via a gPRC connection.

The server expects chunks of audio each containing up to input.WAV_DATA.dims samples (default: 8160). Per default, this corresponds to 510ms of audio per chunk (i.e. 16000Hz sampling rate). The last chunk can send a partial chunk smaller than this maximum value.

In [9]:
from tensorrtserver.api import *
protocol = ProtocolType.from_str("grpc")

CORRELATION_ID = 11101
ctx = InferContext(FLAGS.url, protocol, FLAGS.model_name, FLAGS.model_version,
                    correlation_id=CORRELATION_ID, verbose=True,
                    streaming=False)

Next, we take chunks of audio (each 510ms in duration, containing 8160 samples) from the microphone and stream them sequentially to the Kaldi server. The server processes each chunk as soon as it is received. 

Unlike data from a .wav file, as we take the data continuoulsy from the mic, there is no `end` marker. Therefore, we receive the result once every 10 chunks. Note that the server will reset it status once the result is sent out.   

In [11]:
class TranscribeFromMicrophone:

    def __init__(self,input_device_id, target_sr, chunk_duration):

        self.recording_state = "init"
        self.target_sr  = target_sr
        self.chunk_duration = chunk_duration

        self.p = pa.PyAudio()

        device_info = self.p.get_host_api_info_by_index(0)
        num_devices = device_info.get('deviceCount')
        devices = {}
        for i in range(0, num_devices):
            if (self.p.get_device_info_by_host_api_device_index(0, i).get(
                'maxInputChannels')) > 0:
                devices[i] = self.p.get_device_info_by_host_api_device_index(
                    0, i)

        if (len(devices) == 0):
            raise RuntimeError("Cannot find any valid input devices")

        if input_device_id is None or input_device_id not in \
            devices.keys():
            print("\nInput Devices:")
            for id, info in devices.items():
                print("{}: {}".format(id,info.get("name")))
            input_device_id = int(input("Enter device id to use: "))

        self.input_device_id = input_device_id
        devinfo = self.p.get_device_info_by_index(input_device_id)
        self.device_default_sr = int(devinfo['defaultSampleRate'])
        print("Device sample rate: %d" % self.device_default_sr)

    def transcribe_audio(self, streaming=True):
        ctx = InferContext(FLAGS.url, protocol, FLAGS.model_name, FLAGS.model_version,
                    correlation_id=CORRELATION_ID, verbose=True,
                    streaming=False)
        
        chunk_size = int(self.chunk_duration*self.device_default_sr)
        self.recording_state = "init"

        def keyboard_listener():
            input("**********Press Enter to start and end transcribing...**********")
            self.recording_state = "capture"
            print("Recording...")
            
            input("")
            self.recording_state = "release"

        listener = threading.Thread(target=keyboard_listener)
        listener.start()

        start = True
        print("starting....")
        
        stream_initialized = False
        audio_signal = 0
        audio_segment = 0
        end = False
        
        cnt = 0
        MAX_CHUNKS = 10
        while self.recording_state != "release":
            try:
                if self.recording_state == "capture":

                    if not stream_initialized:
                        stream = self.p.open(
                            format=pa.paInt16,
                            channels=1,
                            rate=self.device_default_sr,
                            input=True,
                            input_device_index=self.input_device_id,
                            frames_per_buffer=chunk_size)
                        stream_initialized = True

                    # Read an audio chunk from microphone
                    audio_signal = stream.read(chunk_size, exception_on_overflow = False)
                    if self.recording_state == "release":
                      break
                      end = True
                    audio_signal = np.frombuffer(audio_signal,dtype=np.int16)
                    audio_segment = AudioSegment(audio_signal,
                                                              self.device_default_sr,
                                                              self.target_sr)
                    
                    if cnt == MAX_CHUNKS:
                        end = True
                    if cnt > 1:
                        start = False
                        
                    # Inference
                    flags = InferRequestHeader.FLAG_NONE
                    x = (audio_segment.samples, self.target_sr, start, end)
                    if x[2]:
                        flags = flags | InferRequestHeader.FLAG_SEQUENCE_START
                    if x[3]:
                        flags = flags | InferRequestHeader.FLAG_SEQUENCE_END
                    if not end:
                        ctx.run({'WAV_DATA' : (x[0],),
                                 'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(x[0]), dtype=np.int32),)},
                                {},
                                batch_size=1,
                                flags=flags,
                                corr_id=CORRELATION_ID)
                    else:
                        res = ctx.run({'WAV_DATA' : (x[0],),
                                       'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(x[0]), dtype=np.int32),)},
                                      { 'TEXT' : InferContext.ResultFormat.RAW },
                                      batch_size=1,
                                      flags=flags,
                                      corr_id=CORRELATION_ID)
                        print("".join([x.decode('utf-8') for x in res['TEXT'][0]]))
                    
                    if cnt == MAX_CHUNKS: # reset server
                        start = True
                        end = False
                        cnt = 0
                    
                    cnt += 1
                    sys.stdout.write("\r" + "."*cnt)
                    sys.stdout.flush()
                    
            except Exception as e:
                print(e)
                break

        stream.close()
        self.p.terminate()

In [12]:
transcriber = TranscribeFromMicrophone(input_device_id,
    target_sr=FLAGS.sample_rate,
    chunk_duration=FLAGS.chunk_duration)

Device sample rate: 32000


After executing the below cell, upon pressing ENTER, the mic will start recording chunks of audio from the specified mic and stream them continuously to the server. After every 10 chunks, the client takes and display the results, while the status of the server is reset, i.e., it treats the next chunk as the start of a fresh new request. 
When pressing ENTER again, the client stops.



In [None]:
transcriber.transcribe_audio()

# Conclusion

In this notebook, we have walked through the complete process of preparing the audio data from a microphone and carry out inference with the Kaldi ASR model.

## What's next
Now it's time to try the Kaldi ASR model on your own data. The online client can also be further improved, for example, by detecting natural breaks in the input stream (e.g., silence) to break sentence more properly. 
