# MicroWakeWord V2 Model Trainer for Google Colab

This notebook is specifically for use in Google Colab. The code below will train a basic MicroWakeWord model. It is intended as a **starting point** for advanced users. This notebook needs to be used with Python 3.10, if you elect to use a different version, code changes will be required.

This notebook will produce a very rough V2 .tflite model compatible with ESPHome, with no code changes, though you'll need to experiment with the settings to produce a reliable model. This is especially true if you are training a very short or long wake word (especially short or long wake words may be unuseable with default settings). There are comments inline poniting out the most impactful settings and what they affect.

> **Configuring the notebook**: You must supply a phoentic and directory-friendly name in step 3 below, and change the runtime type to GPU: T4 for free accounts, or optionally A100 for paid.
	If you have a _huggingface.co_ account, you are encouraged to provide your API token in a secret named "HF_TOKEN", however this is not required.

> **Running the notebook**:
	Once you've supplied your wake word values and any settings you'd like to change, **click the run/play button** on the left of the first card. Then click **Runtime**>**Restart session**. Then run each card in order starting with the second card.

> **Using the model**:
	Upon completion of this notebook, you will be prompted to download a tflite file. To use this in ESPHome, you need to write a V2 model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for esphome configuration details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for V2 .json examples.

In [None]:
################################################################################
# Install MicroWakeWord
################################################################################

###############
# You may see an error in the output about PIPs dependency resolver not taking
# installed packages into account: this is safe to ignore.
###############
import os

!git clone -b november-update https://github.com/kahrendt/microWakeWord.git
!pip install -e ./microWakeWord


