Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
!git clone https://github.com/google-research/google-research.git

In [None]:
import sys
import os
import tarfile
import urllib
import zipfile
sys.path.append('./google-research')

# Examples of streaming and non streaming inference with TF/TFlite

## Imports

In [None]:
# TF streaming
from kws_streaming.models import models
from kws_streaming.models import utils
from kws_streaming.layers.modes import Modes


In [None]:
import tensorflow as tf
import numpy as np
import tensorflow.compat.v1 as tf1
import logging
from kws_streaming.models import model_flags
from kws_streaming.models import model_params
from kws_streaming.train import test
from kws_streaming.data import input_data
from kws_streaming.data import input_data_utils as du
tf1.disable_eager_execution()

In [None]:
config = tf1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf1.Session(config=config)

In [None]:
# general imports
import matplotlib.pyplot as plt
import os
import json
import numpy as np
import scipy as scipy
import scipy.io.wavfile as wav
import scipy.signal

In [None]:
tf.__version__

In [None]:
tf1.reset_default_graph()
sess = tf1.Session()
tf1.keras.backend.set_session(sess)
tf1.keras.backend.set_learning_phase(0)

## Load wav file

In [None]:
def waveread_as_pcm16(filename):
  """Read in audio data from a wav file.  Return d, sr."""
  samplerate, wave_data = wav.read(filename)
  # Read in wav file.
  return wave_data, samplerate

def wavread_as_float(filename, target_sample_rate=16000):
  """Read in audio data from a wav file.  Return d, sr."""
  wave_data, samplerate = waveread_as_pcm16(filename)
  desired_length = int(
      round(float(len(wave_data)) / samplerate * target_sample_rate))
  wave_data = scipy.signal.resample(wave_data, desired_length)

  # Normalize short ints to floats in range [-1..1).
  data = np.array(wave_data, np.float32) / 32768.0
  return data, target_sample_rate

In [None]:
# set PATH to data sets (for example to speech commands V2):
# it can be downloaded from
# https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
# if you run 00_check-data.ipynb then data2 should be located in the current folder
current_dir = os.getcwd()
DATA_PATH = os.path.join(current_dir, "data2/")

In [None]:
# Set path to wav file for testing.
wav_file = os.path.join(DATA_PATH, "left/012187a4_nohash_0.wav")

# read audio file
wav_data, samplerate = wavread_as_float(wav_file)

In [None]:
assert samplerate == 16000

In [None]:
plt.plot(wav_data)

## Prepare batched model

In [None]:
# This notebook is configured to work with 'ds_tc_resnet' and 'svdf'.
MODEL_NAME = 'ds_tc_resnet'
# MODEL_NAME = 'svdf'
MODELS_PATH = os.path.join(current_dir, "models")
MODEL_PATH = os.path.join(MODELS_PATH, MODEL_NAME + "/")
MODEL_PATH

In [None]:
train_dir = os.path.join(MODELS_PATH, MODEL_NAME)

In [None]:
# below is another way of reading flags - through json
with tf.compat.v1.gfile.Open(os.path.join(train_dir, 'flags.json'), 'r') as fd:
  flags_json = json.load(fd)

class DictStruct(object):
  def __init__(self, **entries):
    self.__dict__.update(entries)

flags = DictStruct(**flags_json)

In [None]:
flags.data_dir = DATA_PATH

In [None]:
# get total stride of the model

total_stride = 1
if MODEL_NAME == 'ds_tc_resnet':
  # it can be automated by scanning layers of the model, but for now just use parameters of specific model
  pools = utils.parse(flags.ds_pool)
  strides = utils.parse(flags.ds_stride)
  time_stride = [1]
  for pool in pools:
    if pool > 1:
      time_stride.append(pool)
  for stride in strides:
    if stride > 1:
      time_stride.append(stride)
  total_stride = np.prod(time_stride)

# overide input data shape for streaming model with stride/pool
flags.data_stride = total_stride
flags.data_shape = (total_stride * flags.window_stride_samples,)

In [None]:
# prepare mapping of index to word
audio_processor = input_data.AudioProcessor(flags)
index_to_label = {}
# labels used for training
for word in audio_processor.word_to_index.keys():
  if audio_processor.word_to_index[word] == du.SILENCE_INDEX:
    index_to_label[audio_processor.word_to_index[word]] = du.SILENCE_LABEL
  elif audio_processor.word_to_index[word] == du.UNKNOWN_WORD_INDEX:
    index_to_label[audio_processor.word_to_index[word]] = du.UNKNOWN_WORD_LABEL
  else:
    index_to_label[audio_processor.word_to_index[word]] = word

# training labels
index_to_label

In [None]:
# pad input audio with zeros, so that audio len = flags.desired_samples
padded_wav = np.pad(wav_data, (0, flags.desired_samples-len(wav_data)), 'constant')

input_data = np.expand_dims(padded_wav, 0)
input_data.shape

In [None]:
# create model with flag's parameters
model_non_stream_batch = models.MODELS[flags.model_name](flags)

# load model's weights
weights_name = 'best_weights'
model_non_stream_batch.load_weights(os.path.join(train_dir, weights_name))

In [None]:
tf.keras.utils.plot_model(
    model_non_stream_batch,
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True)

