# MicroWakeWord V2 Model Trainer for Google Colab

This notebook is specifically for use in Google Colab. The code below will train a basic MicroWakeWord model. It is intended as a **starting point** for advanced users. This notebook needs to be used with Python 3.10, if you elect to use a different version, code changes will be required.

This notebook will produce a very rough V2 .tflite model compatible with ESPHome, with no code changes, though you'll need to experiment with the settings to produce a reliable model. This is especially true if you are training a very short or long wake word (especially short or long wake words may be unuseable with default settings). There are comments inline poniting out the most impactful settings and what they affect.

> **Configuring the notebook**: You must supply a phoentic and directory-friendly name in step 2 below, and change the runtime type to GPU: T4 for free accounts, or optionally A100 for paid.
	If you have a _huggingface.co_ account, you are encouraged to provide your API token in a secret named "HF_TOKEN", however this is not required.

> **Running the notebook**:
	Once you've supplied your wake word values and any settings you'd like to change, **click the run/play button** on the left of the first card and wait for it to finish. If you see an error about PIP dependency resolver, it is safe to ignore. Once step one has completed, click **Runtime**>**Restart session**. Don't run step 1 again after restarting. After restarting (make sure not to re-run step 1), run each card in order starting at step 2.

> **Using the Micro Wake Word V2 .tflite model**:
	Upon completion of this notebook, you will be prompted to download a tflite file. To use this in ESPHome, you need to write a V2 model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for esphome configuration details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for V2 .json examples.

In [None]:
################################################################################
# Install MicroWakeWord 
################################################################################
import os

###############
# You may see an error in the output about PIPs dependency resolver not taking
# installed packages into account: this is safe to ignore.
###############

!git clone https://github.com/kahrendt/microWakeWord.git
!pip install -e ./microWakeWord


In [None]:
################################################################################
# Define wake word, set up some output directories, and generate a sample of 
# the wake word to play for verifation
#
# Use a 'phonetic' spelling of your wake word when setting target_word below
# Once this step completes, you'll here a sample of what the expected pronunciation will be.
# You are free to adjust target_word and re-run this step until the wake word sounds 
# like what you expect to speak
################################################################################

target_word = "hey wahbee"  # Phonetic spellings produce better samples for many words (try running this sample, and then change back to 'hey wabi' and run again for an example)
target_word_friendly = "hey_wabi" # Directory safe non-phonetic spelling of wake word (use _ in place of spaces)

###############
# If you have a Hugging Face API key, make sure it's in a Colab secret named "HF_SECRET"
###############

hf_api_token_secret_name = "HF_SECRET"


###############
# Should not need to change these
###############

archive_download_path = "./tmp"

rir_wav_path = "./mit_rirs_16k"

audioset_wav_path = "./audioset_16k"
# AudioSet is an enormous dataset, limit to 35% by default, which is enough to 
# produce a great model and will save time. Change to 100 for maximum precision 
audioset_sample_limit_pct = 35

fma_wav_path = "./fma_16k"

expected_dirs = [ archive_download_path, rir_wav_path, audioset_wav_path, fma_wav_path ]

for dir in expected_dirs:
  if not os.path.exists(dir):
    os.mkdir(dir)

output_ext = ".wav"

generated_samples_output_dir = f"generated_samples/{target_word_friendly}"

print(f"Using {generated_samples_output_dir} for Piper generated wake word samples")

import os
import sys
from IPython.display import Audio

if not os.path.exists("./piper-sample-generator"):
  !git clone https://github.com/rhasspy/piper-sample-generator

if not os.path.isfile("./piper-sample-generator/models/en_US-libritts_r-medium.pt"):
  !wget -q -c --show-progress -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'

  # Install system dependencies
  !pip install torch
  !pip install torchaudio
  !pip install piper-phonemize

if "piper-sample-generator/" not in sys.path:
  sys.path.append("piper-sample-generator/")

!python3 piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples 1 \
--batch-size 1 \
--output-dir "{generated_samples_output_dir}"

Audio(f"{generated_samples_output_dir}/0.wav", autoplay=True)

