In [None]:
# 🥔 MicroWakeWord Trainer — Tater Totterson Edition
# ==================================================
# Welcome, friend! 👋 This notebook will help you train your very own wake word model.
# Think of it like teaching Tater Totterson to recognize when you say a special word.
#
# By the end, you'll have:
#   ✅ A trained TensorFlow Lite model ready for on-device detection.
#   ✅ A matching JSON manifest you can drop straight into ESPHome.
#
# This flow is optimized for Python 3.10 and NVIDIA GPUs (but should work elsewhere too).
# You can customize the wake word, play with training parameters, and experiment with
# different datasets until you get something that feels just right. 💪
#
# ⚡ Quick Tips:
#   • Change TARGET_WORD below to whatever you want your wake word to be.
#   • Rerun the notebook from the top if you change it (to regenerate everything).
#   • Expect to experiment — tweaking hyperparameters is part of the fun!
#
# When you’re done, you’ll get two files:
#   1️⃣ <wakeword>.tflite — your trained model.
#   2️⃣ <wakeword>.json — a manifest for ESPHome integration.
#
# More info & examples:
# 🔗 https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker

# --- Set your wake word here ---
TARGET_WORD = "hey_tater"  # 🗣️ Change this to whatever phrase you want!
print(f"🥔 Tater Totterson is listening for: '{TARGET_WORD}'")

In [None]:
# Installs microWakeWord. Be sure to restart the session after this is finished.
import platform
import sys
import os

if platform.system() == "Darwin":
    # `pymicro-features` is installed from a fork to support building on macOS
    !"{sys.executable}" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore

# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter
!"{sys.executable}" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore

# Clone the microWakeWord repository
repo_path = "./microWakeWord"
if not os.path.exists(repo_path):
    print("Cloning microWakeWord repository...")
    !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}

# Ensure the repository exists before attempting to install
if os.path.exists(repo_path):
    print("Installing microWakeWord...")
    !"{sys.executable}" -m pip install -e {repo_path} --root-user-action=ignore
else:
    print(f"Repository not found at {repo_path}. Cloning might have failed.")

In [None]:
# --- GPU Check (Torch + ONNX Runtime) ---

import torch
import onnxruntime as ort

print("🔧 Torch CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("  • Device count:", torch.cuda.device_count())
    print("  • Current device:", torch.cuda.current_device())
    print("  • Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
else:
    print("⚠️  Torch cannot see a GPU — check Docker runtime (--gpus all) and nvidia-container-toolkit")

print("\n🔧 ONNX Runtime Providers:")
try:
    providers = ort.get_available_providers()
    print("  •", providers)
    if "CUDAExecutionProvider" not in providers:
        print("⚠️  CUDAExecutionProvider not available — ONNX will fall back to CPU.")
except Exception as e:
    print("⚠️  Could not query ONNX Runtime providers:", e)


In [None]:
# NVIDIA Linux Docker: generate 1 sample of the target word (robust + CUDA check)

import os, sys, shutil, subprocess, time, platform
from pathlib import Path
from IPython.display import Audio, display

REPO_URL = "https://github.com/rhasspy/piper-sample-generator"
REPO_DIR = Path.cwd() / "piper-sample-generator"
MODELS_DIR = REPO_DIR / "models"
MODEL_NAME = "en_US-libritts_r-medium.pt"
MODEL_URL  = f"https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/{MODEL_NAME}"
AUDIO_OUT_DIR = Path.cwd() / "generated_samples"
AUDIO_PATH = AUDIO_OUT_DIR / "0.wav"

def run(cmd, check=True):
    print("→", " ".join(cmd))
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    for line in proc.stdout:
        print(line, end="")
    rc = proc.wait()
    if check and rc != 0:
        raise RuntimeError(f"Command failed with exit code {rc}: {' '.join(cmd)}")
    return rc

def pip_install(*pkgs):
    run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], check=False)
    run([sys.executable, "-m", "pip", "install", *pkgs])

