<div align="center">
  <img src="https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png" alt="MicroWakeWord Trainer Logo" width="100" />
  <h1>MicroWakeWord Trainer Docker</h1>
</div>

This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.

**The model generated from this notebook is designed for practical use, but achieving optimal performance will require experimentation with various settings and datasets. The provided scripts and configurations aim to give you a strong foundation to build upon.**

Throughout the notebook, you will find comments suggesting specific settings to modify and experiment with to enhance your model's performance.

By the end of this notebook, you will have:
- A trained TensorFlow Lite model ready for deployment.
- A JSON manifest file to integrate the model with ESPHome.

To use the generated model in ESPHome, refer to the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for integration details. You can also explore example configurations in the [model repository](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2).

In [None]:
# Installs microWakeWord. Be sure to restart the session after this is finished.
import platform
import sys
import os

if platform.system() == "Darwin":
    # `pymicro-features` is installed from a fork to support building on macOS
    !"{sys.executable}" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore

# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter
!"{sys.executable}" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore

# Clone the microWakeWord repository
repo_path = "./microWakeWord"
if not os.path.exists(repo_path):
    print("Cloning microWakeWord repository...")
    !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}

# Ensure the repository exists before attempting to install
if os.path.exists(repo_path):
    print("Installing microWakeWord...")
    !"{sys.executable}" -m pip install -e {repo_path} --root-user-action=ignore
else:
    print(f"Repository not found at {repo_path}. Cloning might have failed.")

In [None]:
# --- GPU Check (Torch + ONNX Runtime) ---

import torch
import onnxruntime as ort

print("🔧 Torch CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("  • Device count:", torch.cuda.device_count())
    print("  • Current device:", torch.cuda.current_device())
    print("  • Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
else:
    print("⚠️  Torch cannot see a GPU — check Docker runtime (--gpus all) and nvidia-container-toolkit")

print("\n🔧 ONNX Runtime Providers:")
try:
    providers = ort.get_available_providers()
    print("  •", providers)
    if "CUDAExecutionProvider" not in providers:
        print("⚠️  CUDAExecutionProvider not available — ONNX will fall back to CPU.")
except Exception as e:
    print("⚠️  Could not query ONNX Runtime providers:", e)


In [None]:
# NVIDIA Linux Docker: generate 1 sample of the target word (robust + CUDA check)

import os, sys, shutil, subprocess, time, platform
from pathlib import Path
from IPython.display import Audio, display

TARGET_WORD = "hey_tater"
REPO_URL = "https://github.com/rhasspy/piper-sample-generator"
REPO_DIR = Path.cwd() / "piper-sample-generator"
MODELS_DIR = REPO_DIR / "models"
MODEL_NAME = "en_US-libritts_r-medium.pt"
MODEL_URL  = f"https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/{MODEL_NAME}"
AUDIO_OUT_DIR = Path.cwd() / "generated_samples"
AUDIO_PATH = AUDIO_OUT_DIR / "0.wav"

def run(cmd, check=True):
    print("→", " ".join(cmd))
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    for line in proc.stdout:
        print(line, end="")
    rc = proc.wait()
    if check and rc != 0:
        raise RuntimeError(f"Command failed with exit code {rc}: {' '.join(cmd)}")
    return rc

def pip_install(*pkgs):
    run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], check=False)
    run([sys.executable, "-m", "pip", "install", *pkgs])

def safe_clone(repo_url, branch=None, dest=REPO_DIR, retries=2):
    if dest.exists() and not (dest / ".git").exists():
        print("⚠️  Found partial clone. Removing…")
        shutil.rmtree(dest, ignore_errors=True)
    if not dest.exists():
        for i in range(retries + 1):
            try:
                cmd = ["git", "clone", "--depth", "1", repo_url, str(dest)]
                if branch:
                    cmd = ["git", "clone", "--depth", "1", "--branch", branch, repo_url, str(dest)]
                run(cmd)
                break
            except Exception as e:
                if i == retries:
                    raise
                print(f"Clone failed ({i+1}/{retries+1}). Retrying in 2s… [{e}]")
                time.sleep(2)