Cloning into 'microWakeWord'...
remote: Enumerating objects: 899, done.[K
remote: Counting objects: 100% (283/283), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 899 (delta 208), reused 225 (delta 183), pack-reused 616 (from 1)[K
Receiving objects: 100% (899/899), 31.55 MiB | 13.70 MiB/s, done.
Resolving deltas: 100% (581/581), done.
Obtaining file:///content/microWakeWord
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting audiomentations (from microwakeword==0.1.0)
  Downloading audiomentations-0.37.0-py3-none-any.whl.metadata (11 kB)
Collecting audio_metadata (from microwakeword==0.1.0)
  Downloading audio_metadata-0.11.1-py3-none-any.whl.metadata (7.0 kB)
Collecting datasets (from microwakeword==0.1.0)
  Downloading datasets-3.1.0-py3-non

In [None]:
################################################################################
# Generate a sample of the wake word and play it for verifation
################################################################################

target_word = "hey willy"  # Phonetic spellings may produce better samples
target_word_friendly = "hey_willy" # Directory safe non-phonetic spelling of wake word (use _ in place of spaces)

generated_samples_output_dir = f"generated_samples/{target_word_friendly}"

print(f"Using {generated_samples_output_dir} for Piper generated wake word samples")

import os
import sys
from IPython.display import Audio

if not os.path.exists("./piper-sample-generator"):
  !git clone https://github.com/rhasspy/piper-sample-generator

if not os.path.isfile("./piper-sample-generator/models/en_US-libritts_r-medium.pt"):
  !wget -q -c --show-progress -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'

  # Install system dependencies
  !pip install torch
  !pip install torchaudio
  !pip install piper-phonemize

if "piper-sample-generator/" not in sys.path:
  sys.path.append("piper-sample-generator/")

!python3 piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples 1 \
--batch-size 1 \
--output-dir "{generated_samples_output_dir}"

Audio(f"{generated_samples_output_dir}/0.wav", autoplay=True)

Using generated_samples/hey_willy for Piper generated wake word samples
Cloning into 'piper-sample-generator'...
remote: Enumerating objects: 124, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 124 (delta 28), reused 22 (delta 22), pack-reused 87 (from 1)[K
Receiving objects: 100% (124/124), 1.03 MiB | 1.15 MiB/s, done.
Resolving deltas: 100% (51/51), done.
Collecting piper-phonemize
  Downloading piper_phonemize-1.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (282 bytes)
Downloading piper_phonemize-1.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (25.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m25.0/25.0 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: piper-phonemize
Successfully installed piper-phonemize-1.1.0
DEBUG:__main__:Loading /content/piper-sample-generator/models/en_US-libritts_r-medium.pt
  torch_model = torch.load(model_path)
INFO:__main__

In [None]:
################################################################################
# Generates a large number of wake word samples
#
# Start here when trying to improve your model.
# See https://github.com/rhasspy/piper-sample-generator for the full set of
# parameters. In particular, experiment with noise-scales and noise-scale-ws,
# generating negative samples similar to the wake word, and generating many more
# wake word samples, possibly with different phonetic pronunciations.
################################################################################

import os
import sys

!python3 piper-sample-generator/generate_samples.py "{target_word}" \
--model piper-sample-generator/models/en_US-libritts_r-medium.pt \
--max-samples 2048 \
--batch-size 92 \
--max-speakers 777 \
--noise-scales 0.667 \
--noise-scale-ws 0.667 \
--length-scales 0.85 \
--length-scales 1.0 \
--length-scales 1.15 \
--output-dir "{generated_samples_output_dir}"

DEBUG:__main__:Loading piper-sample-generator/models/en_US-libritts_r-medium.pt
  torch_model = torch.load(model_path)
INFO:__main__:Successfully loaded the model
DEBUG:__main__:CUDA available, using GPU
DEBUG:__main__:Batch 1/22 complete
DEBUG:__main__:Batch 2/22 complete
DEBUG:__main__:Batch 3/22 complete
DEBUG:__main__:Batch 4/22 complete
DEBUG:__main__:Batch 5/22 complete
DEBUG:__main__:Batch 6/22 complete
DEBUG:__main__:Batch 7/22 complete
DEBUG:__main__:Batch 8/22 complete
DEBUG:__main__:Batch 9/22 complete
DEBUG:__main__:Batch 10/22 complete
DEBUG:__main__:Batch 11/22 complete
DEBUG:__main__:Batch 12/22 complete
DEBUG:__main__:Batch 13/22 complete
DEBUG:__main__:Batch 14/22 complete
DEBUG:__main__:Batch 15/22 complete
DEBUG:__main__:Batch 16/22 complete
DEBUG:__main__:Batch 17/22 complete
DEBUG:__main__:Batch 18/22 complete
DEBUG:__main__:Batch 19/22 complete
DEBUG:__main__:Batch 20/22 complete
DEBUG:__main__:Batch 21/22 complete
DEBUG:__main__:Batch 22/22 complete
DEBUG:__main_

In [None]:
from types import GeneratorType
from __future__ import generators
################################################################################
# Download audio data for augmentation
#
# Based on openWakeWord's automatic_model_training.ipynb; March 4, 2024
#
# IMPORTANT! The data downloaded here has a mixture of different
# licenses and usage restrictions. As such, any custom models trained with this
# data should be considered as appropriate for personal use only. You must seek
# a license with rights' owners for any commercial use.
#
# This can take 30-45 minutes
################################################################################

import datasets
import scipy
import os
import random
import soundfile
import json
import numpy as np

from pathlib import Path
from tqdm import tqdm
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, DownloadConfig
from typing import Dict
from google.colab import userdata

hf_api_token_secret_name = "HF_SECRET"

archive_download_path = "./tmp"
rir_wav_path = "./mit_rirs_16k"
audioset_wav_path = "./audioset_16k"
fma_wav_path = "./fma_16k"
output_ext = ".wav"

if not os.path.exists("./tmp"):
    os.mkdir("./tmp")

# Due to a long standing "bug" in FLAC never solved in libsndfile, protect
# ourselves from psf_fseek() errors
# def audio_convert_wav_from_files(
#     dataset: IterableDataset,
#     output_path:str,
#     sourcefileext:str ):
#   i = 0

#   if not os.path.exists(output_path):
#     os.mkdir(output_path)

#   clip_count = 0

#   clip_count = dataset.num_shards

#   print(f"\nConverting {clip_count} clips to 16-bit WAV...\n")

#   try:
#     for row in dataset:
#         i += 1
#         if os.path.isfile(f"{output_path}/{row['audio']['path']}")
#         scipy \
#             .io \
#             .wavfile \
#             .write( \
#                   os.path.join(output_path, row[0]), \
#                   16000, \
#                   (row['audio']['array']*32767).astype(np.int16) \
#                   )
#         name = row['audio']['path'].split('/')[-1].replace(f".{sourcefileext}", ".wav")
#         if os.path.isfile(os.path.join(output_path, name)):
#           if (i % 50 == 0):
#             print(f"Skipped converting {i} cached files in {output_path}")
#           continue
#         else:
#           scipy \
#             .io \
#             .wavfile \
#             .write( \
#                   os.path.join(output_path, name), \
#                   16000, \
#                   (row['audio']['array']*32767).astype(np.int16) \
#                   )
#   except soundfile.LibsndfileError:
#       print(f"Failed conversion of FLAC to WAV, skipping file\n")
#       i += 1
#       # Here's where I learned a long lesson on indentation
#       audio_convert_wav(dataset.skip(i), output_path, sourcefileext)
#   i = 0
#   return

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)

def create_wav_files_from_dataset(dataset:IterableDataset, output_path: str):
  i = 0

  if not os.path.exists(output_path):
    os.mkdir(output_path)

  unprocessed_audio = dataset.filter(lambda example: example["sentence1"].startswith("Ar"))
  len(start_with_ar)

  clip_count = 0

  clip_count = dataset.info.dataset_size

  existing_files = Path(f"{output_path}").glob(f"**/*.{output_ext}")

  unprocessed_audio = dataset.filter(lambda example: example["sentence1"].startswith("Ar"))

  print(f"\nConverting {clip_count} clips to 16-bit WAV...\n")

  try:
    for row in dataset:
        i += 1
        print("Dumping row:\n")
        print(json.dumps(row))
        print("\n")

        dest_wav_file_name = row['path'].split('/')[-1].split('.')[0] + f"{output_ext}"

        dest_wav_file_path = os.path.join(output_path, dest_wav_file_name)

        if (os.path.isfile(dest_wav_file_path)):
          if (i % 50 == 0):
            print(f"Skipped converting {i} cached files in {output_path}")
          continue

        scipy \
            .io \
            .wavfile \
            .write( \
                  dest_wav_file_path, \
                  16000, \
                  (row['audio']['array']*32767).astype(np.int16) \
                  )
  except soundfile.LibsndfileError:
      print(f"Failed to convert row audio data to WAV, skipping file\n")
      i += 1
      # Here's where I learned a long lesson on indentation
      create_wav_files_from_audio_dataset(dataset.skip(i), output_path)
  i = 0
  return

def create_audio_dataset_from_files(path: str, fileext: str) -> IterableDataset:
  fileset:Dataset = datasets \
          .Dataset \
          .from_dict({ \
                    { "audio": \datasets
                     [str(i) for i in Path(f"{path}/audio").glob(f"**/*.{fileext}")]},
                    {"path": \
                     [str(p) for p in Path(f"{path}/audio").glob(f"**/*.{fileext}")]}})
  cast_dataset = fileset.cast_column("audio", .Audio(sampling_rate=16000))
  return cast_dataset.to_iterable_dataset()

# lee@partners.biz - to do - actually handle streaming. casting a column of the
# dataset can be done async, but .to_iterable_dataset forces hydration of the
# entire iterable which defeats the purpose. anywhere datasets functions accept
# a DownloadConfig, give it one. unclear if its possible to avoid downloading
# compressed bits for every file
def create_audio_from_streamed_dataset(unval_dataset:Dataset | IterableDataset | IterableDatasetDict | DatasetDict) -> IterableDataset:
  if (type(unval_dataset) is IterableDataset):
    i_dataset:IterableDataset = unval_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    pathed_dataset = datasets.Dataset.from_dict({ \
                    { "audio": \
                      [datasets.Audio(a["audio"]) for a in unval_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000)) \
                      ] \
                    }, \
                    {"path": \
                     [str(p['__url__']) for p in next(iter(unval_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))))]}}).to_iterable_dataset()
    return pathed_dataset
  elif (type(unval_dataset) is Dataset):
    d_dataset:Dataset = unval_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    pathed_dataset = datasets.Dataset.from_dict({ \
                    { "audio": \
                    [Audio(a["audio"]) for a in next(iter(d_dataset))] \
                    }, \
                    {"path": \
                     [str(p['__url__']) for p in next(iter(d_dataset))]}}).to_iterable_dataset()
  else:
    raise TypeError("Provided Dataset did not have an iterator")
  return pathed_dataset