def safe_clone(repo_url, branch=None, dest=REPO_DIR, retries=2):
    if dest.exists() and not (dest / ".git").exists():
        print("⚠️  Found partial clone. Removing…")
        shutil.rmtree(dest, ignore_errors=True)
    if not dest.exists():
        for i in range(retries + 1):
            try:
                cmd = ["git", "clone", "--depth", "1", repo_url, str(dest)]
                if branch:
                    cmd = ["git", "clone", "--depth", "1", "--branch", branch, repo_url, str(dest)]
                run(cmd)
                break
            except Exception as e:
                if i == retries:
                    raise
                print(f"Clone failed ({i+1}/{retries+1}). Retrying in 2s… [{e}]")
                time.sleep(2)

def ensure_model():
    MODELS_DIR.mkdir(parents=True, exist_ok=True)
    mp = MODELS_DIR / MODEL_NAME
    if not mp.exists() or mp.stat().st_size == 0:
        import urllib.request
        print(f"Downloading model to {mp} …")
        with urllib.request.urlopen(MODEL_URL) as r, open(mp, "wb") as f:
            shutil.copyfileobj(r, f)
        if mp.stat().st_size < 100 * 1024:
            raise RuntimeError("Downloaded model looks too small; download may have failed.")
    print(f"✅ Model ready: {mp}")

# 1) Clone main repo (Linux/NVIDIA)
print("Linux/NVIDIA detected — using main piper-sample-generator repo.")
safe_clone(REPO_URL)

# 2) Install deps
#   - piper-tts provides the `piper` module (required by generate_samples.py)
#   - piper-phonemize-cross does the phonemization
#   - onnxruntime-gpu enables CUDA (container must have NVIDIA runtime)
deps = [
    "piper-tts>=1.2.0",
    "piper-phonemize-cross==1.2.1",
    "soundfile",
    "numpy",
    "onnxruntime-gpu>=1.16.0",
]
pip_install(*deps)

# 3) Verify CUDA provider is available
try:
    import onnxruntime as ort
    providers = ort.get_available_providers()
    print(f"ONNX Runtime providers: {providers}")
    if "CUDAExecutionProvider" not in providers:
        print("⚠️ CUDAExecutionProvider not available. "
              "The sample will still run on CPU, but check your NVIDIA container setup "
              "(nvidia-container-toolkit, runtime, and driver).")
except Exception as e:
    print("⚠️ Could not import onnxruntime to verify providers:", e)

# 4) Ensure model present
ensure_model()

# 5) Generate one sample
AUDIO_OUT_DIR.mkdir(parents=True, exist_ok=True)
gen_script = REPO_DIR / "generate_samples.py"
if not gen_script.exists():
    raise FileNotFoundError(f"Missing generator: {gen_script}")

cmd = [
    sys.executable, str(gen_script),
    TARGET_WORD,
    "--model", str(MODELS_DIR / MODEL_NAME),  # ← pass the generator .pt explicitly
    "--max-samples", "1",
    "--batch-size", "1",
    "--output-dir", str(AUDIO_OUT_DIR),
]
run(cmd)

# 6) Play the audio (if the notebook frontend supports it)
if AUDIO_PATH.exists():
    print(f"🎧 Playing {AUDIO_PATH}")
    display(Audio(str(AUDIO_PATH), autoplay=True))
else:
    print(f"Audio file not found at {AUDIO_PATH}")

In [None]:
# Generate a large number of wake word samples for training
import sys, subprocess
from pathlib import Path

REPO_DIR = Path.cwd() / "piper-sample-generator"
MODELS_DIR = REPO_DIR / "models"
MODEL_NAME = "en_US-libritts_r-medium.pt"

cmd = [
    sys.executable,
    str(REPO_DIR / "generate_samples.py"),
    TARGET_WORD,
    "--model", str(MODELS_DIR / MODEL_NAME),
    "--max-samples", "50000",
    "--batch-size", "100",
    "--output-dir", "generated_samples",
]

print("→", " ".join(cmd))
subprocess.run(cmd, check=True)

In [None]:
# NVIDIA/Linux dataset prep to match the Apple behavior, but without datasets.Audio (no TorchCodec)
# MIT RIR -> resample to 16 kHz
# AudioSet -> NO resample
# FMA -> resample to 16 kHz mono

import os, sys, scipy.io.wavfile, numpy as np
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
import librosa
from datasets import load_dataset

def write_wav(dst: Path, data: np.ndarray, sr: int):
    x = np.clip(data, -1.0, 1.0)
    scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))