def ensure_model():
    MODELS_DIR.mkdir(parents=True, exist_ok=True)
    mp = MODELS_DIR / MODEL_NAME
    if not mp.exists() or mp.stat().st_size == 0:
        import urllib.request
        print(f"Downloading model to {mp} …")
        with urllib.request.urlopen(MODEL_URL) as r, open(mp, "wb") as f:
            shutil.copyfileobj(r, f)
        if mp.stat().st_size < 100 * 1024:
            raise RuntimeError("Downloaded model looks too small; download may have failed.")
    print(f"✅ Model ready: {mp}")

# 1) Clone main repo (Linux/NVIDIA)
print("Linux/NVIDIA detected — using main piper-sample-generator repo.")
safe_clone(REPO_URL)

# 2) Install deps (GPU ONNX)
#   - piper-phonemize-cross provides phonemization
#   - onnxruntime-gpu enables CUDA (container must have CUDA + drivers)
deps = [
    "piper-phonemize-cross==1.2.1",
    "soundfile",
    "numpy",
    "onnxruntime-gpu>=1.16.0",
]
pip_install(*deps)

# 3) Verify CUDA provider is available
try:
    import onnxruntime as ort
    providers = ort.get_available_providers()
    print(f"ONNX Runtime providers: {providers}")
    if "CUDAExecutionProvider" not in providers:
        print("⚠️ CUDAExecutionProvider not available. "
              "The sample will still run on CPU, but check your NVIDIA container setup "
              "(nvidia-container-toolkit, runtime, and driver).")
except Exception as e:
    print("⚠️ Could not import onnxruntime to verify providers:", e)

# 4) Ensure model present
ensure_model()

# 5) Generate one sample
AUDIO_OUT_DIR.mkdir(parents=True, exist_ok=True)
gen_script = REPO_DIR / "generate_samples.py"
if not gen_script.exists():
    raise FileNotFoundError(f"Missing generator: {gen_script}")

cmd = [
    sys.executable, str(gen_script),
    TARGET_WORD,
    "--max-samples", "1",
    "--batch-size", "1",
    "--output-dir", str(AUDIO_OUT_DIR),
]
run(cmd)

# 6) Play the audio (if the notebook frontend supports it)
if AUDIO_PATH.exists():
    print(f"🎧 Playing {AUDIO_PATH}")
    display(Audio(str(AUDIO_PATH), autoplay=True))
else:
    print(f"Audio file not found at {AUDIO_PATH}")

In [None]:
# Generates a larger amount of wake word samples.
# Start here when trying to improve your model.
# See https://github.com/rhasspy/-m piper-sample-generator for the full set of
# parameters. In particular, experiment with noise-scales and noise-scale-ws,
# generating negative samples similar to the wake word, and generating many more
# wake word samples, possibly with different phonetic pronunciations.

!"{sys.executable}" piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples 50000 \
--batch-size 100 \
--output-dir generated_samples

In [None]:
# Downloads audio data for augmentation. This can be slow!
# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024
#
# **Important note!** The data downloaded here has a mixture of difference
# licenses and usage restrictions. As such, any custom models trained with this
# data should be considered as appropriate for **non-commercial** personal use only.

import os
import scipy.io.wavfile
import numpy as np
from datasets import Dataset, Audio, load_dataset
from pathlib import Path
from tqdm import tqdm
import soundfile as sf