def stream_audio_from_hydrated_dataset(dataset:DatasetDict) -> Dataset:
  dict_audio = dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
  return datasets.Dataset.from_dict({"audio": Audio(a) for a in dict_audio['audio']})

# Use this if streaming=True, use the function above otherwiwe
def stream_dataset_audio(unval_dataset:Dataset | IterableDataset | IterableDatasetDict | DatasetDict):
  if (type(unval_dataset) is IterableDataset):
    dataset:IterableDataset = unval_dataset
  else:
    raise TypeError("Provided Dataset did not have an iterator and cannot be streamed")

  for data_row in next(iter(dataset)):
    yield data_row.cast_column("audio", datasets.Audio(sampling_rate=16000))

################################################
# Simple dataset dowwload config provider
################################################

def get_download_config(dataset_name: str) -> DownloadConfig:
  use_hf_secret_for_token = True
  hf_api_token: str = ""

  try:
    hf_api_token = userdata.get(f"{hf_api_token_secret_name}")
  except userdata.SecretNotFoundError:
    use_hf_secret_for_token = False
  dataset_download_config = datasets.DownloadConfig(
      cache_dir=f"{archive_download_path}/{dataset_name}",
      resume_download=True,
      token=hf_api_token if use_hf_secret_for_token else False
      )

  return dataset_download_config