In [None]:
################################################################################
# Generates a large number of wake word samples
#
# Start here when trying to improve your model.
# See https://github.com/rhasspy/piper-sample-generator for the full set of
# parameters. In particular, experiment with noise-scales and noise-scale-ws,
# generating negative samples similar to the wake word, and generating many more
# wake word samples, possibly with different phonetic pronunciations.
################################################################################

!python3 piper-sample-generator/generate_samples.py "{target_word}" \
--model piper-sample-generator/models/en_US-libritts_r-medium.pt \
--max-samples 2048 \
--batch-size 92 \
--max-speakers 777 \
--noise-scales 0.667 \
--noise-scale-ws 0.667 \
--length-scales 0.85 \
--length-scales 1.0 \
--length-scales 1.15 \
--output-dir "{generated_samples_output_dir}"

In [None]:
################################################################################
# Download audio data for augmentation
#
# Based on openWakeWord's automatic_model_training.ipynb; March 4, 2024
#
# IMPORTANT! The data downloaded here has a mixture of different
# licenses and usage restrictions. As such, any custom models trained with this
# data should be considered as appropriate for personal use only. You must seek
# a license with rights' owners for any commercial use.
#
# This can take 30-45 minutes
################################################################################

import datasets
import scipy
import os
import soundfile
import numpy as np
from dataclasses import dataclass

from pathlib import Path
from datasets import IterableDataset, Dataset, DownloadConfig, load_dataset, DatasetInfo
from typing import Dict, Any
from google.colab import userdata

from __future__ import generators

@dataclass
class ExistingFileMapping:
    FilesPresent: Dict[str, Path]
    InitialPath: str 

class FileMapper:
    file_mapping: ExistingFileMapping
    inital_path: str

    def __init__(self, path: str, file_mapping: ExistingFileMapping = ExistingFileMapping({},"")):
        self.file_mapping = ExistingFileMapping(FilesPresent={}, InitialPath=path)
        
        if (file_mapping is not None):
          self.file_mapping = file_mapping
        
        if (not os.path.isdir(path)):
           os.mkdir(path)
        
        self.inital_path = path

    def __getattr__(self, item):
        return getattr(self.file_provider, item)
    
    def find_files_under_path_by_ext(self, file_ext: str):
       if (self.file_mapping is None or not len(self.file_mapping.InitialPath) > 1 or not len(file_ext) > 1):
        self.file_mapping.FilesPresent = {}
        return
       
       self.file_mapping.FilesPresent = {x.name:x for x in Path(f"{self.file_mapping.InitialPath}/").glob(f"**/*.{file_ext}")}

    def contains_filename(self, filename:str) -> bool:
        return (self.file_mapping.FilesPresent.get(filename) is not None)


