##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [0]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train your own Keyword Spotting Model.
[Open in Google  Colab](https://colab.research.google.com/github/google-research/google-research/blob/master/speech_embedding/record_train.ipynb)

Before running any cells please enable GPUs for this notebook to speed it up. 

* *Edit* → *Notebook Settings*
* select *GPU* from the *Hardware Accelerator* drop-down



In [0]:
#@title Imports
%tensorflow_version 1.x
from __future__ import division

import collections
import IPython
import functools
import math
import matplotlib.pyplot as plt
import numpy as np
import io
import os
import tensorflow as tf
import tensorflow_hub as hub
import random
import scipy.io.wavfile
import tarfile
import time
import sys

from google.colab import output
from google.colab import widgets
from base64 import b64decode

!pip install ffmpeg-python
import ffmpeg

In [0]:
#@title Helper functions and classes
def normalized_read(filename):
  """Reads and normalizes a wavfile."""
  _, data = scipy.io.wavfile.read(open(filename, mode='rb'))
  samples_99_percentile = np.percentile(np.abs(data), 99.9)
  normalized_samples = data / samples_99_percentile
  normalized_samples = np.clip(normalized_samples, -1, 1)
  return normalized_samples

class EmbeddingDataFileList(object):
  """Container that loads audio, stores it as embeddings and can
  rebalance it."""

  def __init__(self, filelist,
               data_dest_dir,
               targets=None,
               label_max=10000,
               negative_label="negative",
               negative_multiplier=25,
               target_samples=32000,
               progress_bar=None,
               embedding_model=None):
    """Creates an instance of `EmbeddingDataFileList`."""
    self._negative_label = negative_label
    self._data_per_label = collections.defaultdict(list)
    self._labelcounts = {}
    self._label_list = targets
    total_examples = sum([min(len(x), label_max) for x in filelist.values()])
    total_examples -= min(len(filelist[negative_label]), label_max)
    total_examples += min(len(filelist[negative_label]), negative_multiplier * label_max)
    print("loading %d examples" % total_examples)
    example_count = 0
    for label in filelist:
      if label not in self._label_list:
        raise ValueError("Unknown label:", label)
      label_files = filelist[label]
      random.shuffle(label_files)
      if label == negative_label:
        multplier = negative_multiplier
      else:
        multplier = 1
      for wav_file in label_files[:label_max * multplier]:
        data = normalized_read(os.path.join(data_dest_dir, wav_file))
        required_padding = target_samples - data.shape[0]
        if required_padding > 0:
          data = np.pad(data, (required_padding, required_padding), 'constant')
        self._labelcounts[label] = self._labelcounts.get(label, 0) + 1
        if embedding_model:
           data = embedding_model.create_embedding(data)[0][0,:,:,:]
        self._data_per_label[label].append(data)
        if progress_bar is not None:
          example_count += 1
          progress_bar.update(progress(100 * example_count/total_examples))

  @property
  def labels(self):
    return self._label_list

  def get_label(self, idx):
    return self.labels.index(idx)

  def _get_filtered_data(self, label, filter_fn):
    idx = self.labels.index(label)
    return [(filter_fn(x), idx) for x in self._data_per_label[label]]

  def _multply_data(self, data, factor):
    samples = int((factor - math.floor(factor)) * len(data))
    return int(factor) * data + random.sample(data, samples)

  def full_rebalance(self, negatives, labeled):
    """Rebalances for a given ratio of labeled to negatives."""
    negative_count = self._labelcounts[self._negative_label]
    labeled_count = sum(self._labelcounts[key]
                        for key in self._labelcounts.keys()
                        if key != self._negative_label)
    labeled_multiply = labeled * negative_count / (negatives * labeled_count)
    for label in self._data_per_label:
      if label == self._negative_label:
        continue
      self._data_per_label[label] = self._multply_data(
          self._data_per_label[label], labeled_multiply)
      self._labelcounts[label] = len(self._data_per_label[label])

  def get_all_data_shuffled(self, filter_fn):
    """Returns a shuffled list containing all the data."""
    return self.get_all_data(filter_fn, shuffled=True)

  def get_all_data(self, filter_fn, shuffled=False):
    """Returns a list containing all the data."""
    data = []
    for label in self._data_per_label:
      data += self._get_filtered_data(label, filter_fn)
    if shuffled:
      random.shuffle(data)
    return data

def cut_middle_frame(embedding, num_frames, flatten):
  """Extrats the middle frames for an embedding."""
  left_context = (embedding.shape[0] - num_frames) // 2
  if flatten:
    return embedding[left_context:left_context+num_frames].flatten()
  else:
    return embedding[left_context:left_context+num_frames]


def progress(value, maximum=100):
  return IPython.display.HTML("""
  <progress value='{value}' max='{max}' style='width: 80%'>{value}</progress>
    """.format(value=value, max=maximum))

In [0]:
#@title HeadTrainerClass and head model functions

def _fully_connected_model_fn(embeddings, num_labels):
  """Builds the head model and adds a fully connected output layer."""
  net = tf.layers.flatten(embeddings)
  logits = tf.compat.v1.layers.dense(net, num_labels, activation=None)
  return logits

framework = tf.contrib.framework
layers = tf.contrib.layers

def _conv_head_model_fn(embeddings, num_labels, context):
  """Builds the head model and adds a fully connected output layer."""
  activation_fn = tf.nn.elu
  normalizer_fn = functools.partial(
      layers.batch_norm, scale=True, is_training=True)
  with framework.arg_scope([layers.conv2d], biases_initializer=None,
                           activation_fn=None, stride=1, padding="SAME"):
    net = embeddings
    net = layers.conv2d(net, 96, [3, 1])
    net = normalizer_fn(net)
    net = activation_fn(net)
    net = layers.max_pool2d(net, [2, 1], stride=[2, 1], padding="VALID")
    context //= 2
    net = layers.conv2d(net, 96, [3, 1])
    net = normalizer_fn(net)
    net = activation_fn(net)
    net = layers.max_pool2d(net, [context, net.shape[2]], padding="VALID")
  net = tf.layers.flatten(net)
  logits = layers.fully_connected(
      net, num_labels, activation_fn=None)
  return logits

class HeadTrainer(object):
  """A tensorflow classifier to quickly train and test on embeddings.

  Only use this if you are training a very small model on a very limited amount
  of data. If you expect the training to take any more than 15 - 20 min then use
  something else.
  """

  def __init__(self, model_fn, input_shape, num_targets,
               head_learning_rate=0.001, batch_size=64):
    """Creates a `HeadTrainer`.

    Args:
      model_fn: function that builds the tensorflow model, defines its loss
          and returns the tuple (predictions, loss, accuracy).
      input_shape: describes the shape of the models input feature.
          Does not include a the batch dimension.
      num_targets: Target number of keywords.
    """
    self._input_shape = input_shape
    self._output_dim = num_targets
    self._batch_size = batch_size
    self._graph = tf.Graph()
    with self._graph.as_default():
      self._feature = tf.placeholder(tf.float32, shape=([None] + input_shape))
      self._labels = tf.placeholder(tf.int64, shape=(None))
      module_spec = hub.create_module_spec(
          module_fn=self._get_headmodule_fn(model_fn, num_targets))
      self._module = hub.Module(module_spec, trainable=True)
      logits = self._module(self._feature)
      self._predictions = tf.nn.softmax(logits)
      self._loss, self._accuracy = self._get_loss(
          logits, self._labels, self._predictions)
      self._update_weights = tf.train.AdamOptimizer(
          learning_rate=head_learning_rate).minimize(self._loss)
    self._sess = tf.Session(graph=self._graph)
    with self._sess.as_default():
      with self._graph.as_default():
        self._sess.run(tf.local_variables_initializer())
        self._sess.run(tf.global_variables_initializer())

  def _get_headmodule_fn(self, model_fn, num_targets):
    """Wraps the model_fn in a tf hub module."""
    def module_fn():
      embeddings = tf.placeholder(
          tf.float32, shape=([None] + self._input_shape))
      logit = model_fn(embeddings, num_targets)
      hub.add_signature(name='default', inputs=embeddings, outputs=logit)
    return module_fn


  def _get_loss(self, logits, labels, predictions):
    """Defines the model's loss and accuracy."""
    xentropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=labels)
    loss = tf.reduce_mean(xentropy_loss)
    accuracy = tf.contrib.metrics.accuracy(tf.argmax(predictions, 1), labels)
    return loss, accuracy

  def save_head_model(self, save_directory):
    """Saves the model."""
    with self._graph.as_default():
      self._module.export(save_directory, self._sess)


  def _feature_transform(self, batch_features, batch_labels):
    """Transforms lists of features and labels into into model inputs."""
    return np.stack(batch_features), np.stack(batch_labels)

  def _batch_data(self, data, batch_size=None):
    """Splits the input data into batches."""
    batch_features = []
    batch_labels = []
    batch_size = batch_size or len(data)
    for feature, label in data:
      if feature.shape != tuple(self._input_shape):
        raise ValueError(
            "Feature shape ({}) doesn't match model shape ({})".format(
                feature.shape, self._input_shape))
      if not 0 <= label < self._output_dim:

        raise ValueError('Label value ({}) outside of target range'.format(
            label))
      batch_features.append(feature)
      batch_labels.append(label)
      if len(batch_features) == batch_size:
        yield self._feature_transform(batch_features, batch_labels)
        del batch_features[:]
        del batch_labels[:]
    if batch_features:
      yield self._feature_transform(batch_features, batch_labels)

  def epoch_train(self, data, epochs=1, batch_size=None):
    """Trains the model on the provided data.

    Args:
      data: List of tuples (feature, label) where feature is a np array of
          shape `self._input_shape` and label an int less than self._output_dim.
      epochs: Number of times this data should be trained on.
      batch_size: Number of feature, label pairs per batch. Overwrites
          `self._batch_size` when set.

    Returns:
      tuple of accuracy, loss;
          accuracy: Average training accuracy.
          loss: Loss of the final batch.
    """
    batch_size = batch_size or self._batch_size
    accuracy_list = []
    for _ in range(epochs):
      for features, labels in self._batch_data(data, batch_size):
        loss, accuracy, _ = self._sess.run(
            [self._loss, self._accuracy, self._update_weights],
            feed_dict={self._feature: features, self._labels: labels})
        accuracy_list.append(accuracy)
    return (sum(accuracy_list))/len(accuracy_list), loss

  def test(self, data, batch_size=None):
    """Evaluates the model on the provided data.

    Args:
      data: List of tuples (feature, label) where feature is a np array of
          shape `self._input_shape` and label an int less than self._output_dim.
      batch_size: Number of feature, label pairs per batch. Overwrites
          `self._batch_size` when set.

    Returns:
      tuple of accuracy, loss;
          accuracy: Average training accuracy.
          loss: Loss of the final batch.
    """
    batch_size = batch_size or self._batch_size
    accuracy_list = []
    for features, labels in self._batch_data(data, batch_size):
      loss, accuracy = self._sess.run(
          [self._loss, self._accuracy],
          feed_dict={self._feature: features, self._labels: labels})
      accuracy_list.append(accuracy)
    return sum(accuracy_list)/len(accuracy_list), loss

  def infer(self, example_feature):
    """Runs inference on example_feature."""
    if example_feature.shape != tuple(self._input_shape):
      raise ValueError(
          "Feature shape ({}) doesn't match model shape ({})".format(
              example_feature.shape, self._input_shape))
    return self._sess.run(
        self._predictions,
        feed_dict={self._feature: np.expand_dims(example_feature, axis=0)})