################################################
# Download MIT RIR data
################################################

print(f"Retrieving MIT RIR dataset\n");

rir_ddict = datasets.load_dataset("davidscripka/MIT_environmental_impulse_responses", \
                                  split="train", \
                                  streaming=True, \
                                  download_config=get_download_config("mit_rir"))

if (type(rir_ddict) is DatasetDict):
  rir_audio_dataset = create_audio_from_streamed_dataset(rir_ddict)
  dataset.cast_column("audio", Audio(sampling_rate=16000))

  create_wav_files_for_dataset(rir_audio_dataset, rir_wav_path)

################################################
# Audioset annoated audio training data
# (https://research.google.com/audioset/dataset/index.html)
# Download the audioset .tar files, extract, and convert to 16khz
# For full-scale training, it's recommended to download the entire dataset from
# https://huggingface.co/datasets/agkphysics/AudioSet, and
# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)
################################################

print(f"Retrieving Audioset dataset\n");

hf_audioset_dataset = datasets.load_dataset("agkphysics/Audioset", split="train", streaming=True, download_config=get_download_config(f"{os.rmdir}"))
hf_audioset_audio_dataset = create_audio_from_streamed_dataset(hf_audioset_dataset)

create_wav_files_from_audio_dataset(hf_audioset_audio_dataset, audioset_wav_path)


# 10 .tar archives yields ~16,000 clips
# bal_range = list(range(0, 9))
# random.shuffle(bal_range)

# b = 0
# cached_count = 0

# for archhiveNum in bal_range:
#   fname = f"bal_train0{archhiveNum}.tar"
#   link = f"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/{fname}"

#   out_file_path = os.path.join(archive_download_path, fname)
#   total_archives = bal_range.count

#   if not os.path.isfile(out_file_path):
#     print(f"Downloading {fname} ({++b} of {total_archives}....\n")
#     !wget -q -c --show-progress -O {out_file_path} {link}

#     print(f"Extracting {out_file_path}...\n")
#     !tar -xf {out_file_path} -C {archive_download_path}