################################################
# Local WAV file producer for datasets
# Avoids an "issue" in FLAC that can cause libsndfile errors with try except recursion
#
# Should the script fail or session get interrupted during the download and conversion process (which can 
# be long running), we will not reprocess anything that already exists
#
# Note: this function preserves dataset async streaming, use streaming=true where possible when loading datasets
# Warning: if you call Dataset.to_iterable_dataset, to_list, etc, the entire dataset will be downloaded and hydrated syncronously
################################################
def create_wav_files_from_dataset(dataset:Dataset, output_path: str):
  i = 0

  # Initialize small class to handle tracking files of particular type that already exist under output_path
  existing_file_mapper:FileMapper = FileMapper(output_path)

  # create file_mapper dict with existing .wav files under output directory
  existing_file_mapper.find_files_under_path_by_ext(output_ext)

  dataset_info:DatasetInfo = dataset.info

  dataset_info_output_path = f"{output_path}/dataset_info"

  if not os.path.exists(dataset_info_output_path):
    os.mkdir(dataset_info_output_path)

  dataset_info.write_to_directory(f"{output_path}/dataset_info")
  dataset_size_bytes:int = dataset_info.size_in_bytes or 0
  dataset_num_rows:int = 0
  
  if hasattr(dataset_info, 'num_samples'):
    dataset_num_rows = dataset_info.num_samples
  elif type(dataset) is Dataset:
    dataset_num_rows = dataset.num_rows

  # at the cost of readibility, use builtin filter to reduce dataset to only files that still need to be converted and saved
  # function(example: Union[Dict, Any]) -> bool if with_indices=False, batched=False
  if "path" in dataset:
    unprocessed_audio = dataset.filter(lambda ds_row: not existing_file_mapper.contains_filename(os.path.basename(ds_row["file"])), \
                                        input_columns=[ "path" ], \
                                        with_indices=False)
  else:
    unprocessed_audio = dataset 
  
  # casting column is the last thing to happen before writing out .wavs to disk
  # shuffle here to avoid shuffling in a recursive loop
  unprocessed_audio = unprocessed_audio \
                        .cast_column("audio", datasets.Audio(sampling_rate=16000)) \
                        .shuffle(seed=42)
  
  ds_size_kb = round(dataset_size_bytes / 1024)
  ds_size_disp = f"{round(ds_size_kb / 1024, ndigits=2)}MB" if ds_size_kb > 1024 else f"{ds_size_kb}KB"

  
  print("Processing Dataset")
  print(f"\tDataset size: {ds_size_disp}")
  
  if dataset_num_rows > 0:
    print(f"\tConverting {dataset_num_rows} clips to 16-bit WAV...\n")


  try:
    for row in unprocessed_audio:
        i += 1
        
        # if dataset doesn't have path column we're never saving time just create a file every time
        dest_wav_file_name = f"{i}{output_ext}"
        
        # but if it does, write out a wav with the original filename
        if "path" in row:
          dest_wav_file_name = os.path.splitext( \
                                                  os.path.basename(row["path"]))[0] \
                                                  + f"{output_ext}"
        
        dest_wav_file_path = os.path.join(output_path, dest_wav_file_name)

        scipy \
            .io \
            .wavfile \
            .write(filename=dest_wav_file_path, \
                  rate=16000, \
                  data=(row["audio"]["array"]*32767).astype(np.int16))
  except soundfile.LibsndfileError:
      print("Failed to convert row audio data to WAV, skipping file\n")
      # Here's where I learned a long lesson on indentation
      create_wav_files_from_dataset(dataset.skip(i + 1), output_path)
  return


################################################
# Reference only: manually create dataset from local files
# Instead use load_dataset, which has a handler for local files
################################################
def create_dataset_from_files_under_path(path: str, fileext: str):
  file_dict:Dict[str, Path] = { x.name:x for x in Path(f"{path}/").glob(f"**/*.{fileext}")}
  files_dataset = datasets.Dataset.from_dict({ "audio": [Path(p) for p in file_dict.values]})    
  
  return files_dataset.cast_column("audio", Audio(sampling_rate=16000))


################################################
# Simple dataset dowwload config provider
################################################

def get_download_config(dataset_name: str) -> DownloadConfig:
  use_hf_secret_for_token = True
  hf_api_token: str = ""

  try:
    hf_api_token = userdata.get(f"{hf_api_token_secret_name}")
  except userdata.SecretNotFoundError:
    use_hf_secret_for_token = False
  dataset_download_config = datasets.DownloadConfig(
      cache_dir=f"{dataset_name}/cache",
      token=hf_api_token if use_hf_secret_for_token else False,
      )

  return dataset_download_config

################################################
# Load a dataset
# If you know there is no path data, set save_infos=true
# and we'll try to salvage what we can from DatasetInfo
#
# For now, default torch_format to False as we're using
# datasets.IterableDataset, rather than a torch model
################################################
def retrieve_dataset(path:str, \
                     download_config:DownloadConfig, \
                     split:str = "train", \
                     sample_limit_pct:int = 100, \
                     streaming:bool = True, \
                     save_infos:bool = False, \
                     torch_format:bool = False):
  
  if (sample_limit_pct < 100):
     split = f"train[:{sample_limit_pct}%]"

  loaded_ds = datasets.load_dataset(path, \
                                  split=split, \
                                  streaming=streaming, \
                                  download_config=download_config, \
                                  save_infos=save_infos,
                               )

  return loaded_ds.with_format("torch") if torch_format else loaded_ds

################################################
# Download MIT RIR data
################################################

print("Retrieving MIT RIR dataset\n")

download_config=get_download_config("mit_rir")
rir_dataset = retrieve_dataset(path = "davidscripka/MIT_environmental_impulse_responses", \
                               download_config = download_config)