In [0]:
#@title TfHubWrapper Class

class TfHubWrapper(object):
  """A loads a tf hub embedding model."""
  def __init__(self, embedding_model_dir):
    """Creates a `SavedModelWraper`."""
    self._graph = tf.Graph()
    self._sess = tf.Session(graph=self._graph)
    with self._graph.as_default():
      with self._sess.as_default():
        module_spec = hub.load_module_spec(embedding_model_dir)
        embedding_module = hub.Module(module_spec)
        self._samples = tf.placeholder(
            tf.float32, shape=[1, None], name='audio_samples')
        self._embedding = embedding_module(self._samples)
        self._sess.run(tf.global_variables_initializer())
    print("Embedding model loaded, embedding shape:", self._embedding.shape)

  def create_embedding(self, samples):
    samples = samples.reshape((1, -1))
    output = self._sess.run(
        [self._embedding],
        feed_dict={self._samples: samples})
    return output

In [0]:
#@title Define AudioClipRecorder Class
AUDIOCLIP_HTML ='''
<span style="font-size:30px">Recorded audio clips of {keyphrase}:</span> 
<div id='target{keyphrase}'></div>
<span id = "status_label{keyphrase}" style="font-size:30px">
  Ready to record.</span>
<button id='Add{keyphrase}Audio'>Record</button>
<script>
var recorder;
var base64data = 0;

function sleep(ms) {{
  return new Promise(resolve => setTimeout(resolve, ms));
}}

var handleSuccess = function(stream) {{
  recorder = new MediaRecorder(stream);
  recorder.ondataavailable = function(e) {{            
    reader = new FileReader();
    reader.readAsDataURL(e.data); 
    reader.onloadend = function() {{
      base64data = reader.result;
    }}
  }};
  recorder.start();
}};

document.querySelector('#Add{keyphrase}Audio').onclick = () => {{
  var label = document.getElementById("status_label{keyphrase}"); 
  navigator.mediaDevices.getUserMedia({{audio: true}}).then(handleSuccess);
  label.innerHTML = "Recording ... please say {keyphrase}!".fontcolor("red");; 
    sleep({clip_length_ms}).then(() => {{
    recorder.stop();
    label.innerHTML = "Recording finished ... processing audio."; 
    sleep(1000).then(() => {{
      google.colab.kernel.invokeFunction('notebook.AddAudioItem{keyphrase}',
      [base64data.toString()], {{}});
      label.innerHTML = "Ready to record.";
    }});
}});
}};
</script>'''