#     # Downloading the full set pushes storage limits, delete archive but recreate empty file
#     # so it will not download again in case of rerun
#     !rm -f {out_file_path} && touch {out_file_path}
#     print(f"Deleted audio set archive {out_file_path}\n")
#   else:
#     if (++cached_count % 25 == 0):
#       print(f"Skipped {cached_count} extracted files from {fname} (cached)\n")

# # Save clips to 16-bit PCM wav files
# audioset_dataset = create_audio_dataset_from_files(f"{archive_download_path}/audio/bal_train", "flac")
# audio_convert_wav(audioset_dataset, audioset_wav_path, "flac")


################################################
# Free Music Archive dataset
# https://github.com/mdeff/fma
# (Third-party mchl914 extra small set)
################################################

print(f"Retrieving FMA dataset\n");

fname = "fma_xs.zip"
link = "https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/" + fname
fma_archive_path = f"{archive_download_path}/{fname}"

if not os.path.isfile(fma_archive_path):

  print(f"Downloading {link}\n")
  !wget -q -c --show-progress -O {fma_archive_path} {link}

  print(f"Unzipping {fname}\n")
  !unzip -q {fma_archive_path} -d {archive_download_path}

  print(f"Deleting {fname}\n")
  !rm -f {fma_archive_path} && touch {fma_archive_path}

# Save clips to 16-bit PCM wav files
fma_audio_dataset = create_audio_dataset_from_files(f"{archive_download_path}/fma_small", "mp3")
create_wav_files_from_audio_dataset(fma_audio_dataset, fma_wav_path)

Retrieving MIT RIR dataset



Resolving data files:   0%|          | 0/270 [00:00<?, ?it/s]

Retrieving Audioset dataset



Resolving data files:   0%|          | 0/882 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
################################################################################
# Set up the augmentations
# To improve your model, experiment with these settings and use more sources of
# background clips.
################################################################################

from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

clips = Clips(input_directory=f"{generated_samples_output_dir}",
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
              )
augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities = {
                                "SevenBandParametricEQ": 0.1,
                                "TanhDistortion": 0.1,
                                "PitchShift": 0.1,
                                "BandStopFilter": 0.1,
                                "AddColorNoise": 0.1,
                                "AddBackgroundNoise": 0.75,
                                "Gain": 1.0,
                                "RIR": 0.5,
                            },
                         impulse_paths = ['mit_rirs'],
                         background_paths = ['fma_16k', 'audioset_16k'],
                         background_min_snr_db = -5,
                         background_max_snr_db = 10,
                         min_jitter_s = 0.195,
                         max_jitter_s = 0.205,
                         )

In [None]:
# Augment a random clip and play it back to verify it works well

from IPython.display import Audio
from microwakeword.audio.audio_utils import save_clip

random_clip = clips.get_random_clip()
augmented_clip = augmenter.augment_clip(random_clip)
save_clip(augmented_clip, "augmented_clip.wav")

Audio("augmented_clip.wav", autoplay=True)

IndexError: Invalid key: 0 is out of bounds for size 0

In [None]:
################################################################################
# Augment samples and save the training, validation, and testing sets.
# Validating and testing samples generated the same way can make the model
# benchmark better than it performs in real-word use. Use real samples or TTS
# samples generated with a different TTS engine to potentially get more accurate
# benchmarks.
################################################################################

import os
from mmap_ninja.ragged import RaggedMmap

if not os.path.exists("generated_augmented_features"):
  os.mkdir("generated_augmented_features")

generated_features_output_dir = f"generated_augmented_features/{target_word_friendly}"

if not os.path.exists(generated_features_output_dir):
    os.mkdir(generated_features_output_dir)

splits = ["training", "validation", "testing"]
for split in splits:
  out_dir = os.path.join(generated_features_output_dir, split)
  if not os.path.exists(out_dir):
      os.mkdir(out_dir)


  split_name = "train"
  repetition = 2

  spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=10,    # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.
                                     step_ms=10,
                                     )
  if split == "validation":
    split_name = "validation"
    repetition = 1
  elif split == "testing":
    split_name = "test"
    repetition = 1
    spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=1,    # The testing set uses the streaming version of the model, so no artificial repetition is necessary
                                     step_ms=10,
                                     )

  RaggedMmap.from_generator(
      out_dir=os.path.join(out_dir, f'{target_word_friendly}_mmap'),
      sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),
      batch_size=92,
      verbose=True,
  )

