# Training a microWakeWord Model

This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.

**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**

In the comment at the start of certain blocks, I note some specific settings to consider modifying.

This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!

At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples.

In [None]:
# Installs microWakeWord. Be sure to restart the session after this is finished.
import platform
!pip install edge-tts

if platform.system() == "Darwin":
    # `pymicro-features` is installed from a fork to support building on macOS
    !pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'

# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter
!pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'

!git clone https://github.com/kahrendt/microWakeWord
!pip install -e ./microWakeWord
!pip install piper piper-tts

In [None]:
!wget https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/amy/medium/en_US-amy-medium.onnx
!wget https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/amy/medium/en_US-amy-medium.onnx.json
!wget https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/kusal/medium/en_US-kusal-medium.onnx
!wget https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/kusal/medium/en_US-kusal-medium.onnx.json

In [None]:
# Generates 1 sample of the target word for manual verification.

target_word = 'hey bob'  # Phonetic spellings may produce better samples

import os
import sys
import platform

from IPython.display import Audio

if not os.path.exists("./piper-sample-generator"):
    if platform.system() == "Darwin":
        !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator
    else:
        !git clone https://github.com/rhasspy/piper-sample-generator

    !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'

    # Install system dependencies
    !pip install torch torchaudio piper-phonemize-cross==1.2.1

    if "piper-sample-generator/" not in sys.path:
        sys.path.append("piper-sample-generator/")

!python3 piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples 1 \
--batch-size 1 \
--output-dir generated_samples


Audio("generated_samples/0.wav", autoplay=True)

In [None]:
# Generates a larger amount of wake word samples.
# Start here when trying to improve your model.
# See https://github.com/rhasspy/piper-sample-generator for the full set of
# parameters. In particular, experiment with noise-scales and noise-scale-ws,
# generating negative samples similar to the wake word, and generating many more
# wake word samples, possibly with different phonetic pronunciations.

# 1. Positive samples for "bob"
!python3 piper-sample-generator/generate_samples.py "hey bob" \
  --max-samples 10000 \
  --batch-size 500 \
  --noise-scales 0.5 1.0 \
  --noise-scale-ws 0.5 1.0 \
  --output-dir /pee/generated_samples

!python3 piper-sample-generator/generate_samples.py "pop rob bop bab mob kop lop dop bok zop aop alexa google cat dog hello start stop boob" \
  --max-samples 5000 \
  --batch-size 500 \
  --noise-scales 0.5 1.0 \
  --noise-scale-ws 0.5 1.0 \
  --output-dir /pee/pee

import edge_tts
import asyncio
import os
import subprocess
import random

output_dir = "/pee/validation"
os.makedirs(output_dir, exist_ok=True)

# Diverse voices across accents
voices = [
    "en-US-AriaNeural", "en-US-GuyNeural", "en-US-JennyNeural", "en-US-TonyNeural",
    "en-GB-LibbyNeural", "en-GB-RyanNeural", "en-GB-SoniaNeural", "en-GB-ThomasNeural",
    "en-IN-NeerjaNeural", "en-IN-PrabhatNeural",
    "en-AU-NatashaNeural", "en-AU-WilliamNeural",
    "en-IE-ConnorNeural", "en-IE-EmilyNeural",
    "en-CA-ClaraNeural", "en-CA-LiamNeural",
    "en-ZA-LeahNeural", "en-ZA-LukeNeural",
    "en-NG-EzinneNeural", "en-NG-AbeoNeural",
    "en-NZ-MitchellNeural", "en-NZ-MollyNeural",
    "en-PH-JamesNeural", "en-PH-RosaNeural",
    "en-KE-AsiliaNeural", "en-KE-ChilembaNeural",
    "en-SG-LunaNeural", "en-SG-WayneNeural",
]

SAMPLES_PER_VOICE = 50  # 50 × ~30 voices = 1500 samples

async def main():
    idx = 0
    for voice in voices:
        for j in range(SAMPLES_PER_VOICE):
            tmp_path = os.path.join(output_dir, f"tmp_{idx}.mp3")
            wav_path = os.path.join(output_dir, f"bob_{idx}_{voice}.wav")

            # Random pitch/rate for natural variation
            rate = f"{random.randint(-20, 20)}%"    # -20% to +20%
            pitch = f"{random.randint(-10, 10)}Hz" # -10Hz to +10Hz

            ssml_text = f"<speak><prosody rate='{rate}' pitch='{pitch}'>bob</prosody></speak>"

            communicate = edge_tts.Communicate(ssml_text, voice)
            await communicate.save(tmp_path)

            # Convert to 16kHz mono WAV
            subprocess.run([
                "ffmpeg", "-y", "-i", tmp_path,
                "-ar", "16000", "-ac", "1", wav_path
            ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

            os.remove(tmp_path)
            print(f"Saved {wav_path}")
            idx += 1

asyncio.run(main())


In [None]:
# Downloads audio data for augmentation. This can be slow!
# Adapted from openWakeWord's automatic_model_training.ipynb

import datasets
import scipy
import os

import numpy as np

from pathlib import Path
from tqdm import tqdm

## ----------------------------
## Download MIT RIR data
## ----------------------------

output_dir = "/pee/mit_rirs"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    rir_dataset = datasets.load_dataset("davidscripka/MIT_environmental_impulse_responses", split="train", streaming=True)
    # Save clips to 16-bit PCM wav files
    for row in tqdm(rir_dataset):
        name = row['audio']['path'].split('/')[-1]
        scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))