class AudioClipRecorder:
  """Python class that creates a JS microphone clip recorder."""

  def __init__(self, keyphrase="test", clip_length_ms=2100):
    """Creates an AudioClipRecorder instance.

    When created this class prints an empty <div> tag into which the
    recorded clips will be printed and a record audio button that uses
    javascript to access the microphone and record an audio clip.
    
    Args:
      keyphrase: The name of the keyphrase that should be recorded.
        This will be displayed in the recording prompt and used as a
        directory name when the recordings are exported.
      clip_length_ms: The length (in ms) of each recorded audio clip.
        Due to the async nature of javascript this actual amount of recorded
        audio may vary by a ~20-80ms.
    """
    self._counter = 0
    self._keyphrase = keyphrase
    self._audio_clips = {}
    IPython.display.display(IPython.display.HTML(AUDIOCLIP_HTML.format(
        keyphrase=keyphrase, clip_length_ms=clip_length_ms)))
    output.register_callback('notebook.AddAudioItem' + keyphrase,
                             self.add_list_item)
    output.register_callback('notebook.RemoveAudioItem' + keyphrase,
                             self.rm_audio)

  def add_list_item(self, data):
    """Adds the recorded audio to the list of clips.

    This function is called from javascript after clip_length_ms audio has
    been recorded. It prints the recorded audio clip to the <div> together with
    a button that allows for it to be deleted.

    Args:
      data: The recorded audio in webm format.
    """
    raw_string_data = data.split(',')[1]
    samples, rate = self.decode_webm(raw_string_data)
    length_samples = len(samples)
    with output.redirect_to_element('#target{keyphrase}'.format(
        keyphrase=self._keyphrase)):
      with output.use_tags('{keyphrase}_audio_{counter}'.format(
          counter=self._counter, keyphrase=self._keyphrase)):
        IPython.display.display(IPython.display.HTML('''Audio clip {counter} - 
        {length} samples - 
        <button id=\'delbutton{keyphrase}{counter}\'>del</button>
        <script>
        document.querySelector('#delbutton{keyphrase}{counter}').onclick = () => {{
          google.colab.kernel.invokeFunction('notebook.RemoveAudioItem{keyphrase}', [{counter}], {{}});
        }};
        </script>'''.format(counter=self._counter, length=length_samples,
                            keyphrase=self._keyphrase)))
        IPython.display.display(IPython.display.Audio(data=samples, rate=rate))
        IPython.display.display(IPython.display.HTML('<br><br>'))
        self._audio_clips[self._counter]=samples
      self._counter+=1

  def rm_audio(self, count):
    """Removes the audioclip 'count' from the list of clips."""
    output.clear(output_tags="{0}_audio_{1}".format(self._keyphrase, count))
    self._audio_clips.pop(count)

  def decode_webm(self, data):
    """Decodes a webm audio clip in a np.array of samples."""
    sample_rate=16000
    process = (ffmpeg
      .input('pipe:0')
      .output('pipe:1', format='s16le', ar=sample_rate)
      .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True,
                 quiet=True, overwrite_output=True)
    )
    output, err = process.communicate(input=b64decode(data))
    audio = np.frombuffer(output, dtype=np.int16).astype(np.float32)
    return audio, sample_rate

  def save_as_wav_files(self, base_output_dir,
                        file_prefix='recording_', file_suffix=''):
    """Exports all audio clips as wav files.

    The files wav files will be written to 'base_output_dir/self._keyphrase'.
    And will be named: file_prefix + str(clip_id) + file_suffix + '.wav'    
    """
    if not os.path.exists(base_output_dir):
      os.mkdir(base_output_dir)
    keyphrase_output_dir = os.path.join(base_output_dir, self._keyphrase)
    if not os.path.exists(keyphrase_output_dir):
      os.mkdir(keyphrase_output_dir)
    for clip_id in self._audio_clips:
      filename = file_prefix + str(clip_id) + file_suffix + '.wav'
      output_file = os.path.join(keyphrase_output_dir, filename)
      print("Creating:", output_file)
      scipy.io.wavfile.write(output_file, 16000, self._audio_clips[clip_id])