In [None]:
################################################################################
# Download pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets.
# This can be slow!
################################################################################

output_dir = './negative_datasets'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
    filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']
    for fname in filenames:
        link = link_root + fname

        zip_path = f"negative_datasets/{fname}"
        !wget -O {zip_path} {link}
        !unzip -q {zip_path} -d {output_dir}

In [None]:
################################################################################
# Save the training configuration YAML file
# IMPORTANT! These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increasing the number of
# training steps.
################################################################################

import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = (
    f"trained_models/{target_word_friendly}"
)


# Each feature_dir should have at least one of the following folders with this structure:
#  training/
#    ragged_mmap_folders_ending_in_mmap
#  testing/
#    ragged_mmap_folders_ending_in_mmap
#  testing_ambient/
#    ragged_mmap_folders_ending_in_mmap
#  validation/
#    ragged_mmap_folders_ending_in_mmap
#  validation_ambient/
#    ragged_mmap_folders_ending_in_mmap
#
#  sampling_weight: Weight for choosing a spectrogram from this set in the batch
#  penalty_weight: Penalizing weight for incorrect predictions from this set
#  truth: Boolean whether this set has positive samples or negative samples
#  truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated
#       - random: choose a random portion of the entire spectrogram - useful for long negative samples
#       - truncate_start: remove the start of the spectrogram
#       - truncate_end: remove the end of the spectrogram
#       - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets

config["features"] = [
    {
        "features_dir": f'{generated_features_output_dir}',
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    { # Only used for validation and testing
        "features_dir": "negative_datasets/dinner_party_eval",
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps
config["training_steps"] = [10000]

# Penalizing weight for incorrect class predictions - lists that correspond to training steps
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]

config["learning_rates"] = [
    0.001,
]  # Learning rates for Adam optimizer - list that corresponds to training steps
config["batch_size"] = 92

config["time_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["time_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps
config["freq_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["freq_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps

config["eval_step_interval"] = (
    500  # Test the validation sets after every this many steps
)
config["clip_duration_ms"] = (
    1300  # Maximum length of wake word that the streaming model will accept
)

# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization
# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize
# Available metrics:
#   - "loss" - cross entropy error on validation set
#   - "accuracy" - accuracy of validation set
#   - "recall" - recall of validation set
#   - "precision" - precision of validation set
#   - "false_positive_rate" - false positive rate of validation set
#   - "false_negative_rate" - false negative rate of validation set
#   - "ambient_false_positives" - count of false positives from the split validation_ambient set
#   - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set
config["target_minimization"] = 0.9
config["minimization_metric"] = None  # Set to None to disable

config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [None]:
################################################################################
# Train the model
# When finished, the code will quantize and convert the model to a
# streaming version suitable for on-device detection.
# This will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# Google Colab does not print mini-batch results, so it may appear
# stuck for several minutes.
# This can be slow!
################################################################################

!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 1 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_nonstreaming_quantized 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,64,64,64" \
--repeat_in_block  "1, 1, 1, 1" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32 \
--first_conv_kernel_size 5 \
--stride 3

In [None]:
################################################################################
# Download the tflite model file
# To use on the wake word with ESPHome, you need to write a
# V2 model JSON file. See https://esphome.io/components/micro_wake_word for the
# documentation and
# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for
# examples of the V2 json format. Adjust the probability threshold based on
# the test results obtained.
# After the training is complete, you may also need to increase the Tensor arena
# model size if the model later fails to load in ESPHome.
################################################################################

import os
from google.colab import files

tflite_output_dir = f"trained_models/{target_word_friendly}/tflite_stream_state_internal_quant"

model_tflite_filename =f'trained_models/{target_word_friendly}/{target_word_friendly}.tflite'

os.rename(f'{tflite_output_dir}/stream_state_internal_quant.tflite', model_tflite_filename)

files.download(f"{model_tflite_filename}")