## ----------------------------
## Download noise and background audio
## ----------------------------

# AudioSet Dataset (https://research.google.com/audioset/dataset/index.html)
# HuggingFace mirror: https://huggingface.co/datasets/agkphysics/AudioSet

if not os.path.exists("/pee/audioset"):
    os.mkdir("/pee/audioset")

    # Grab two parts of the dataset instead of one
    for fname in ["bal_train09.tar", "bal_train08.tar"]:
        out_dir = f"audioset/{fname}"
        link = "https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/" + fname
        !wget -O {out_dir} {link}
        !cd audioset && tar -xf {fname}

    output_dir = "/pee/audioset_16k"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # Save clips to 16-bit PCM wav files
    audioset_dataset = datasets.Dataset.from_dict({"audio": [str(i) for i in Path("audioset/audio").glob("**/*.flac")]})
    audioset_dataset = audioset_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    for row in tqdm(audioset_dataset):
        name = row['audio']['path'].split('/')[-1].replace(".flac", ".wav")
        scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))

# Free Music Archive dataset (extra small subset)
# https://github.com/mdeff/fma

output_dir = "/pee/fma"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    fname = "fma_xs.zip"
    link = "https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/" + fname
    out_dir = f"fma/{fname}"
    !wget -O {out_dir} {link}
    !cd {output_dir} && unzip -q {fname}

    output_dir = "/pee/fma_16k"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # Save clips to 16-bit PCM wav files
    fma_dataset = datasets.Dataset.from_dict({"audio": [str(i) for i in Path("fma/fma_small").glob("**/*.mp3")]})
    fma_dataset = fma_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    for row in tqdm(fma_dataset):
        name = row['audio']['path'].split('/')[-1].replace(".mp3", ".wav")
        scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))


output_dir = "/pee/fsd50k"
os.makedirs(output_dir, exist_ok=True)

# Stream dataset (does not download all at once)
fsd_dataset = datasets.load_dataset("Fhrozen/FSD50k", split="train", streaming=True)
fsd_dataset = fsd_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
# Download and save only 10k clips
for idx, row in enumerate(tqdm(fsd_dataset, total=10000)):
    if idx >= 10000:   # <-- stop after 10k
        break
    name = f"fsd50k_{idx}.wav"
    scipy.io.wavfile.write(
        os.path.join(output_dir, name),
        16000,
        (row["audio"]["array"] * 32767).astype(np.int16)
    )


In [None]:
# Sets up the augmentations.
# To improve your model, experiment with these settings and use more sources of
# background clips.

from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

# Positive wake word samples ("bob")
clips = Clips(input_directory='/pee/generated_samples',
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
              )

augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities={
                                "SevenBandParametricEQ": 0.6,   # most get EQ
                                "TanhDistortion": 0.3,          # distortion fairly often
                                "PitchShift": 0.4,              # many get pitch shifts
                                "BandStopFilter": 0.3,          # frequent filtering
                                "AddColorNoise": 0.5,           # about half get extra noise
                                "AddBackgroundNoise": 0.8,      # almost all get bg noise
                                "Gain": 0.7,                    # gain adjustment most times
                                "RIR": 0.5,                     # half get reverb
                            },
                         impulse_paths=['mit_rirs'],
                         background_paths=['fma_16k', 'audioset_16k', 'fsd50k'],
                         background_min_snr_db=-5,
                         background_max_snr_db=10,
                         min_jitter_s=0.195,
                         max_jitter_s=0.205,
                         )

# Negative samples ("pee" folder)
clips_neg = Clips(input_directory='/pee/pee',
                  file_pattern='*.wav',
                  max_clip_duration_s=None,
                  remove_silence=False,
                  random_split_seed=10,
                  split_count=0.1,
                  )

augmenter_neg = Augmentation(augmentation_duration_s=3.2,
                              augmentation_probabilities={
                                    "SevenBandParametricEQ": 0.6,
                                    "TanhDistortion": 0.3,
                                    "PitchShift": 0.4,
                                    "BandStopFilter": 0.3,
                                    "AddColorNoise": 0.5,
                                    "AddBackgroundNoise": 0.8,
                                    "Gain": 0.7,
                                    "RIR": 0.5,
                              },
                              impulse_paths=['mit_rirs'],
                              background_paths=['fma_16k', 'audioset_16k', 'fsd50k'],
                              background_min_snr_db=-5,
                              background_max_snr_db=10,
                              min_jitter_s=0.195,
                              max_jitter_s=0.205,
                              )


In [None]:
# Augment a random clip and play it back to verify it works well

from IPython.display import Audio
from microwakeword.audio.audio_utils import save_clip