In [0]:
#@title Define AudioClipEval Class
class AudioClipEval(AudioClipRecorder):
  def __init__(self, embedding_model, head_model, filter_fn, labels,
               name="eval1", clip_length_ms=2100):
    """Creates an AudioClipEval instance.

    When created this class prints an empty <div> tag into which the
    recorded clips will be printed and a record audio button that uses
    javascript to access the microphone and record an audio clip.
    
    Args:
      embedding_model: The embedding model.
      head_model: The default head model.
      filter_fn: function that prepared the input to the head model.
      labels: List of head model target labels.
      keyphrase: The name of the keyphrase that should be recorded.
        This will be displayed in the recording prompt and used as a
        directory name when the recordings are exported.
      clip_length_ms: The length (in ms) of each recorded audio clip.
        Due to the async nature of javascript this actual amount of recorded
        audio may vary by a ~20-80ms.
    """
    self._counter = 0
    self._keyphrase = name
    keyphrase = name
    self._audio_clips = {}
    self._embedding_model = embedding_model
    self._head_model = head_model
    self._filter_fn = filter_fn
    self._labels = labels
    IPython.display.display(IPython.display.HTML(
        AUDIOCLIP_HTML.format(keyphrase=keyphrase, clip_length_ms=clip_length_ms)))
    output.register_callback('notebook.AddAudioItem' + keyphrase,
                             self.add_list_item)
    output.register_callback('notebook.RemoveAudioItem' + keyphrase,
                             self.rm_audio)

  def add_list_item(self, data):
    """Adds the recorded audio to the list of clips and classifies it.

    This function is called from javascript after clip_length_ms audio has
    been recorded. It prints the recorded audio clip to the <div> together with
    a button that allows for it to be deleted.

    Args:
      data: The recorded audio in webm format.
    """
    raw_string_data = data.split(',')[1]
    samples, rate = self.decode_webm(raw_string_data)
    length_samples = len(samples)
    detection, confidence = self.eval_audio(samples)
    with output.redirect_to_element('#target{keyphrase}'.format(
        keyphrase=self._keyphrase)):
      with output.use_tags('{keyphrase}_audio_{counter}'.format(
          counter=self._counter, keyphrase=self._keyphrase)):
        IPython.display.display(IPython.display.HTML('''Audio clip {counter} - 
        {length} samples - 
        <button id=\'delbutton{counter}\'>del</button>
        <script>
        document.querySelector('#delbutton{counter}').onclick = () => {{
          google.colab.kernel.invokeFunction('notebook.RemoveAudioItem{keyphrase}', [{counter}], {{}});
        }};
        </script>'''.format(counter=self._counter, length=length_samples,
                            keyphrase=self._keyphrase)))
        IPython.display.display(IPython.display.Audio(data=samples, rate=rate))
        IPython.display.display(IPython.display.HTML(
            '''<span id = "result{counter}" style="font-size:24px">
              detected: {detection} ({confidence})<span>'''.format(
                  counter=self._counter, detection=detection,
                  confidence=confidence)))
        IPython.display.display(IPython.display.HTML('<br><br>'))
        self._audio_clips[self._counter]=samples
      self._counter+=1

  def eval_audio(self, samples, head_model=None):
    """Classifies the audio using the current or a provided model."""
    embeddings = self._embedding_model.create_embedding(samples)[0][0,:,:,:]
    if head_model:
      probs = head_model.infer(self._filter_fn(embeddings))
    else:
      probs = self._head_model.infer(self._filter_fn(embeddings))
    return self._labels[np.argmax(probs)], np.amax(probs)

  def eval_on_new_model(self, head_model):
    """Reclassifies the clips using a new head model."""
    for clip_id in self._audio_clips:
      samples = self._audio_clips[clip_id]
      length_samples = len(samples)
      detection, confidence = self.eval_audio(samples, head_model=head_model)
      IPython.display.display(IPython.display.HTML(
            '''Audio clip {counter} -  {length} samples - 
            <span id = "result{counter}" style="font-size:24px">
              detected: {detection} ({confidence})<span>'''.format(
                counter=clip_id, length=length_samples,
                detection=detection, confidence=confidence)))     
      IPython.display.display(IPython.display.Audio(data=samples, rate=16000))