create_wav_files_from_dataset(rir_dataset, rir_wav_path)


################################################
# Audioset annoated audio training data
# (https://research.google.com/audioset/dataset/index.html)
# Download the audioset .tar files, extract, and convert to 16khz
# For full-scale training, it's recommended to download the entire dataset from
# https://huggingface.co/datasets/agkphysics/AudioSet, and
# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)
################################################

print("Retrieving Audioset dataset\n")

download_config=get_download_config("audioset")
audioset_dataset = retrieve_dataset(path = "agkphysics/Audioset", \
                               download_config = download_config, \
                               sample_limit_pct = audioset_sample_limit_pct)

create_wav_files_from_dataset(audioset_dataset, audioset_wav_path)


################################################
# Free Music Archive dataset
# https://github.com/mdeff/fma
# (Third-party mchl914 extra small set)
################################################

print("Retrieving FMA dataset\n")

fma_archive_filename = "fma_xs.zip"
fma_archive_url = f"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/{fma_archive_filename}"

fma_dataset = load_dataset("audiofolder", \
                          data_files=f"{fma_archive_url}", \
                          DownloadConfig=get_download_config("fma"))

create_wav_files_from_dataset(fma_dataset, fma_wav_path)

In [None]:
################################################################################
# Set up the augmentations
# To improve your model, experiment with these settings and use more sources of
# background clips.
################################################################################

from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

clips = Clips(input_directory=f"{generated_samples_output_dir}",
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
              )
augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities = {
                                "SevenBandParametricEQ": 0.1,
                                "TanhDistortion": 0.1,
                                "PitchShift": 0.1,
                                "BandStopFilter": 0.1,
                                "AddColorNoise": 0.1,
                                "AddBackgroundNoise": 0.75,
                                "Gain": 1.0,
                                "RIR": 0.5,
                            },
                         impulse_paths = ['mit_rirs'],
                         background_paths = ['fma_16k', 'audioset_16k'],
                         background_min_snr_db = -5,
                         background_max_snr_db = 10,
                         min_jitter_s = 0.195,
                         max_jitter_s = 0.205,
                         )

In [None]:
# Augment a random clip and play it back to verify it works well

from IPython.display import Audio
from microwakeword.audio.audio_utils import save_clip

random_clip = clips.get_random_clip()
augmented_clip = augmenter.augment_clip(random_clip)
save_clip(augmented_clip, "augmented_clip.wav")

Audio("augmented_clip.wav", autoplay=True)

In [None]:
################################################################################
# Augment samples and save the training, validation, and testing sets.
# Validating and testing samples generated the same way can make the model
# benchmark better than it performs in real-word use. Use real samples or TTS
# samples generated with a different TTS engine to potentially get more accurate
# benchmarks.
#
# lelando - to-do: since mww seems to support streaming, and training data is 
# streaming, next step should be eliminating .wav output except as necessary
# to e.g Audio[autoplay=true]
################################################################################

import os
from mmap_ninja.ragged import RaggedMmap

if not os.path.exists("generated_augmented_features"):
  os.mkdir("generated_augmented_features")

generated_features_output_dir = f"generated_augmented_features/{target_word_friendly}"

if not os.path.exists(generated_features_output_dir):
    os.mkdir(generated_features_output_dir)

splits = ["training", "validation", "testing"]
for split in splits:
  out_dir = os.path.join(generated_features_output_dir, split)
  if not os.path.exists(out_dir):
      os.mkdir(out_dir)

  split_name = "train"
  repetition = 2

  spectrograms = SpectrogramGeneration(clips=clips, \
                                     augmenter=augmenter, \
                                     slide_frames=10, \   # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.
                                     step_ms=10)
  if split == "validation":
    split_name = "validation"
    repetition = 1
  elif split == "testing":
    split_name = "test"
    repetition = 1
    spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=1,    # The testing set uses the streaming version of the model, so no artificial repetition is necessary
                                     step_ms=10,
                                     )

  RaggedMmap.from_generator(
      out_dir=os.path.join(out_dir, f'{target_word_friendly}_mmap'),
      sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),
      batch_size=92,
      verbose=True,
  )