# -----------------------------
# Download and Process MIT RIR
# -----------------------------
output_dir = "./mit_rirs"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    rir_dataset = load_dataset("davidscripka/MIT_environmental_impulse_responses", split="train", streaming=True)
    print(f"Downloading MIT RIR dataset to {output_dir}...")
    for row in tqdm(rir_dataset):
        name = row["audio"]["path"].split("/")[-1]
        scipy.io.wavfile.write(
            os.path.join(output_dir, name), 
            16000, 
            (row["audio"]["array"] * 32767).astype(np.int16)
        )
    print(f"Finished downloading MIT RIR dataset to {output_dir}.\n")
else:
    print(f"{output_dir} already exists. Skipping download.")

# -----------------------------
# Download and Process Audioset
# -----------------------------

# Directory setup
audioset_dir = "./audioset"
output_dir = "./audioset_16k"
os.makedirs(audioset_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

# Full-scale dataset download links
dataset_links = [
    f"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar"
    for i in range(10)
]

# Download and extract each dataset part
for link in dataset_links:
    file_name = link.split("/")[-1]
    out_path = os.path.join(audioset_dir, file_name)
    if not os.path.exists(out_path):
        print(f"Downloading {file_name}...")
        os.system(f"wget --quiet -O {out_path} {link}")
        print(f"Extracting {file_name}...")
        os.system(f"tar -xf {out_path} -C {audioset_dir}")

# Collect all FLAC files for processing
audioset_files = list(Path(audioset_dir).glob("**/*.flac"))
print(f"Number of FLAC files found: {len(audioset_files)}")

if audioset_files:
    corrupted_files = []

    print("Converting Audioset files to 16kHz WAV...")
    for file_path in tqdm(audioset_files, desc="Processing Audioset files"):
        try:
            # Attempt to load the file and handle any errors
            audio, sampling_rate = sf.read(file_path)
            
            if audio is None or len(audio) == 0:
                raise ValueError(f"Empty or invalid audio data in file: {file_path}")

            # Resample audio to 16kHz
            output_path = Path(output_dir) / (file_path.stem + ".wav")
            scipy.io.wavfile.write(
                output_path,
                16000,
                (audio * 32767).astype(np.int16),
            )
        except (sf.LibsndfileError, ValueError, Exception) as e:
            # Log the error and skip the file
            print(f"Error converting {file_path}: {e}")
            corrupted_files.append(str(file_path))

    # Log corrupted files
    if corrupted_files:
        log_path = Path(output_dir) / "audioset_corrupted_files.log"
        with open(log_path, "w") as log_file:
            log_file.writelines(f"{file}\n" for file in corrupted_files)
        print(f"Logged corrupted files to {log_path}")
else:
    print("No FLAC files found in Audioset.")

print("Audioset processing complete!")


# -----------------------------
# Download and Process FMA
# -----------------------------
output_dir = "./fma"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    fname = "fma_xs.zip"
    link = "https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/" + fname
    out_dir = os.path.join(output_dir, fname)
    os.system(f"wget -q -O {out_dir} {link}")
    os.system(f"cd {output_dir} && unzip -q {fname}")

output_dir = "./fma_16k"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

# Save clips to 16-bit PCM wav files
fma_files = list(Path("fma/fma_small").glob("**/*.mp3"))
print(f"Number of MP3 files found: {len(fma_files)}")
if fma_files:
    fma_dataset = Dataset.from_dict({"audio": [str(file) for file in fma_files]})
    fma_dataset = fma_dataset.cast_column("audio", Audio(sampling_rate=16000))

    corrupted_files = []
    print("Converting FMA files to 16kHz WAV...")
    for row in tqdm(fma_dataset):
        try:
            name = row["audio"]["path"].split("/")[-1].replace(".mp3", ".wav")
            scipy.io.wavfile.write(
                os.path.join(output_dir, name), 
                16000, 
                (row["audio"]["array"] * 32767).astype(np.int16)
            )
        except Exception as e:
            print(f"Error converting {row['audio']['path']}: {e}")
            corrupted_files.append(row["audio"]["path"])

    if corrupted_files:
        with open("fma_corrupted_files.log", "w") as log_file:
            log_file.writelines(f"{file}\n" for file in corrupted_files)
else:
    print("No MP3 files found in FMA.")

print("Dataset preparation complete!")

In [None]:
# Sets up the augmentations.
# To improve your model, experiment with these settings and use more sources of
# background clips.

import os
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

def validate_directories(paths):
    for path in paths:
        if not os.path.exists(path):
            print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
            return False
    return True

# Paths to augmented data
impulse_paths = ['mit_rirs']
background_paths = ['fma_16k', 'audioset_16k']

if not validate_directories(impulse_paths + background_paths):
    raise ValueError("One or more required directories are missing.")

clips = Clips(
    input_directory='./generated_samples',
    file_pattern='*.wav',
    max_clip_duration_s=5,
    remove_silence=True,
    random_split_seed=10,
    split_count=0.1,
)

augmenter = Augmentation(
    augmentation_duration_s=3.2,
    augmentation_probabilities={
        "SevenBandParametricEQ": 0.1,
        "TanhDistortion": 0.05,
        "PitchShift": 0.15,
        "BandStopFilter": 0.1,
        "AddColorNoise": 0.1,
        "AddBackgroundNoise": 0.7,
        "Gain": 0.8,
        "RIR": 0.7,
    },
    impulse_paths=impulse_paths,
    background_paths=background_paths,
    background_min_snr_db=5,
    background_max_snr_db=10,
    min_jitter_s=0.2,
    max_jitter_s=0.3,
)


In [None]:
# Augment a random clip and play it back to verify it works well
from pathlib import Path
from IPython.display import Audio
from microwakeword.audio.audio_utils import save_clip

# Ensure output directory exists
output_dir = Path('./augmented_clips')
output_dir.mkdir(exist_ok=True)

try:
    # Get a random clip and apply augmentation
    random_clip = clips.get_random_clip()
    augmented_clip = augmenter.augment_clip(random_clip)
    
    # Save augmented clip to file
    output_file = output_dir / 'augmented_clip.wav'
    save_clip(augmented_clip, output_file)
    print(f"Augmented clip saved to {output_file}")
    
    # Playback augmented clip
    display(Audio(str(output_file), autoplay=True))
except Exception as e:
    print(f"Error during augmentation or playback: {e}")

In [None]:
# Augment samples and save the training, validation, and testing sets.
# Validating and testing samples generated the same way can make the model
# benchmark better than it performs in real-word use. Use real samples or TTS
# samples generated with a different TTS engine to potentially get more accurate
# benchmarks.

import os
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.spectrograms import SpectrogramGeneration

# Output directory for augmented features
output_dir = 'generated_augmented_features'
os.makedirs(output_dir, exist_ok=True)

# Configuration for each split
split_config = {
    "training": {"name": "train", "repetition": 2, "slide_frames": 10},
    "validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
    "testing": {"name": "test", "repetition": 1, "slide_frames": 1},
}

# Generate augmented features for each split
for split, config in split_config.items():
    out_dir = os.path.join(output_dir, split)
    os.makedirs(out_dir, exist_ok=True)
    print(f"Processing {split} set...")

    try:
        # Spectrogram generation configuration
        spectrograms = SpectrogramGeneration(
            clips=clips,
            augmenter=augmenter,
            slide_frames=config["slide_frames"],
            step_ms=10,  # Can parameterize this if needed
        )

        # Generate and save spectrogram features
        RaggedMmap.from_generator(
            out_dir=os.path.join(out_dir, 'wakeword_mmap'),
            sample_generator=spectrograms.spectrogram_generator(
                split=config["name"], repeat=config["repetition"]
            ),
            batch_size=100,  # Can parameterize this if needed
            verbose=True,
        )
        print(f"Completed processing {split} set. Output saved to {out_dir}")
    except Exception as e:
        print(f"Error processing {split} set: {e}")

In [None]:
# Downloads pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets. This can be slow!

import os
import requests
import zipfile
from pathlib import Path
from tqdm import tqdm

# Function to download a file with progress bar
def download_file(url, output_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(output_path, "wb") as f, tqdm(
        desc=f"Downloading {output_path.name}",
        total=total_size,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)
            bar.update(len(chunk))
    print(f"Downloaded: {output_path}")

# Function to extract ZIP files
def extract_zip(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted: {zip_path} to {extract_to}")

# Directory for negative datasets
output_dir = Path('./negative_datasets')
output_dir.mkdir(exist_ok=True)

# Negative dataset URLs
link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']

# Download and extract files
for fname in filenames:
    link = link_root + fname
    zip_path = output_dir / fname

    # Download only if the file doesn't already exist
    if not zip_path.exists():
        try:
            download_file(link, zip_path)
        except Exception as e:
            print(f"Error downloading {fname}: {e}")
            continue

    # Extract the ZIP file
    try:
        extract_zip(zip_path, output_dir)
    except Exception as e:
        print(f"Error extracting {fname}: {e}")


In [None]:
# Save a yaml config that controls the training process
# These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increasing the number of
# training steps.

import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = "trained_models/wakeword"

config["features"] = [
    {
        "features_dir": "generated_augmented_features",
        "sampling_weight": 2.0,  # Increased
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 12.0,  # Adjusted
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 12.0,  # Adjusted
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,  # Balanced
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party_eval",
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

config["training_steps"] = [40000]  # Increased
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]  # Adjusted
config["learning_rates"] = [0.001]  # Adjusted
config["batch_size"] = 128

config["time_mask_max_size"] = [0]  # Enabled SpecAugment
config["time_mask_count"] = [0]
config["freq_mask_max_size"] = [0]
config["freq_mask_count"] = [0]

config["eval_step_interval"] = 500  # Adjusted
config["clip_duration_ms"] = 1500  # Increased

config["target_minimization"] = 0.9
config["minimization_metric"] = None  # Updated
config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [None]:
# Trains a model. When finished, it will quantize and convert the model to a
# streaming version suitable for on-device detection.
# It will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# On Google colab, it doesn't print the mini-batch results, so it may appear
# stuck for several minutes! Additionally, it is very slow compared to training
# on a local GPU.

import os
import sys

# Ensure the library path is correctly set
os.environ['LD_LIBRARY_PATH'] = "/usr/lib/x86_64-linux-gnu:" + os.environ.get('LD_LIBRARY_PATH', '')

# Training command with optimized settings
!"{sys.executable}" -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 1 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_nonstreaming_quantized 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,64,64,64" \
--repeat_in_block "1,1,1,1" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32 \
--first_conv_kernel_size 5 \
--stride 2


In [None]:
import shutil
import json
from IPython.display import FileLink

# Define the source path and desired download location for the TFLite file
source_path = "trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
destination_path = "./stream_state_internal_quant.tflite"

# Copy the TFLite file to the current working directory
shutil.copy(source_path, destination_path)

# Define the JSON file content
json_data = {
    "type": "micro",
    "wake_word": "hey_norman",  # Adjust this if the target_word changes dynamically
    "author": "master phooey",
    "website": "https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker",
    "model": "stream_state_internal_quant.tflite",
    "trained_languages": ["en"],
    "version": 2,
    "micro": {
        "probability_cutoff": 0.97,
        "sliding_window_size": 5,
        "feature_step_size": 10,
        "tensor_arena_size": 30000,
        "minimum_esphome_version": "2024.7.0"
    }
}

# Define the JSON file path
json_path = "./stream_state_internal_quant.json"

# Write the JSON file
with open(json_path, "w") as json_file:
    json.dump(json_data, json_file, indent=2)

# Generate download links for both files
print("Download your files:")
print("TFLite Model:")
display(FileLink(destination_path))
print("\nJSON Metadata:")
display(FileLink(json_path))