## Load the embedding model

The following info messages can be ignored

> *INFO:tensorflow:Saver not created because there are no variables in the graph to restore*

Don't worry tf hub is restoring all the variables.

You can test the model by having it produce an embedding on zeros:


```
speech_embedding_model.create_embedding(np.zeros((1,66000)))
```



In [0]:
embedding_model_url = "https://tfhub.dev/google/speech_embedding/1"
speech_embedding_model = TfHubWrapper(embedding_model_url)

## Record training data or copy from google drive

The following cells allow you to define a set of target keyphrases and record some examples for training.

### Optional Google Drive access.

The recorded wav files can be uploaded (and later download) from your Google drive using [PyDrive](https://googleworkspace.github.io/PyDrive/docs/build/html/index.html). When you run the *Set up Google drive access* cell it will prompt you to log in and grant this colab permission to access your Google drive. Only if you do this will you be able to run the other Google drive cells.



In [0]:
#@title Optional: Set up Google drive access
!pip install PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [0]:
#@title Optional: Download and untar an archive from drive

filename = ''#@param {type:"string"}
#@markdown You can find the file_id by looking at its share-link.
#@markdown e.g. *1b9Lkfie2NHX-O06vPGrqzyGcGWUPul36*
file_id =  ''#@param {type:"string"}

downloaded = drive.CreateFile({'id':file_id})
downloaded.GetContentFile(filename)
with tarfile.open(filename, 'r:gz') as data_tar_file:
  for member_info in data_tar_file.getmembers():
    print(member_info.name)
    data_tar_file.extract(member_info)

In [0]:
#@title Setup recording session and define model targets

#@markdown Only use letters and _ for the **RECORDING_NAME** and **TARGET_WORDS**. 
RECORDING_NAME = 'transportation' #@param {type:"string"}
target_word1 = 'hogwarts_express'  #@param {type:"string"}
target_word2 = 'nautilus'  #@param {type:"string"}
target_word3 = 'millennium_falcon'  #@param {type:"string"}
target_word4 = 'enterprise'  #@param {type:"string"}
target_word5 = ''  #@param {type:"string"}
target_word6 = ''  #@param {type:"string"}
clip_lengh_ms = 2100 #@param {type:"integer"}

#@markdown ### Microphone access
#@markdown Please connect the microphone that you want to use
#@markdown before running this cell. You may also be asked to
#@markdown to grant colab permission to use it.
#@markdown If you have any problems check your browser settings
#@markdown and rerun the cell.

target_words = [target_word1, target_word2, target_word3,
                target_word4, target_word5, target_word6]

OWN_TARGET_WORDS = ','.join([w for w in target_words if w is not ''])
OWN_MODEL_LABELS = ['negative', 'silence'] + OWN_TARGET_WORDS.split(',')

word_list = OWN_TARGET_WORDS.split(',')

t = widgets.TabBar(word_list)

clip_recorders = {}
for label in word_list:
  with t.output_to(word_list.index(label)):
    clip_recorders[label] = AudioClipRecorder(keyphrase=label,
                                              clip_length_ms=2100)

with t.output_to(0):
  print()

In [0]:
#@title Create wav files from recording session.

session =  'recording1_'#@param {type:"string"}
speaker =  '_spk1'#@param {type:"string"}

for label in clip_recorders:
  clip_recorders[label].save_as_wav_files(base_output_dir=RECORDING_NAME,
                                          file_prefix=session,
                                          file_suffix=speaker)

In [0]:
#@title Load files for training.
                                    
all_train_example_files = collections.defaultdict(list)

for label in OWN_TARGET_WORDS.split(','):
  label_dir = os.path.join(RECORDING_NAME, label)
  all_label_files = [
      os.path.join(label, f)
      for f in os.listdir(label_dir)
      if os.path.isfile(os.path.join(label_dir, f))
  ]
  all_train_example_files[label].extend(all_label_files)

progress_bar = IPython.display.display(progress(0, 100), display_id=True)
print("loading train data")
train_data = EmbeddingDataFileList(
    all_train_example_files, RECORDING_NAME,
    targets=OWN_MODEL_LABELS, embedding_model=speech_embedding_model,
    progress_bar=progress_bar)

In [0]:
#@title Optional: save recorded data to drive.

archive_name = RECORDING_NAME + "_" + str(int(time.time())) +".tar.gz"

def make_tarfile(output_filename, source_dir):
    with tarfile.open(output_filename, "w:gz") as tar:
        tar.add(source_dir, arcname=os.path.basename(source_dir))

make_tarfile(archive_name, RECORDING_NAME)

file1 = drive.CreateFile({'title': archive_name})
file1.SetContentFile(archive_name)
file1.Upload()
print('Saving to drive: %s, id: %s' % (file1['title'], file1['id']))

# Train a model on your recorded data

In [0]:
#@title Run training

#@markdown We assume that the keyphrase is spoken roughly in the middle
#@markdown of the loaded audio clips. With **context_size** we can choose the 
#@markdown number of embeddings around the middle to use as a model input.
context_size = 16 #@param {type:"slider", min:1, max:28, step:1}

filter_fn = functools.partial(cut_middle_frame, num_frames=context_size, flatten=False)
all_train_data = train_data.get_all_data_shuffled(filter_fn=filter_fn)
all_eval_data = None

head_model = "Convolutional" #@param ["Convolutional", "Fully_Connected"] {type:"string"}

#@markdown Suggested **learning_rate** range: 0.00001 - 0.01.
learning_rate = 0.001 #@param {type:"number"}
batch_size = 32
#@markdown **epochs_per_eval** and **train_eval_loops** control how long the
#@markdown the model is trained. An epoch is defined as the model having seen
#@markdown each example at least once, with some examples twice to ensure the
#@markdown correct labeled / negatives balance.

epochs_per_eval = 1 #@param {type:"slider", min:1, max:15, step:1}
train_eval_loops = 30 #@param {type:"slider", min:5, max:80, step:5}

if head_model == "Convolutional":
  model_fn = functools.partial(_conv_head_model_fn, context=context_size)
else:
  model_fn = _fully_connected_model_fn

trainer = HeadTrainer(model_fn=model_fn,
                      input_shape=[context_size,1,96],
                      num_targets=len(OWN_MODEL_LABELS),
                      head_learning_rate=learning_rate,
                      batch_size=batch_size)

data_trained_on = 0
data = [] 
train_results = []
eval_results = []
max_data = len(all_train_data) * epochs_per_eval * train_eval_loops + 10

def plot_step(plot, max_data, data, train_results, eval_results):
  plot.clf()
  plot.xlim(0, max_data)
  plot.ylim(0.85, 1.05)
  plot.plot(data, train_results, "bo")
  plot.plot(data, train_results, "b", label="train_results")
  if eval_results:
    plot.plot(data, eval_results, "ro")
    plot.plot(data, eval_results, "r", label="eval_results")
  plot.legend(loc='lower right', fontsize=24)
  plot.xlabel('number of examples trained on', fontsize=22)
  plot.ylabel('Accuracy', fontsize=22)
  plot.xticks(fontsize=20)
  plot.yticks(fontsize=20) 

plt.figure(figsize=(25, 7))
for loop in range(train_eval_loops):
  train_accuracy, loss = trainer.epoch_train(all_train_data,
                                             epochs=epochs_per_eval)
  train_results.append(train_accuracy)
  if all_eval_data:
    eval_accuracy, loss = trainer.test(all_eval_data)
    eval_results.append(eval_accuracy)
  else:
    eval_results = None

  data_trained_on += len(all_train_data) * epochs_per_eval
  data.append(data_trained_on)
  plot_step(plt, max_data, data, train_results, eval_results)

  IPython.display.display(plt.gcf())
  if all_eval_data:
    print("Highest eval accuracy: %.2f percent." % (100 * max(eval_results)))
  IPython.display.clear_output(wait=True)

if all_eval_data:
  print("Highest eval accuracy: %.2f percent." % (100 * max(eval_results)))


In [0]:
#@title Test the model
clip_eval = AudioClipEval(speech_embedding_model, trainer, filter_fn, OWN_MODEL_LABELS)

In [0]:
#@title Rerun the test using a new head model (train a new head model first)

clip_eval.eval_on_new_model(trainer)

## FAQ

Q: **My model isn't very good?**

A: The head model is very small and depends a lot on the initialisation weights:
 * This default setup doesn't have a negative class so it will always detect *something*. 
 * Try retraining it a couple of times.
 * Reduce the learning rate a little bit.
 * Add more training examples:
   * At 1 - 5 examples per keyphrase the model probably won't be very good.
   * With around 10-20 examples per keyphrase it may work reasonably well; however, it may still fail to learn a keyphrase.
   * If you only have examples from a single speaker, then it may only learn how that speaker pronounces the keyphrase.
 * Make sure your keyphrase are distinctive enough:
   * e.g. heads up vs ketchup




Q: **Can I export the model and use it somewhere?**

A: Yes, there's some example code in the following cells that demonstrate how that could be done. However, this simple example model is only training a between-word classifier.
If you want to use it in any relaistic setting, you will probably also want to add:
 * A negative or non-target-word speech class: You could do this by recording 2-10 min of continuous speech that doesn't contain your target keyphrases.
 * A non-speech / silence / background-noise class: The speech commands dataset contains some examples of non-speech background audio that could be used for this, and/or you could just leave your mircophone on and record some ambient audio from the future deployement location.

# Export and reuse the head model
The following cells show how the head model you just trained can be exported and reused in a graph.

In [0]:
#@title Save the head model

head_model_module_dir = "head_model_module_dir"
trainer.save_head_model(head_model_module_dir)

In [0]:
#@title FullModelWrapper - Example Class

class FullModelWrapper(object):
  """A loads a save model classifier."""
  def __init__(self, embedding_model_dir, head_model_dir):
    self._graph = tf.Graph()
    self._sess = tf.Session(graph=self._graph)
    with self._graph.as_default():
      embedding_module_spec = hub.load_module_spec(embedding_model_dir)
      embedding_module = hub.Module(embedding_module_spec)
      head_module_spec = hub.load_module_spec(head_model_dir)
      head_module = hub.Module(head_module_spec)
      self._samples = tf.placeholder(
          tf.float32, shape=[1, None], name='audio_samples')
      embedding = embedding_module(self._samples)
      logits = head_module(embedding)
      self._predictions = tf.nn.softmax(logits)
      with self._sess.as_default():
        self._sess.run(tf.global_variables_initializer())
 
  def infer(self, samples):
    samples = samples.reshape((1, -1))
    output = self._sess.run(
        [self._predictions],
        feed_dict={self._samples: samples})
    return output

In [0]:
#@title Test the full model on zeros
full_model = FullModelWrapper(embedding_model_url, head_model_module_dir)
full_model.infer(np.zeros((1,32000)))