# -----------------------------
# MIT RIR (resample to 16 kHz)
# -----------------------------
print("=== MIT RIR ===")
rir_out = Path("mit_rirs")
rir_out.mkdir(exist_ok=True)
if not any(rir_out.rglob("*.wav")):
    ok = 0
    try:
        # Avoid datasets.Audio to keep TorchCodec out:
        # Use streaming=True + Audio(decode=False)-equivalent: access raw file path and decode with librosa
        print("⬇️ MIT RIR (streaming + manual decode)…")
        ds = load_dataset("davidscripka/MIT_environmental_impulse_responses",
                          split="train", streaming=True)
        for i, row in enumerate(tqdm(ds)):
            try:
                audio_path = row["audio"]["path"]
                y, sr = librosa.load(audio_path, sr=16000, mono=True)
                write_wav(rir_out / f"rir_{i:04d}.wav", y, 16000)
                ok += 1
            except Exception:
                pass
        print(f"✅ MIT RIR saved: {ok} files")
    except Exception as e:
        print(f"⚠️ MIT RIR download failed: {e}")
        # Fallback to official ZIP if needed (rare)
        try:
            print("⬇️ MIT RIR (fallback ZIP)…")
            zip_url = "https://mcdermottlab.mit.edu/Reverb/IRMAudio/Audio.zip"
            zip_path = rir_out.parent / "MIT_RIR_Audio.zip"
            if not zip_path.exists():
                os.system(f"wget -q -O '{zip_path}' '{zip_url}'")
            os.system(f'unzip -q -o "{zip_path}" -d "{rir_out}"')
            # Normalize to 16k mono
            for p in tqdm(list(rir_out.rglob("*.wav")), desc="Normalize MIT RIR"):
                a, sr = sf.read(p, always_2d=False)
                if a.ndim > 1: a = a[:,0]
                if sr != 16000:
                    a, _ = librosa.load(p, sr=16000, mono=True)
                write_wav(p, a, 16000)
            print("✅ MIT RIR fallback complete")
        except Exception as e2:
            print(f"❌ MIT RIR fallback failed: {e2}")
else:
    print("✅ mit_rirs exists; skipping.")

# -----------------------------
# AudioSet (NO resample — fast)
# -----------------------------
print("\n=== AudioSet subset ===")
audioset_dir = Path("audioset"); audioset_dir.mkdir(exist_ok=True)
audioset_out = Path("audioset_16k"); audioset_out.mkdir(exist_ok=True)

links = [f"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar"
         for i in range(10)]
for link in links:
    fname = link.split("/")[-1]
    out_tar = audioset_dir / fname
    if not out_tar.exists():
        print(f"⬇️ {fname}")
        os.system(f"wget -q -O '{out_tar}' '{link}'")
        print(f"📦 Extract {fname}")
        os.system(f"tar -xf '{out_tar}' -C '{audioset_dir}'")

flacs = list(audioset_dir.rglob("*.flac"))
print(f"🔎 FLAC files: {len(flacs)}")
corrupt = []
for p in tqdm(flacs, desc="AudioSet→WAV (no resample)"):
    try:
        a, sr = sf.read(p, always_2d=False)
        if a is None or len(a) == 0:
            raise ValueError("empty audio")
        if a.ndim > 1:
            a = a[:,0]
        # Apple behavior: write as 16-bit and label 16 kHz (no resample)
        write_wav(audioset_out / (p.stem + ".wav"), a, 16000)
    except Exception as e:
        corrupt.append(f"{p}:{e}")
if corrupt:
    (audioset_out / "audioset_corrupted_files.log").write_text("\n".join(corrupt))
print("✅ AudioSet processing complete!")

# -----------------------------
# FMA xsmall (resample to 16 kHz mono)
# -----------------------------
print("\n=== FMA xsmall ===")
fma_zip_dir = Path("fma"); fma_zip_dir.mkdir(exist_ok=True)
fma_out = Path("fma_16k"); fma_out.mkdir(exist_ok=True)

zipname = "fma_xs.zip"
zipurl  = f"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/{zipname}"
zipout  = fma_zip_dir / zipname
if not zipout.exists():
    os.system(f"wget -q -O '{zipout}' '{zipurl}'")
    os.system(f"cd fma && unzip -q '{zipname}'")