## Run inference with TF

### TF Run non streaming inference

In [None]:
# convert model to inference mode with batch one
inference_batch_size = 1
tf.keras.backend.set_learning_phase(0)
flags.batch_size = inference_batch_size  # set batch size

model_non_stream = utils.to_streaming_inference(model_non_stream_batch, flags, Modes.NON_STREAM_INFERENCE)
#model_non_stream.summary()

In [None]:
tf.keras.utils.plot_model(
    model_non_stream,
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True)

In [None]:
predictions = model_non_stream.predict(input_data)
predicted_labels = np.argmax(predictions, axis=1)

In [None]:
predicted_labels

In [None]:
index_to_label[predicted_labels[0]]

### TF Run streaming inference with internal state

In [None]:
# convert model to streaming mode
flags.batch_size = inference_batch_size  # set batch size

model_stream = utils.to_streaming_inference(model_non_stream_batch, flags, Modes.STREAM_INTERNAL_STATE_INFERENCE)
#model_stream.summary()

In [None]:
tf.keras.utils.plot_model(
    model_stream,
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True)

In [None]:
stream_output_prediction = test.run_stream_inference_classification(flags, model_stream, input_data)

In [None]:
stream_output_arg = np.argmax(stream_output_prediction)
stream_output_arg

In [None]:
index_to_label[stream_output_arg]

### TF Run streaming inference with external state

In [None]:
# convert model to streaming mode
flags.batch_size = inference_batch_size  # set batch size

model_stream_external = utils.to_streaming_inference(model_non_stream_batch, flags, Modes.STREAM_EXTERNAL_STATE_INFERENCE)
#model_stream.summary()

In [None]:
tf.keras.utils.plot_model(
    model_stream_external,
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True)

In [None]:
inputs = []
for s in range(len(model_stream_external.inputs)):
  inputs.append(np.zeros(model_stream_external.inputs[s].shape, dtype=np.float32))

window_stride = flags.data_shape[0]

start = 0
end = window_stride
while end <= input_data.shape[1]:
  # get new frame from stream of data
  stream_update = input_data[:, start:end]

  # update indexes of streamed updates
  start = end
  end = start + window_stride

  # set input audio data (by default input data at index 0)
  inputs[0] = stream_update

  # run inference
  outputs = model_stream_external.predict(inputs)

  # get output states and set it back to input states
  # which will be fed in the next inference cycle
  for s in range(1, len(model_stream_external.inputs)):
    inputs[s] = outputs[s]

  stream_output_arg = np.argmax(outputs[0])
stream_output_arg

In [None]:
index_to_label[stream_output_arg]

## Run inference with TFlite

### Run non streaming inference with TFLite

In [None]:
tflite_non_streaming_model = utils.model_to_tflite(sess, model_non_stream_batch, flags, Modes.NON_STREAM_INFERENCE)
tflite_non_stream_fname = 'tflite_non_stream.tflite'
with open(os.path.join(MODEL_PATH, tflite_non_stream_fname), 'wb') as fd:
  fd.write(tflite_non_streaming_model)

In [None]:
interpreter = tf.lite.Interpreter(model_content=tflite_non_streaming_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [None]:
# set input audio data (by default input data at index 0)
interpreter.set_tensor(input_details[0]['index'], input_data.astype(np.float32))

# run inference
interpreter.invoke()

# get output: classification
out_tflite = interpreter.get_tensor(output_details[0]['index'])

out_tflite_argmax = np.argmax(out_tflite)

out_tflite_argmax

In [None]:
index_to_label[out_tflite_argmax]

### Run streaming inference with TFLite

In [None]:
tflite_streaming_model = utils.model_to_tflite(sess, model_non_stream_batch, flags, Modes.STREAM_EXTERNAL_STATE_INFERENCE)
tflite_stream_fname = 'tflite_stream.tflite'

In [None]:
with open(os.path.join(MODEL_PATH, tflite_stream_fname), 'wb') as fd:
  fd.write(tflite_streaming_model)

In [None]:
interpreter = tf.lite.Interpreter(model_content=tflite_streaming_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_states = []
for s in range(len(input_details)):
  input_states.append(np.zeros(input_details[s]['shape'], dtype=np.float32))

In [None]:
out_tflite = test.run_stream_inference_classification_tflite(flags, interpreter, input_data, input_states)

In [None]:
out_tflite_argmax = np.argmax(out_tflite[0])

In [None]:
index_to_label[out_tflite_argmax]

## Run evaluation on all testing data

In [None]:
test.tflite_non_stream_model_accuracy(
    flags,
    MODEL_PATH,
    tflite_model_name=tflite_non_stream_fname,
    accuracy_name='tflite_non_stream_model_accuracy.txt')

In [None]:
test.tflite_stream_state_external_model_accuracy(
    flags,
    MODEL_PATH,
    tflite_model_name=tflite_stream_fname,
    accuracy_name='tflite_stream_state_external_model_accuracy.txt',
    reset_state=True)

In [None]:
test.tflite_stream_state_external_model_accuracy(
    flags,
    MODEL_PATH,
    tflite_model_name=tflite_stream_fname,
    accuracy_name='tflite_stream_state_external_model_accuracy.txt',
    reset_state=False)