In [None]:
################################################################################
# Download pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets.
# This can be slow!
#
# lelando - to-do: use retrieve_dataset
################################################################################

output_dir = './negative_datasets'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
    filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']
    for fname in filenames:
        link = link_root + fname

        zip_path = f"negative_datasets/{fname}"
        !wget -O {zip_path} {link}
        !unzip -q {zip_path} -d {output_dir}

In [None]:
################################################################################
# Save the training configuration YAML file
# IMPORTANT! These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increasing the number of
# training steps.
################################################################################

import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = (
    f"trained_models/{target_word_friendly}"
)


# Each feature_dir should have at least one of the following folders with this structure:
#  training/
#    ragged_mmap_folders_ending_in_mmap
#  testing/
#    ragged_mmap_folders_ending_in_mmap
#  testing_ambient/
#    ragged_mmap_folders_ending_in_mmap
#  validation/
#    ragged_mmap_folders_ending_in_mmap
#  validation_ambient/
#    ragged_mmap_folders_ending_in_mmap
#
#  sampling_weight: Weight for choosing a spectrogram from this set in the batch
#  penalty_weight: Penalizing weight for incorrect predictions from this set
#  truth: Boolean whether this set has positive samples or negative samples
#  truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated
#       - random: choose a random portion of the entire spectrogram - useful for long negative samples
#       - truncate_start: remove the start of the spectrogram
#       - truncate_end: remove the end of the spectrogram
#       - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets

config["features"] = [
    {
        "features_dir": f'{generated_features_output_dir}',
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    { # Only used for validation and testing
        "features_dir": "negative_datasets/dinner_party_eval",
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps
config["training_steps"] = [10000]

# Penalizing weight for incorrect class predictions - lists that correspond to training steps
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]

config["learning_rates"] = [
    0.001,
]  # Learning rates for Adam optimizer - list that corresponds to training steps
config["batch_size"] = 92

config["time_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["time_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps
config["freq_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["freq_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps

config["eval_step_interval"] = (
    500  # Test the validation sets after every this many steps
)
config["clip_duration_ms"] = (
    1300  # Maximum length of wake word that the streaming model will accept
)

# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization
# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize
# Available metrics:
#   - "loss" - cross entropy error on validation set
#   - "accuracy" - accuracy of validation set
#   - "recall" - recall of validation set
#   - "precision" - precision of validation set
#   - "false_positive_rate" - false positive rate of validation set
#   - "false_negative_rate" - false negative rate of validation set
#   - "ambient_false_positives" - count of false positives from the split validation_ambient set
#   - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set
config["target_minimization"] = 0.9
config["minimization_metric"] = None  # Set to None to disable

config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [None]:
################################################################################
# Train the model
# When finished, the code will quantize and convert the model to a
# streaming version suitable for on-device detection.
# This will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# Google Colab does not print mini-batch results, so it may appear
# stuck for several minutes.
# This can be slow!
################################################################################

!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 1 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_nonstreaming_quantized 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,64,64,64" \
--repeat_in_block  "1, 1, 1, 1" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32 \
--first_conv_kernel_size 5 \
--stride 3

In [None]:
################################################################################
# Download the tflite model file
# To use on the wake word with ESPHome, you need to write a
# V2 model JSON file. See https://esphome.io/components/micro_wake_word for the
# documentation and
# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for
# examples of the V2 json format. Adjust the probability threshold based on
# the test results obtained.
# After the training is complete, you may also need to increase the Tensor arena
# model size if the model later fails to load in ESPHome.
################################################################################

import os
from google.colab import files

# Prompt user to download their model .tflite. (to-do?) Might be nice to implemnt something to generate
# a basic v2 .json since we have all the relavent values, then add both files to an archive and 
# offer that for download instead 
tflite_output_dir = f"trained_models/{target_word_friendly}/tflite_stream_state_internal_quant"
model_tflite_filename =f'trained_models/{target_word_friendly}/{target_word_friendly}.tflite'
os.rename(f'{tflite_output_dir}/stream_state_internal_quant.tflite', model_tflite_filename)

files.download(f"{model_tflite_filename}")