mp3s = list(Path("fma/fma_small").rglob("*.mp3"))
print(f"🎵 FMA mp3 count: {len(mp3s)}")
corrupt = []
for p in tqdm(mp3s, desc="FMA→16k WAV"):
    try:
        y, sr = librosa.load(p, sr=16000, mono=True)  # proper decode+resample
        if y.size == 0:
            raise ValueError("empty audio")
        write_wav(fma_out / (p.stem + ".wav"), y, 16000)
    except Exception as e:
        corrupt.append(f"{p}:{e}")
if corrupt:
    Path("fma_corrupted_files.log").write_text("\n".join(corrupt))
print("\n✅ Dataset prep complete!")

In [None]:
# Sets up the augmentations.
# To improve your model, experiment with these settings and use more sources of
# background clips.

import os
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

def validate_directories(paths):
    for path in paths:
        if not os.path.exists(path):
            print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
            return False
    return True

# Paths to augmented data
impulse_paths = ['mit_rirs']
background_paths = ['fma_16k', 'audioset_16k']

if not validate_directories(impulse_paths + background_paths):
    raise ValueError("One or more required directories are missing.")

clips = Clips(
    input_directory='./generated_samples',
    file_pattern='*.wav',
    max_clip_duration_s=5,
    remove_silence=True,
    random_split_seed=10,
    split_count=0.1,
)

augmenter = Augmentation(
    augmentation_duration_s=3.2,
    augmentation_probabilities={
        "SevenBandParametricEQ": 0.1,
        "TanhDistortion": 0.05,
        "PitchShift": 0.15,
        "BandStopFilter": 0.1,
        "AddColorNoise": 0.1,
        "AddBackgroundNoise": 0.7,
        "Gain": 0.8,
        "RIR": 0.7,
    },
    impulse_paths=impulse_paths,
    background_paths=background_paths,
    background_min_snr_db=5,
    background_max_snr_db=10,
    min_jitter_s=0.2,
    max_jitter_s=0.3,
)


In [None]:
# Augment a random generated-sample WAV and play it back (pass ndarray to augmenter)
from pathlib import Path
from IPython.display import Audio, display
import numpy as np
import soundfile as sf
import librosa, random, glob

output_dir = Path("./augmented_clips")
output_dir.mkdir(exist_ok=True)

# 1) Pick a random WAV from the Piper outputs
candidates = glob.glob("generated_samples/*.wav")
if not candidates:
    raise SystemExit("No files in generated_samples/. Run the TTS sample cell first.")
src_path = random.choice(candidates)

# 2) Load as 16 kHz mono float32
y, sr = librosa.load(src_path, sr=16000, mono=True)
y = y.astype(np.float32, copy=False)

# 3) Augment — microwakeword Augmentation expects a 1-D numpy array
try:
    y_aug = augmenter.augment_clip(y)
except Exception as e:
    # some versions accept (samples, sr) — try that as a fallback
    try:
        y_aug = augmenter.augment_clip((y, sr))
    except Exception:
        raise

# 4) Save and play
out_path = output_dir / "augmented_clip.wav"
sf.write(str(out_path), y_aug.astype(np.float32, copy=False), sr, subtype="PCM_16")
print(f"Augmented clip saved to {out_path}")
display(Audio(str(out_path), autoplay=True))

In [None]:
# Augment samples and save the training, validation, and testing sets.
# This version avoids datasets.Audio entirely by driving Clips from local WAVs.

import os, glob, random
from pathlib import Path
import types
import numpy as np
import librosa
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.spectrograms import SpectrogramGeneration

# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----
def audio_generator_from_wavs(self, split="train", repeat=1):
    """
    Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.
    Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.
    """
    files = sorted(glob.glob("generated_samples/*.wav"))
    if not files:
        raise SystemExit("❌ No WAVs in generated_samples/. Generate TTS samples first.")

    rng = random.Random(10)   # deterministic shuffling like Clips(random_split_seed=10)
    files_shuf = files[:]
    rng.shuffle(files_shuf)

    n = len(files_shuf)
    n_val = max(1, int(0.10 * n))
    n_test = max(1, int(0.10 * n))
    n_train = max(0, n - n_val - n_test)
    splits = {
        "train":      files_shuf[:n_train],
        "validation": files_shuf[n_train:n_train + n_val],
        "test":       files_shuf[n_train + n_val:],
    }
    file_list = splits.get(split, [])
    if not file_list:
        return  # nothing to yield

    for _ in range(max(1, int(repeat))):
        for p in file_list:
            y, sr = librosa.load(p, sr=16000, mono=True)
            yield y.astype(np.float32, copy=False)