random_clip = clips.get_random_clip()
augmented_clip = augmenter.augment_clip(random_clip)
save_clip(augmented_clip, 'augmented_clip.wav')

Audio("augmented_clip.wav", autoplay=True)

In [None]:
# Augment samples and save the training, validation, and testing sets.
# Positive = generated_samples/
# Negative = pee/
# Validation = validation/ (pre-generated clean TTS voices, no splitting)

import os
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.spectrograms import SpectrogramGeneration

output_dir = 'generated_augmented_features'
os.makedirs(output_dir, exist_ok=True)

# ---------------------------
# Positive spectrograms
# ---------------------------
def process_split(split, clips, augmenter, tag, external_validation=False):
    """Helper to generate mmap features for a given dataset (positive/negative)."""

    out_dir = os.path.join(output_dir, split)
    os.makedirs(out_dir, exist_ok=True)

    split_name = "train"
    repetition = 2
    slide_frames = 10

    # Validation/test settings
    if split == "validation":
        split_name = "validation"
        repetition = 1
    elif split == "testing":
        split_name = "test"
        repetition = 1
        slide_frames = 1

    # External validation (real TTS voices): skip augmenter
    if external_validation:
        spectrograms = SpectrogramGeneration(
            clips=clips,
            augmenter=None,
            slide_frames=slide_frames,
            step_ms=10,
        )
    else:
        spectrograms = SpectrogramGeneration(
            clips=clips,
            augmenter=augmenter,
            slide_frames=slide_frames,
            step_ms=10,
        )

    RaggedMmap.from_generator(
        out_dir=os.path.join(out_dir, f'{tag}_mmap'),
        sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),
        batch_size=100,
        verbose=True,
    )


# ---------------------------O
# Generate training/val/test
# ---------------------------

splits = ["training", "validation", "testing"]

# Positive samples
for split in splits:
    process_split(split, clips, augmenter, "wakeword")

# Negative samples
for split in splits:
    process_split(split, clips_neg, augmenter_neg, "negative")

# Real TTS validation (clean voices in "validation/" dir, no augmentation)
from microwakeword.audio.clips import Clips
tts_validation = Clips(input_directory='/pee/validation', file_pattern='*.wav')
process_split("validation", tts_validation, None, "tts_validation", external_validation=True)


In [None]:
# Downloads pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets. This can be slow!

output_dir = '/pee/pee'
if os.path.exists(output_dir):
    print('im working negative prt')
    link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
    filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip', 'speech_background.zip']
    for fname in filenames:
        link = link_root + fname

        zip_path = f"/pee/pee/{fname}"
        !wget -O {zip_path} {link}
        !unzip -q {zip_path} -d {output_dir}

In [None]:
import yaml
import os

config = {}

config["window_step_ms"] = 10
config["train_dir"] = "/pee/trained_models/wakeword"

config["features"] = [
    {
        # Training positives
        "features_dir": "generated_augmented_features/training",
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        # Validation positives (TTS Bob, 90% clean + 10% augmented)
        "features_dir": "generated_augmented_features/validation",
        "sampling_weight": 0.0,   # eval only
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        # Testing positives (streaming, clean)
        "features_dir": "generated_augmented_features/testing",
        "sampling_weight": 0.0,   # eval only
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "/pee/pee/speech",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "/pee/pee/dinner_party",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "/pee/pee/no_speech",
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    { # Ambient eval negatives (for validation/testing false alarms)
        "features_dir": "/pee/pee/dinner_party_eval",
        "sampling_weight": 0.0,   # eval only
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
    { # Extra validation_ambient set for silent/noisy conditions
        "features_dir": "/pee/pee/no_speech",
        "sampling_weight": 0.0,   # eval only
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

# Training schedule
config["training_steps"] = [20000]
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]

config["learning_rates"] = [0.001]
config["batch_size"] = 64

# SpecAugment disabled
config["time_mask_max_size"] = [0]
config["time_mask_count"] = [0]
config["freq_mask_max_size"] = [0]
config["freq_mask_count"] = [0]

# Eval settings
config["eval_step_interval"] = 500
config["clip_duration_ms"] = 1500

# Model selection criteria
config["target_minimization"] = 0.9
config["minimization_metric"] = None
config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    yaml.dump(config, file)


In [None]:
# Trains a model. When finished, it will quantize and convert the model to a
# streaming version suitable for on-device detection.
# It will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# On Google colab, it doesn't print the mini-batch results, so it may appear
# stuck for several minutes! Additionally, it is very slow compared to training
# on a local GPU.

!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 0 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_nonstreaming_quantized 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,64,64,64" \
--repeat_in_block  "1, 1, 1, 1" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32 \
--first_conv_kernel_size 5 \
--stride 3

In [None]:
# Downloads the tflite model file. To use on the device, you need to write a
# Model JSON file. See https://esphome.io/components/micro_wake_word for the
# documentation and
# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for
# examples. Adjust the probability threshold based on the test results obtained
# after training is finished. You may also need to increase the Tensor arena
# model size if the model fails to load.

from google.colab import files

files.download(f"/pee/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite")