##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");


In [0]:
#@title License

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Train your own Keyword Spotting Model.

Before running any cells please enable GPUs for this notebook to speed it up. 

* *Edit* → *Notebook Settings*
* select *GPU* from the *Hardware Accelerator* drop-down



## Download and install Cloud Text-to-Speech SDK

Please remember to click the *RESTART RUNTIME* button before proceeding to later cells.


In [0]:
!pip install --upgrade google-cloud-texttospeech


## <font color="red">STOP !</font>

Before proceeding to the next step, did you click on the *RESTART RUNTIME* button in the output?
If not please do so now. After the restart, proceed with the next step.

## Authenticate and get Cloud credential

Next we will authenticate and get Cloud credential in order to call the Google Cloud Text-to-Speech API to synthesize the speech examples.

In order to do so, please follow the 5 steps in the 5 cells labeled *A* through *E* below.


### A.  Create and/or selection a Google Cloud project to use.


1.  Go to the [Google Cloud project selector](https://console.cloud.google.com/projectselector2/home/dashboard) page.

2.  Select a project to use with this notebook, or create a new project by following the next few sub-items. You will need to create a new project if you have none to select from.
    * Click on *CREATE PROJECT* near the upper right corner.
    * Edit the project name to something more meaningful if you prefer.
    * Complete the project creation process by clicking on the *CREATE* button.


### B.  Enable billing on the Google Cloud project

Please note that with the free quota, you will not be charged unless you explicitly grant permission.

1.  First, ensure you have a billing account by clicking [here](https://console.cloud.google.com/billing). If you do not have a billing account, you may add one now:
    * Click on the *Add billing account* button.
    * Follow the instructions in the 2-page workflow. Please note the  text *No auto-payment after free trial ends* on the right.
2.  Next, enable billing on the project you selected (or just created) if billing is not already enabled on the project.
    * Go to the project page by clicking [here](https://console.cloud.google.com/billing/projects).
    * Click on the 3 dots to right of the selected project and select *Change billing* from the menu.
    * In the pop-up dialog, select the billing account you want to use. If there is only one billing account available, the message will inform you so and you can directly click on the *SET ACCOUNT* button.


### C.  Enable the Google Cloud Text-to-Speech API on the project

1.  Click [here](https://console.cloud.google.com/flows/enableapi?apiid=texttospeech.googleapis.com) to go to the page for enabling Google Cloud Text-to-Speech API.
2. Select the project from the drop-down menu on top (to the right of *Google Cloud Platform*) for registering the application to use Cloud Text-to-Speech API.
3.  Click on *Continue*. When the page returns with the message that *The API is enabled*, you may close the page.


### D.  Set up authentication

1.  Go to the [Create service account key](https://console.cloud.google.com/apis/credentials/serviceaccountkey) page.
2. Select the project from the drop-down menu on top (to the right of *Google Cloud Platform*).
3.  Select or create a new service account. Follow these steps if you need to create a new service account:
    * From the *Service account* drop-down, select *New service account*.
    * Enter a *Service account name*.
    * For convenience, from the *Role* drop-down, select *Project → Owner*.
    * Complete service account creation by clicking the *Create* button at the bottom left.
    * You may save the private key for the json file on your computer at this point and close the pop-up dialog.


### E.  Copy the email address for the service account and paste it into the form field in the next cell

If when running the next cell you encounter the following error:

    ERROR: (gcloud.iam.service-accounts.keys.create) RESOURCE_EXHAUSTED: Maximum number of keys on account reached.

Go to the [credentials page](https://console.cloud.google.com/apis/credentials) and click on the service account. This takes you to a page listing all the keys associated with the service account, and you can delete unused ones by clicking on the garbage can icon to the right.



In [0]:
service_account_name = 'YOUR_SERVICE_KEY_ACCOUNT_HERE@SOMETHING.iam.gserviceaccount.com'  #@param {type: "string"}

!gcloud auth login
!gcloud iam service-accounts keys create /tmp/key.json  --iam-account $service_account_name

import os
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/tmp/key.json'


## General imports and helper function


In [0]:
#@title Imports
%tensorflow_version 1.x\n
from __future__ import division

import collections
import IPython
import functools
import math
import matplotlib.pyplot as plt
import numpy as np
import io
import os
import tensorflow as tf
import tensorflow_hub as hub
import random
import scipy.io.wavfile
import tarfile
import time
import sys

from google.colab import output
from google.colab import widgets
from base64 import b64decode

!pip install ffmpeg-python
import ffmpeg

In [0]:
#@title Helper function

from IPython.display import HTML, display

def progress(value, maximum=100):
  return IPython.display.HTML("""
  <progress value='{value}' max='{max}' style='width: 80%'>{value}</progress>
    """.format(value=value, max=maximum))


## Get and load the test data

The following cell are responsible for getting the data into the colab and creating the embeddings on top which the model is trained.

To train a model on a different source of data, replace the next cell with one that copies in your data and change the file scanning cell to scan it correctly.

Finally, ensure that global variable MODEL_LABELS is appropriatly set.

File scanning is performed to create 2 lists of wav files:
 * A training file list containing all possible training files. (All files not in testing_list.txt or validation_list.txt)
 * An evaluation file list that we will use for testing (validation_list.txt)


File lists are actually dictionaries with the following structure:

```
{'keyword1': ['path/to/word1/example1.wav', path/to/word1/example2.wav'],
 'keyword2': ['path/to/word2/example1.wav', path/to/word2/example2.wav'],
 ...
 'negativ': ['path/to/negativ_example1.wav', path/to/negativ_example2.wav']}
 ```

The subsequent cells assume that the file lists are stored in the  variables: *all_eval_example_files* and *all_train_example_files*.



In [0]:
#@title Download and extract the speech commands data set
data_source = "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz"
data_dest_dir = "speech_commands_v0.02"
test_list = data_dest_dir + "/testing_list.txt"
valid_list = data_dest_dir + "/validation_list.txt"

TARGET_WORDS = 'yes,no,up,down,left,right,on,off,stop,go'
ALL_WORDS = 'backward,bed,bird,cat,dog,down,eight,five,follow,forward,four,go,' + 'happy,house,learn,left,marvin,nine,no,off,on,one,right,seven,sheila,six,stop,' + 'three,tree,two,up,visual,wow,yes,zero'

# Note: This example colab doesn't train the silence output. 
MODEL_LABELS = ['negative', 'silence'] + TARGET_WORDS.split(',')

!wget http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz

print("extracting tar archive.. this may take a few minutes.")

if not os.path.exists(data_dest_dir):
  os.makedirs(data_dest_dir)
#tarfile.open("speech_commands_v0.02.tar.gz", 'r:gz').extractall(data_dest_dir)
file_count = 0
progress_bar = IPython.display.display(progress(0, 100), display_id=True)

with tarfile.open("speech_commands_v0.02.tar.gz", 'r:gz') as speech_commands_tar:
  for member_info in speech_commands_tar.getmembers():
    if file_count % 100 == 0:
      progress_bar.update(progress(100 * file_count/105800))
    speech_commands_tar.extract(member_info, data_dest_dir)
    file_count+=1


In [0]:
#@title Scan files

progress_bar = IPython.display.display(progress(0, 100), display_id=True)
print("loading filelists from: %s " % data_dest_dir)

def get_test_files(word, test_file_names):
  word_dir = os.path.join(data_dest_dir, word)
  all_word_files = [
      os.path.join(word, f)
      for f in os.listdir(word_dir)
      if os.path.isfile(os.path.join(word_dir, f))
  ]
  word_test_files = [f for f in all_word_files if f in test_files]
  random.shuffle(word_test_files)
  return word_test_files

test_files = [line.rstrip() for line in open(test_list, encoding="ISO-8859-1")]

all_eval_example_files = collections.defaultdict(list)

all_word_list = ALL_WORDS.split(',')

word_count = 0
for word in all_word_list:
  if word in MODEL_LABELS:
    label = word
  else:
    label = "negative"
  eval_files = get_test_files(word, test_files)
  all_eval_example_files[label].extend(eval_files)
  if progress is not None:
    word_count += 1
    progress_bar.update(progress(100 * word_count/len(all_word_list)))


## Prepare synthetic data


Here the wav files from both evaluation and training sets are:
 * Opened and decoded.
 * Loudness normalized.
 * Passed through the embedding model to create embeddings.
 * Added to a data structure that let's us change the balance between negative and labeled outputs.


 resulting in two objects *eval_data* and *train_data*.
 

In [0]:
#@title Define list of cloud voices of interest

standard_voices = [
  # Australia
  "en-AU-Standard-A", "en-AU-Standard-B", "en-AU-Standard-C", "en-AU-Standard-D",
  # UK
  "en-GB-Standard-A", "en-GB-Standard-B", "en-GB-Standard-C", "en-GB-Standard-D",
  # India
  "en-IN-Standard-A", "en-IN-Standard-B", "en-IN-Standard-C",
  # US
  "en-US-Standard-B", "en-US-Standard-C", "en-US-Standard-D", "en-US-Standard-E",
]

wavenet_voices = [
  # Australia
  "en-AU-Wavenet-A", "en-AU-Wavenet-B", "en-AU-Wavenet-C", "en-AU-Wavenet-D",
  # UK
  "en-GB-Wavenet-A", "en-GB-Wavenet-B", "en-GB-Wavenet-C", "en-GB-Wavenet-D",
  # India
  "en-IN-Wavenet-A", "en-IN-Wavenet-B", "en-IN-Wavenet-C",
  # US
  "en-US-Wavenet-A", "en-US-Wavenet-B", "en-US-Wavenet-C", "en-US-Wavenet-D", "en-US-Wavenet-E", "en-US-Wavenet-F",
]

cloud_voices = standard_voices + wavenet_voices



In [0]:
#@title Setup Text-to-Speech synthesis

from google.cloud import texttospeech

# Instantiates a client
client = texttospeech.TextToSpeechClient()

SAMPLE_RATE = 16000
def synthesize_speech(client, text, cloud_voice, language='en-US'):
  # Select the type of audio file you want returned
  audio_config = texttospeech.types.AudioConfig(
      audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
      sample_rate_hertz=SAMPLE_RATE)

  # Set the text input to be synthesized
  synthesis_input = texttospeech.types.SynthesisInput(text=text)
  voice = texttospeech.types.VoiceSelectionParams(
      language_code=language,
      name=cloud_voice)


  # Perform the text-to-speech request on the text input with the selected
  # voice parameters and audio file type
  response = client.synthesize_speech(synthesis_input, voice, audio_config)

  return response



from google.colab import files
from io import BytesIO
import scipy.io.wavfile as sciwav
import collections
import time

def synthesize_texts(voices, texts, target_words):
  cloud_wavs = collections.defaultdict(list)
  progress_bar = display(progress(0, 100), display_id=True)
  num_synthesized_examples = 0
  total_exmaples_to_synthesize = len(voices) * len(texts)
  for voice in voices:
    for text in texts:
      response = synthesize_speech(client, text, voice)
      id = '{}_{}'.format(text, voice)
      filename = '{}.wav'.format(id)
      with open(filename, 'wb') as f_out:
        f_out.write(response.audio_content)
        # print("Audio content written to file '{}'.".format(filename))
        text = text.lower()
        text = text if text in target_words else 'negative'
        cloud_wavs[text].append(filename)
      time.sleep(0.200)
      if progress is not None:
        num_synthesized_examples += 1
        if num_synthesized_examples % 5 == 0:
          progress_bar.update(
              progress(100 * num_synthesized_examples/total_exmaples_to_synthesize))
  progress_bar.update(
      progress(100 * num_synthesized_examples/total_exmaples_to_synthesize))
  return cloud_wavs




In [0]:
#@title Synthesize speech examples

%%time

print("Synthesizing speech examples, this may take a few minutes.")

cloud_wavs = synthesize_texts(
    cloud_voices, ALL_WORDS.split(','), TARGET_WORDS.split(','))


## Load the embedding model

The following info messages can be ignored

> *INFO:tensorflow:Saver not created because there are no variables in the graph to restore*

Don't worry tf hub is restoring all the variables.

You can test the model by having it produce an embedding on zeros:


```
speech_embedding_model.create_embedding(np.zeros((1,66000)))
```


In [0]:
#@title Helper functions and classes
def normalized_read(filename):
  """Reads and normalizes a wavfile."""
  _, data = scipy.io.wavfile.read(open(filename, mode='rb'))
  samples_99_percentile = np.percentile(np.abs(data), 99.9)
  normalized_samples = data / samples_99_percentile
  normalized_samples = np.clip(normalized_samples, -1, 1)
  return normalized_samples

class EmbeddingDataFileList(object):
  """Container that loads audio, stores it as embeddings and can
  rebalance it."""

  def __init__(self, filelist,
               data_dest_dir,
               targets=None,
               label_max=10000,
               negative_label="negative",
               negative_multiplier=25,
               target_samples=32000,
               progress_bar=None,
               embedding_model=None):
    """Creates an instance of `EmbeddingDataFileList`."""
    self._negative_label = negative_label
    self._data_per_label = collections.defaultdict(list)
    self._labelcounts = {}
    self._label_list = targets
    total_examples = sum([min(len(x), label_max) for x in filelist.values()])
    total_examples -= min(len(filelist[negative_label]), label_max)
    total_examples += min(len(filelist[negative_label]), negative_multiplier * label_max)
    print("loading %d examples" % total_examples)
    example_count = 0
    for label in filelist:
      if label not in self._label_list:
        raise ValueError("Unknown label:", label)
      label_files = filelist[label]
      random.shuffle(label_files)
      if label == negative_label:
        multplier = negative_multiplier
      else:
        multplier = 1
      for wav_file in label_files[:label_max * multplier]:
        data = normalized_read(os.path.join(data_dest_dir, wav_file))
        required_padding = target_samples - data.shape[0]
        if required_padding > 0:
          data = np.pad(data, (required_padding, required_padding), 'constant')
        self._labelcounts[label] = self._labelcounts.get(label, 0) + 1
        if embedding_model:
           data = embedding_model.create_embedding(data)[0][0,:,:,:]
        self._data_per_label[label].append(data)
        if progress_bar is not None:
          example_count += 1
          progress_bar.update(progress(100 * example_count/total_examples))

  @property
  def labels(self):
    return self._label_list

  def get_label(self, idx):
    return self.labels.index(idx)

  def _get_filtered_data(self, label, filter_fn):
    idx = self.labels.index(label)
    return [(filter_fn(x), idx) for x in self._data_per_label[label]]

  def _multply_data(self, data, factor):
    samples = int((factor - math.floor(factor)) * len(data))
    return int(factor) * data + random.sample(data, samples)

  def full_rebalance(self, negatives, labeled):
    """Rebalances for a given ratio of labeled to negatives."""
    negative_count = self._labelcounts[self._negative_label]
    labeled_count = sum(self._labelcounts[key]
                        for key in self._labelcounts.keys()
                        if key != self._negative_label)
    labeled_multiply = labeled * negative_count / (negatives * labeled_count)
    for label in self._data_per_label:
      if label == self._negative_label:
        continue
      self._data_per_label[label] = self._multply_data(
          self._data_per_label[label], labeled_multiply)
      self._labelcounts[label] = len(self._data_per_label[label])

  def get_all_data_shuffled(self, filter_fn):
    """Returns a shuffled list containing all the data."""
    return self.get_all_data(filter_fn, shuffled=True)

  def get_all_data(self, filter_fn, shuffled=False):
    """Returns a list containing all the data."""
    data = []
    for label in self._data_per_label:
      data += self._get_filtered_data(label, filter_fn)
    if shuffled:
      random.shuffle(data)
    return data

def cut_middle_frame(embedding, num_frames, flatten):
  """Extrats the middle frames for an embedding."""
  left_context = (embedding.shape[0] - num_frames) // 2
  if flatten:
    return embedding[left_context:left_context+num_frames].flatten()
  else:
    return embedding[left_context:left_context+num_frames]


In [0]:
#@title TfHubWrapper Class

class TfHubWrapper(object):
  """A loads a tf hub embedding model."""
  def __init__(self, embedding_model_dir):
    """Creates a `SavedModelWraper`."""
    self._graph = tf.Graph()
    self._sess = tf.Session(graph=self._graph)
    with self._graph.as_default():
      with self._sess.as_default():
        module_spec = hub.load_module_spec(embedding_model_dir)
        embedding_module = hub.Module(module_spec)
        self._samples = tf.placeholder(
            tf.float32, shape=[1, None], name='audio_samples')
        self._embedding = embedding_module(self._samples)
        self._sess.run(tf.global_variables_initializer())
    print("Embedding model loaded, embedding shape:", self._embedding.shape)

  def create_embedding(self, samples):
    samples = samples.reshape((1, -1))
    output = self._sess.run(
        [self._embedding],
        feed_dict={self._samples: samples})
    return output


In [0]:
embedding_model_url = "https://tfhub.dev/google/speech_embedding/1"
speech_embedding_model = TfHubWrapper(embedding_model_url)


## Load training and eval data

In [0]:
#@title Load evaluation data
progress_bar = display(progress(0, 100), display_id=True)

print("loading eval data")
eval_data = EmbeddingDataFileList(
    all_eval_example_files, data_dest_dir,
    targets=MODEL_LABELS, embedding_model=speech_embedding_model,
    progress_bar=progress_bar)



In [0]:
#@title Load synthesized examples wav files for training

progress_bar = display(progress(0, 100), display_id=True)

print("loading train data")
train_data = EmbeddingDataFileList(
    cloud_wavs, '',
    targets=MODEL_LABELS,
    embedding_model=speech_embedding_model,
    progress_bar=progress_bar)


## Train and Evaluate a Head Model

In [0]:
#@title HeadTrainerClass and head model functions

def _fully_connected_model_fn(embeddings, num_labels):
  """Builds the head model and adds a fully connected output layer."""
  net = tf.layers.flatten(embeddings)
  logits = tf.compat.v1.layers.dense(net, num_labels, activation=None)
  return logits

framework = tf.contrib.framework
layers = tf.contrib.layers

def _conv_head_model_fn(embeddings, num_labels, context):
  """Builds the head model and adds a fully connected output layer."""
  activation_fn = tf.nn.elu
  normalizer_fn = functools.partial(
      layers.batch_norm, scale=True, is_training=True)
  with framework.arg_scope([layers.conv2d], biases_initializer=None,
                           activation_fn=None, stride=1, padding="SAME"):
    net = embeddings
    net = layers.conv2d(net, 96, [3, 1])
    net = normalizer_fn(net)
    net = activation_fn(net)
    net = layers.max_pool2d(net, [2, 1], stride=[2, 1], padding="VALID")
    context //= 2
    net = layers.conv2d(net, 96, [3, 1])
    net = normalizer_fn(net)
    net = activation_fn(net)
    net = layers.max_pool2d(net, [context, net.shape[2]], padding="VALID")
  net = tf.layers.flatten(net)
  logits = layers.fully_connected(
      net, num_labels, activation_fn=None)
  return logits

class HeadTrainer(object):
  """A tensorflow classifier to quickly train and test on embeddings.

  Only use this if you are training a very small model on a very limited amount
  of data. If you expect the training to take any more than 15 - 20 min then use
  something else.
  """

  def __init__(self, model_fn, input_shape, num_targets,
               head_learning_rate=0.001, batch_size=64):
    """Creates a `HeadTrainer`.

    Args:
      model_fn: function that builds the tensorflow model, defines its loss
          and returns the tuple (predictions, loss, accuracy).
      input_shape: describes the shape of the models input feature.
          Does not include a the batch dimension.
      num_targets: Target number of keywords.
    """
    self._input_shape = input_shape
    self._output_dim = num_targets
    self._batch_size = batch_size
    self._graph = tf.Graph()
    with self._graph.as_default():
      self._feature = tf.placeholder(tf.float32, shape=([None] + input_shape))
      self._labels = tf.placeholder(tf.int64, shape=(None))
      module_spec = hub.create_module_spec(
          module_fn=self._get_headmodule_fn(model_fn, num_targets))
      self._module = hub.Module(module_spec, trainable=True)
      logits = self._module(self._feature)
      self._predictions = tf.nn.softmax(logits)
      self._loss, self._accuracy = self._get_loss(
          logits, self._labels, self._predictions)
      self._update_weights = tf.train.AdamOptimizer(
          learning_rate=head_learning_rate).minimize(self._loss)
    self._sess = tf.Session(graph=self._graph)
    with self._sess.as_default():
      with self._graph.as_default():
        self._sess.run(tf.local_variables_initializer())
        self._sess.run(tf.global_variables_initializer())

  def _get_headmodule_fn(self, model_fn, num_targets):
    """Wraps the model_fn in a tf hub module."""
    def module_fn():
      embeddings = tf.placeholder(
          tf.float32, shape=([None] + self._input_shape))
      logit = model_fn(embeddings, num_targets)
      hub.add_signature(name='default', inputs=embeddings, outputs=logit)
    return module_fn


  def _get_loss(self, logits, labels, predictions):
    """Defines the model's loss and accuracy."""
    xentropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=labels)
    loss = tf.reduce_mean(xentropy_loss)
    accuracy = tf.contrib.metrics.accuracy(tf.argmax(predictions, 1), labels)
    return loss, accuracy

  def save_head_model(self, save_directory):
    """Saves the model."""
    with self._graph.as_default():
      self._module.export(save_directory, self._sess)


  def _feature_transform(self, batch_features, batch_labels):
    """Transforms lists of features and labels into into model inputs."""
    return np.stack(batch_features), np.stack(batch_labels)

  def _batch_data(self, data, batch_size=None):
    """Splits the input data into batches."""
    batch_features = []
    batch_labels = []
    batch_size = batch_size or len(data)
    for feature, label in data:
      if feature.shape != tuple(self._input_shape):
        raise ValueError(
            "Feature shape ({}) doesn't match model shape ({})".format(
                feature.shape, self._input_shape))
      if not 0 <= label < self._output_dim:

        raise ValueError('Label value ({}) outside of target range'.format(
            label))
      batch_features.append(feature)
      batch_labels.append(label)
      if len(batch_features) == batch_size:
        yield self._feature_transform(batch_features, batch_labels)
        del batch_features[:]
        del batch_labels[:]
    if batch_features:
      yield self._feature_transform(batch_features, batch_labels)

  def epoch_train(self, data, epochs=1, batch_size=None):
    """Trains the model on the provided data.

    Args:
      data: List of tuples (feature, label) where feature is a np array of
          shape `self._input_shape` and label an int less than self._output_dim.
      epochs: Number of times this data should be trained on.
      batch_size: Number of feature, label pairs per batch. Overwrites
          `self._batch_size` when set.

    Returns:
      tuple of accuracy, loss;
          accuracy: Average training accuracy.
          loss: Loss of the final batch.
    """
    batch_size = batch_size or self._batch_size
    accuracy_list = []
    for _ in range(epochs):
      for features, labels in self._batch_data(data, batch_size):
        loss, accuracy, _ = self._sess.run(
            [self._loss, self._accuracy, self._update_weights],
            feed_dict={self._feature: features, self._labels: labels})
        accuracy_list.append(accuracy)
    return (sum(accuracy_list))/len(accuracy_list), loss

  def test(self, data, batch_size=None):
    """Evaluates the model on the provided data.

    Args:
      data: List of tuples (feature, label) where feature is a np array of
          shape `self._input_shape` and label an int less than self._output_dim.
      batch_size: Number of feature, label pairs per batch. Overwrites
          `self._batch_size` when set.

    Returns:
      tuple of accuracy, loss;
          accuracy: Average training accuracy.
          loss: Loss of the final batch.
    """
    batch_size = batch_size or self._batch_size
    accuracy_list = []
    for features, labels in self._batch_data(data, batch_size):
      loss, accuracy = self._sess.run(
          [self._loss, self._accuracy],
          feed_dict={self._feature: features, self._labels: labels})
      accuracy_list.append(accuracy)
    return sum(accuracy_list)/len(accuracy_list), loss

  def infer(self, example_feature):
    """Runs inference on example_feature."""
    if example_feature.shape != tuple(self._input_shape):
      raise ValueError(
          "Feature shape ({}) doesn't match model shape ({})".format(
              example_feature.shape, self._input_shape))
    return self._sess.run(
        self._predictions,
        feed_dict={self._feature: np.expand_dims(example_feature, axis=0)})

In [0]:
#@title Rebalance and filter data.

labeled_weight = 8 #@param {type:"slider", min:1, max:25, step:1}
negatives_weight = 1 #@param {type:"slider", min:1, max:25, step:1}
context_size = 16 #@param {type:"slider", min:1, max:28, step:1}


filter_fn = functools.partial(cut_middle_frame, num_frames=context_size, flatten=False)


eval_data.full_rebalance(negatives=negatives_weight, labeled=labeled_weight)
all_eval_data = eval_data.get_all_data_shuffled(filter_fn=filter_fn)

train_data.full_rebalance(negatives=negatives_weight, labeled=labeled_weight)
all_train_data = train_data.get_all_data_shuffled(filter_fn=filter_fn)

In [0]:
#@title Run training and evaluation
head_model = "Convolutional" #@param ["Convolutional", "Fully_Connected"] {type:"string"}

#@markdown Suggested **learning_rate** range 0.00001 - 0.01.
learning_rate = 0.002 #@param {type:"number"}
batch_size = 64
#@markdown **epochs_per_eval** and **train_eval_loops** control how long the
#@markdown the model is trained. An epoch is defined as the model having seen
#@markdown each example at least once, with some examples twice to ensure the
#@markdown correct labeled / negatives balance.

epochs_per_eval = 1 #@param {type:"slider", min:1, max:15, step:1}
train_eval_loops = 15 #@param {type:"slider", min:5, max:80, step:5}

if head_model == "Convolutional":
  model_fn = functools.partial(_conv_head_model_fn, context=context_size)
else:
  model_fn = _fully_connected_model_fn


trainer = HeadTrainer(model_fn=model_fn,
                      input_shape=[context_size,1,96],
                      num_targets=len(MODEL_LABELS),
                      head_learning_rate=learning_rate,
                      batch_size=batch_size)

data_trained_on = 0
data = [] 
train_results = []
eval_results = []
max_data = len(all_train_data) * epochs_per_eval * train_eval_loops + 10


def plot_step(plot, max_data, data, train_results, eval_results):
  plot.clf()
  plot.xlim(0, max_data)
  plot.ylim(0.85, 1.05)
  plot.plot(data, train_results, "bo")
  plot.plot(data, train_results, "b", label="train_results")
  if eval_results:
    plot.plot(data, eval_results, "ro")
    plot.plot(data, eval_results, "r", label="eval_results")
  plot.legend(loc='lower right', fontsize=24)
  plot.xlabel('number of examples trained on', fontsize=22)
  plot.ylabel('Accuracy', fontsize=22)
  plot.xticks(fontsize=20)
  plot.yticks(fontsize=20)

plt.figure(figsize=(25, 7))
for loop in range(train_eval_loops):
  train_accuracy, loss = trainer.epoch_train(all_train_data,
                                             epochs=epochs_per_eval)
  train_results.append(train_accuracy)
  if all_eval_data:
    eval_accuracy, loss = trainer.test(all_eval_data)
    eval_results.append(eval_accuracy)
  else:
    eval_results = None

  data_trained_on += len(all_train_data) * epochs_per_eval
  data.append(data_trained_on)
  plot_step(plt, max_data, data, train_results, eval_results)

  IPython.display.display(plt.gcf())
  if all_eval_data:
    print("Highest eval accuracy: %.2f percent." % (100 * max(eval_results)))
  IPython.display.clear_output(wait=True)

if all_eval_data:
  print("Highest eval accuracy: %.2f percent." % (100 * max(eval_results)))