# Bind the patched generator to your existing `clips` instance
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
print("✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).")

# ---- Validate augmentation asset folders exist ----
def validate(paths):
    for p in paths:
        if not Path(p).exists():
            raise SystemExit(f"❌ Missing directory: {p}. Run dataset prep first.")

impulse_paths = ["mit_rirs"]
background_paths = ["fma_16k", "audioset_16k"]
validate(impulse_paths + background_paths)

# ---- Output root ----
out_root = Path("generated_augmented_features")
out_root.mkdir(exist_ok=True)

# ---- Split config (same as before) ----
split_cfg = {
    "training":   {"name": "train",      "repetition": 2, "slide_frames": 10},
    "validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
    "testing":    {"name": "test",       "repetition": 1, "slide_frames": 1},
}

# ---- Generate features ----
for split, cfg in split_cfg.items():
    out_dir = out_root / split
    out_dir.mkdir(parents=True, exist_ok=True)
    print(f"🧪 Processing {split} …")

    spectros = SpectrogramGeneration(
        clips=clips,                 # now backed by our WAV loader
        augmenter=augmenter,         # your existing augmenter
        slide_frames=cfg["slide_frames"],
        step_ms=10,
    )

    RaggedMmap.from_generator(
        out_dir=str(out_dir / "wakeword_mmap"),
        sample_generator=spectros.spectrogram_generator(
            split=cfg["name"], repeat=cfg["repetition"]
        ),
        batch_size=100,
        verbose=True,
    )

print("✅ Features ready (generated_augmented_features/*/wakeword_mmap)")

In [None]:
# Downloads pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets. This can be slow!

import os
import requests
import zipfile
from pathlib import Path
from tqdm import tqdm

# Function to download a file with progress bar
def download_file(url, output_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(output_path, "wb") as f, tqdm(
        desc=f"Downloading {output_path.name}",
        total=total_size,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)
            bar.update(len(chunk))
    print(f"Downloaded: {output_path}")

# Function to extract ZIP files
def extract_zip(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted: {zip_path} to {extract_to}")

# Directory for negative datasets
output_dir = Path('./negative_datasets')
output_dir.mkdir(exist_ok=True)

# Negative dataset URLs
link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']

# Download and extract files
for fname in filenames:
    link = link_root + fname
    zip_path = output_dir / fname

    # Download only if the file doesn't already exist
    if not zip_path.exists():
        try:
            download_file(link, zip_path)
        except Exception as e:
            print(f"Error downloading {fname}: {e}")
            continue

    # Extract the ZIP file
    try:
        extract_zip(zip_path, output_dir)
    except Exception as e:
        print(f"Error extracting {fname}: {e}")


In [None]:
# GPU memory config (set env BEFORE importing TF)
import os, sys, gc

if "tensorflow" not in sys.modules:
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"              # grow as needed
    os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"          # modern CUDA allocator
    os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/local/cuda"
    os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"            # disable XLA JIT (more stable mem)
import tensorflow as tf

# Per-device memory growth (belt + suspenders)
for g in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception:
        pass
print("GPUs:", tf.config.list_physical_devices("GPU"))
gc.collect()

# Optional but recommended: mixed precision halves activation memory
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy("mixed_float16")
    print("Mixed precision policy:", mixed_precision.global_policy())
except Exception as e:
    print("Mixed precision not enabled:", e)

# --- Save a yaml config that controls the training process ---

import yaml

config = {}

config["window_step_ms"] = 10
config["train_dir"] = "trained_models/wakeword"

config["features"] = [
    {"features_dir":"generated_augmented_features","sampling_weight":2.0,"penalty_weight":1.0,"truth":True,"truncation_strategy":"truncate_start","type":"mmap"},
    {"features_dir":"negative_datasets/speech","sampling_weight":12.0,"penalty_weight":1.0,"truth":False,"truncation_strategy":"random","type":"mmap"},
    {"features_dir":"negative_datasets/dinner_party","sampling_weight":12.0,"penalty_weight":1.0,"truth":False,"truncation_strategy":"random","type":"mmap"},
    {"features_dir":"negative_datasets/no_speech","sampling_weight":5.0,"penalty_weight":1.0,"truth":False,"truncation_strategy":"random","type":"mmap"},
    {"features_dir":"negative_datasets/dinner_party_eval","sampling_weight":0.0,"penalty_weight":1.0,"truth":False,"truncation_strategy":"split","type":"mmap"},
]

config["training_steps"] = [40000]
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]
config["learning_rates"] = [0.001]

# Smaller batch to avoid GPU copy/alloc failures on 3070 laptop VRAM
config["batch_size"] = 16

# SpecAugment off (as before)
config["time_mask_max_size"] = [0]
config["time_mask_count"] = [0]
config["freq_mask_max_size"] = [0]
config["freq_mask_count"] = [0]

config["eval_step_interval"] = 500
config["clip_duration_ms"] = 1500
config["target_minimization"] = 0.9
config["minimization_metric"] = None
config["maximization_metric"] = "average_viable_recall"

with open("training_parameters.yaml", "w") as f:
    yaml.dump(config, f)

print("✅ Wrote training_parameters.yaml (batch_size=16) with allow_growth, cuda_malloc_async, XLA JIT OFF, mixed precision ON.")

In [None]:
# Train + export (GPU-friendly env + stable flags)

import os, sys

# --- Runtime env (inherited by the subprocess we're about to launch) ---
os.environ.setdefault("LD_LIBRARY_PATH",
    "/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/x86_64-linux-gnu:" +
    os.environ.get("LD_LIBRARY_PATH","")
)
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")            # quieter logs
os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")    # grow VRAM as needed
os.environ.setdefault("TF_GPU_ALLOCATOR", "cuda_malloc_async")# modern allocator
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_cuda_data_dir=/usr/local/cuda")
os.environ.setdefault("TF_XLA_FLAGS", "--tf_xla_auto_jit=0")  # disable XLA JIT (more stable)
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "1")            # allow TF32 (perf/VRAM win on Ampere+)

# If you still hit GPU memory errors, uncomment to force a smaller workspace:
# os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"] = "256"

# --- Kick off training ---
cmd = f'''"{sys.executable}" -m microwakeword.model_train_eval \
  --training_config="training_parameters.yaml" \
  --train 1 \
  --restore_checkpoint 1 \
  --test_tf_nonstreaming 0 \
  --test_tflite_nonstreaming 0 \
  --test_tflite_nonstreaming_quantized 0 \
  --test_tflite_streaming 0 \
  --test_tflite_streaming_quantized 1 \
  --use_weights "best_weights" \
  mixednet \
  --pointwise_filters "64,64,64,64" \
  --repeat_in_block "1,1,1,1" \
  --mixconv_kernel_sizes "[5], [7,11], [9,15], [23]" \
  --residual_connection "0,0,0,0" \
  --first_conv_filters 32 \
  --first_conv_kernel_size 5 \
  --stride 2'''
print("Running:\n", cmd)
!$cmd

In [None]:
import shutil
import json
from IPython.display import display, HTML

# Use the wake word from Cell 3
wake_word = TARGET_WORD

# --- Copy TFLite file to working dir with wake word name ---
source_path = "trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
tflite_filename = f"{wake_word}.tflite"
tflite_path = f"./{tflite_filename}"
shutil.copy(source_path, tflite_path)

# --- Write JSON metadata file with matching model name ---
json_data = {
    "type": "micro",
    "wake_word": wake_word,
    "author": "Tater Totterson",
    "website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
    "model": tflite_filename,
    "trained_languages": ["en"],
    "version": 2,
    "micro": {
        "probability_cutoff": 0.97,
        "sliding_window_size": 5,
        "feature_step_size": 10,
        "tensor_arena_size": 30000,
        "minimum_esphome_version": "2024.7.0"
    }
}
json_filename = f"{wake_word}.json"
json_path = f"./{json_filename}"
with open(json_path, "w") as json_file:
    json.dump(json_data, json_file, indent=2)

# --- Display nice download links ---
html = f"""
<h3>Download your files:</h3>
<ul>
  <li><a href="{tflite_filename}" download>⬇️ Download Model ({tflite_filename})</a></li>
  <li><a href="{json_filename}" download>⬇️ Download Metadata ({json_filename})</a></li>
</ul>
"""
display(HTML(html))