# 🎯 Training a microWakeWord Model

<div style="background-color: #f0f7fb; padding: 15px; border-radius: 10px; border-left: 5px solid #3498db; margin-bottom: 20px;">
    <h2 style="margin-top: 0; color: #3498db;">Welcome to microWakeWord Training!</h2>
    <p>This notebook guides you through training a custom wake word model using microWakeWord. It provides a <b>visual and interactive</b> approach while maintaining the powerful functionality of the basic training process.</p>
    <p><b>Python 3.10</b> is recommended for the best experience.</p>
</div>

<div style="background-color: #fff3cd; padding: 15px; border-radius: 10px; border-left: 5px solid #f0ad4e; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #8a6d3b;">⚠️ Important Note</h3>
    <p>The model generated will require experimentation to achieve good results. You may need to adjust various settings to get a model that:</p>
    <ul>
        <li>Reliably detects your wake word</li>
        <li>Doesn't trigger falsely on similar sounds</li>
        <li>Works well in your specific environment</li>
    </ul>
    <p>This notebook is designed to be run <b>sequentially</b>. Run each cell in order by pressing Shift+Enter.</p>
</div>

## 📋 What You'll Learn

1. How to set up the microWakeWord training environment
2. How to generate and augment wake word samples
3. How to train a custom wake word model
4. How to evaluate and export your model for use on devices

<div style="background-color: #dff0d8; padding: 15px; border-radius: 10px; border-left: 5px solid #5cb85c; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #3c763d;">💡 Pro Tip</h3>
    <p>Training is <b>much faster</b> on a local GPU. Make sure you have the necessary dependencies installed for GPU acceleration.</p>
</div>

At the end of this notebook, you'll have a TFLite model file ready for deployment on ESPHome devices. For deployment instructions, see the [ESPHome documentation](https://esphome.io/components/micro_wake_word) and [model examples](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2).


In [None]:
# 🔧 Setup: Install microWakeWord and dependencies
# This cell installs all necessary packages for training

import platform
import sys
from IPython.display import HTML, display

# Display a progress message
display(HTML(
    "<div style='background-color: #f8f9fa; padding: 10px; border-radius: 5px;'>"
    "<h3 style='margin-top: 0;'>📦 Installing dependencies...</h3>"
    "<p>This may take a few minutes. Please wait until completion.</p>"
    "</div>"
))

# Platform-specific installations
if platform.system() == "Darwin":
    # `pymicro-features` is installed from a fork to support building on macOS
    !pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'

# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter
!pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'

# Install ipywidgets for interactive notebook elements
!pip install ipywidgets tqdm matplotlib

# Clone and install microWakeWord
!git clone https://github.com/BigPappy098/microWakeWord
!pip install -e ./microWakeWord

# Display success message
display(HTML(
    "<div style='background-color: #dff0d8; padding: 10px; border-radius: 5px;'>"
    "<h3 style='margin-top: 0; color: #3c763d;'>✅ Setup Complete!</h3>"
    "<p>All dependencies have been installed successfully.</p>"
    "</div>"
))

## 🎤 Step 1: Choose Your Wake Word

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">Selecting an Effective Wake Word</h3>
    <p>A good wake word should be:</p>
    <ul>
        <li><b>Distinctive</b>: Unique sound patterns not common in everyday speech</li>
        <li><b>Multi-syllabic</b>: 2-5 syllables work best (e.g., "hey computer", "alexa")</li>
        <li><b>Clear pronunciation</b>: Avoid words that are difficult to pronounce consistently</li>
    </ul>
    <p>You can use phonetic spellings to improve recognition. For example, "computer" might be better as "kuhm-pyoo-ter".</p>
</div>

In [None]:
# 🎤 Set your wake word and generate a sample to verify

import os
import sys
import platform
import ipywidgets as widgets
from IPython.display import Audio, display, HTML

# Create an interactive text input for the wake word
wake_word_input = widgets.Text(
    value='hey_computer',
    description='Wake Word:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

# Help text
help_text = widgets.HTML(
    value="<p style='color: #666; font-style: italic;'>Use underscores instead of spaces (e.g., 'hey_computer'). Try phonetic spellings for better results.</p>"
)

# Display the input widget
display(wake_word_input)
display(help_text)

# Setup piper sample generator if not already done
if not os.path.exists("./piper-sample-generator"):
    display(HTML("<p>Setting up sample generator...</p>"))
    
    if platform.system() == "Darwin":
        !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator
    else:
        !git clone https://github.com/rhasspy/piper-sample-generator

    !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'

    # Install system dependencies
    !pip install torch torchaudio piper-phonemize-cross==1.2.1

    if "piper-sample-generator/" not in sys.path:
        sys.path.append("piper-sample-generator/")
        
# Create output directory if it doesn't exist
if not os.path.exists("generated_samples"):
    os.makedirs("generated_samples")

# Generate a sample for the current wake word
target_word = wake_word_input.value
display(HTML(f"<p>Generating sample for '{target_word}'...</p>"))

!python3 piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples 1 \
--batch-size 1 \
--output-dir generated_samples

# Display the audio
display(HTML("<p style='color: green;'>✅ Sample generated! Listen below:</p>"))
display(Audio("generated_samples/0.wav", autoplay=True))

# Add a note about changing the wake word
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> If you want to change the wake word, edit the text field above and run this cell again.</p>"
    "</div>"
))

## 🔊 Step 2: Generate Training Samples

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">About Training Samples</h3>
    <p>To train a robust model, we need many examples of our wake word. The sample generator creates synthetic speech samples with different voices and variations.</p>
    <p>You can adjust the number of samples and batch size below. More samples generally lead to better models but take longer to generate.</p>
</div>

In [None]:
# 🔊 Generate multiple wake word samples for training

from IPython.display import display, HTML
import ipywidgets as widgets
from tqdm.notebook import tqdm
import glob

# Create sliders for sample count and batch size
sample_count_slider = widgets.IntSlider(
    value=1000,
    min=100,
    max=5000,
    step=100,
    description='Sample Count:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

batch_size_slider = widgets.IntSlider(
    value=100,
    min=10,
    max=200,
    step=10,
    description='Batch Size:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

# Display the sliders
display(sample_count_slider)
display(batch_size_slider)

# Add a note about adjusting settings
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> Adjust the sliders above if needed, then run this cell to generate samples.</p>"
    "</div>"
))

# Generate the samples
target_word = wake_word_input.value
sample_count = sample_count_slider.value
batch_size = batch_size_slider.value

display(HTML(f"<p>Generating {sample_count} samples with batch size {batch_size}...</p>"))

!python3 piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples {sample_count} \
--batch-size {batch_size} \
--output-dir generated_samples

# Count the generated files
file_count = len(glob.glob('generated_samples/*.wav'))

# Display success message
display(HTML(f"<p style='color: green;'>✅ Generated {file_count} samples successfully!</p>"))

## 📥 Step 3: Download Negative Samples

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">About Negative Samples</h3>
    <p>To train a robust model, we need "negative" samples - audio that is NOT the wake word. These help the model learn what to ignore.</p>
    <p>The pre-generated negative datasets include:</p>
    <ul>
        <li><b>Speech</b>: General speech samples</li>
        <li><b>Dinner Party</b>: Conversational audio with multiple speakers</li>
        <li><b>No Speech</b>: Environmental sounds without speech</li>
    </ul>
</div>

In [None]:
# 📥 Download pre-generated negative datasets

import os
from IPython.display import display, HTML

display(HTML("<p>Starting download of negative datasets...</p>"))

output_dir = './negative_datasets'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
    filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']
    
    for i, fname in enumerate(filenames):
        display(HTML(f"<p>Downloading {fname} ({i+1}/{len(filenames)})...</p>"))
        link = link_root + fname
        zip_path = f"negative_datasets/{fname}"
        !wget -O {zip_path} {link}
        
        display(HTML(f"<p>Extracting {fname}...</p>"))
        !unzip -q {zip_path} -d {output_dir}

    display(HTML("<p style='color: green;'>✅ All negative datasets downloaded and extracted successfully!</p>"))
else:
    display(HTML("<p style='color: blue;'>ℹ️ Negative datasets already exist. Skipping download.</p>"))

# Add a note about the datasets
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> The negative datasets are essential for training a model that doesn't trigger on non-wake word sounds.</p>"
    "</div>"
))

## 🔄 Step 4: Set Up Audio Augmentation

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">About Audio Augmentation</h3>
    <p>Audio augmentation helps create more robust models by applying various transformations to our samples:</p>
    <ul>
        <li><b>Background Noise</b>: Adds realistic background sounds</li>
        <li><b>Room Effects</b>: Simulates different acoustic environments</li>
        <li><b>Audio Effects</b>: Applies distortion, EQ, and other modifications</li>
    </ul>
    <p>You can adjust the intensity of these effects using the controls below.</p>
</div>

In [None]:
# 🔄 Configure audio augmentation settings

import ipywidgets as widgets
from IPython.display import display, HTML, Audio

# Create sliders for augmentation parameters
background_prob_slider = widgets.FloatSlider(
    value=0.75,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Background Noise:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

rir_prob_slider = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Room Effects:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

effects_prob_slider = widgets.FloatSlider(
    value=0.1,
    min=0.0,
    max=0.5,
    step=0.05,
    description='Audio Effects:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

snr_range_slider = widgets.IntRangeSlider(
    value=[-5, 10],
    min=-20,
    max=20,
    step=1,
    description='SNR Range (dB):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

# Display the sliders
display(HTML("<h4>Augmentation Intensity Controls</h4>"))
display(background_prob_slider)
display(rir_prob_slider)
display(effects_prob_slider)
display(snr_range_slider)

# Add a note about adjusting settings
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> Adjust the sliders above if needed, then run this cell to set up augmentation.</p>"
    "</div>"
))

# Set up augmentation
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips

display(HTML("<p>Setting up audio augmentation...</p>"))

# Get values from sliders
bg_prob = background_prob_slider.value
rir_prob = rir_prob_slider.value
effect_prob = effects_prob_slider.value
min_snr, max_snr = snr_range_slider.value

# Set up clips and augmentation
clips = Clips(input_directory='generated_samples',
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
              )

augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities = {
                                "SevenBandParametricEQ": effect_prob,
                                "TanhDistortion": effect_prob,
                                "PitchShift": effect_prob,
                                "BandStopFilter": effect_prob,
                                "AddColorNoise": effect_prob,
                                "AddBackgroundNoise": bg_prob,
                                "Gain": 1.0,
                                "RIR": rir_prob,
                            },
                         impulse_paths = ['mit_rirs'],
                         background_paths = ['fma_16k', 'audioset_16k'],
                         background_min_snr_db = min_snr,
                         background_max_snr_db = max_snr,
                         min_jitter_s = 0.195,
                         max_jitter_s = 0.205,
                         )

display(HTML("<p style='color: green;'>✅ Audio augmentation configured successfully!</p>"))

# Generate a preview
display(HTML("<p>Generating augmented preview...</p>"))
from microwakeword.audio.audio_utils import save_clip

# Get a random clip and augment it
random_clip = clips.get_random_clip()
augmented_clip = augmenter.augment_clip(random_clip)
save_clip(augmented_clip, 'augmented_preview.wav')

display(HTML("<p style='color: green;'>✅ Preview generated! Listen below:</p>"))
display(Audio("augmented_preview.wav", autoplay=True))

# Add a note about augmentation
display(HTML(
    "<div style='background-color: #dff0d8; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Tip:</b> Experiment with different augmentation settings to improve model robustness. "
    "Higher values create more challenging training data but may make training more difficult.</p>"
    "</div>"
))

## 🔧 Step 5: Generate Augmented Features

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">About Feature Generation</h3>
    <p>In this step, we'll generate spectrograms from our augmented audio samples. These spectrograms will be used to train the neural network.</p>
    <p>We'll create three sets of data:</p>
    <ul>
        <li><b>Training</b>: Used to train the model</li>
        <li><b>Validation</b>: Used to evaluate the model during training</li>
        <li><b>Testing</b>: Used for final evaluation</li>
    </ul>
</div>

In [None]:
# 🔧 Generate augmented features for training, validation, and testing

import os
from IPython.display import display, HTML
from tqdm.notebook import tqdm

display(HTML("<p>Starting feature generation...</p>"))

output_dir = 'generated_augmented_features'

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

from microwakeword.audio.spectrograms import SpectrogramGeneration
from mmap_ninja.ragged import RaggedMmap

splits = ["training", "validation", "testing"]
for i, split in enumerate(splits):
    display(HTML(f"<p>Processing {split} set ({i+1}/{len(splits)})...</p>"))
    
    out_dir = os.path.join(output_dir, split)
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    
    split_name = "train"
    repetition = 2
    
    spectrograms = SpectrogramGeneration(clips=clips,
                                       augmenter=augmenter,
                                       slide_frames=10,
                                       step_ms=10,
                                       )
    
    if split == "validation":
        split_name = "validation"
        repetition = 1
    elif split == "testing":
        split_name = "test"
        repetition = 1
        spectrograms = SpectrogramGeneration(clips=clips,
                                           augmenter=augmenter,
                                           slide_frames=1,
                                           step_ms=10,
                                           )
    
    display(HTML(f"<p>Generating spectrograms for {split} set...</p>"))
    
    RaggedMmap.from_generator(
        out_dir=os.path.join(out_dir, 'wakeword_mmap'),
        sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),
        batch_size=100,
        verbose=True,
    )

display(HTML("<p style='color: green;'>✅ All features generated successfully!</p>"))

# Add a note about feature generation
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> Feature generation may take several minutes depending on your hardware. "
    "These spectrograms will be used to train the neural network model.</p>"
    "</div>"
))

## ⚙️ Step 6: Configure Training Parameters

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">About Training Configuration</h3>
    <p>The training configuration controls how the model learns from our data. Key parameters include:</p>
    <ul>
        <li><b>Training Steps</b>: How long to train the model</li>
        <li><b>Batch Size</b>: How many samples to process at once</li>
        <li><b>Class Weights</b>: How to balance positive and negative examples</li>
        <li><b>Learning Rate</b>: How quickly the model adapts to the training data</li>
    </ul>
    <p>Adjust these parameters to find the optimal balance for your wake word.</p>
</div>

In [None]:
# ⚙️ Configure training parameters

import yaml
import os
import ipywidgets as widgets
from IPython.display import display, HTML

# Create sliders for training parameters
training_steps_slider = widgets.IntSlider(
    value=10000,
    min=5000,
    max=30000,
    step=1000,
    description='Training Steps:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

batch_size_slider = widgets.IntSlider(
    value=128,
    min=32,
    max=256,
    step=32,
    description='Batch Size:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

negative_weight_slider = widgets.IntSlider(
    value=20,
    min=5,
    max=50,
    step=5,
    description='Negative Class Weight:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

learning_rate_slider = widgets.FloatLogSlider(
    value=0.001,
    base=10,
    min=-4,  # 10^-4 = 0.0001
    max=-2,  # 10^-2 = 0.01
    step=0.1,
    description='Learning Rate:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

# Display the sliders
display(HTML("<h4>Training Parameters</h4>"))
display(training_steps_slider)
display(batch_size_slider)
display(negative_weight_slider)
display(learning_rate_slider)

# Add a note about adjusting settings
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> Adjust the sliders above if needed, then run this cell to create the training configuration.</p>"
    "</div>"
))

# Create configuration
display(HTML("<p>Creating training configuration...</p>"))

# Get values from sliders
training_steps = training_steps_slider.value
batch_size = batch_size_slider.value
negative_weight = negative_weight_slider.value
learning_rate = learning_rate_slider.value

# Create configuration dictionary
config = {}

config["window_step_ms"] = 10
config["train_dir"] = "trained_models/wakeword"

# Configure feature directories
config["features"] = [
    {
        "features_dir": "generated_augmented_features",
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    { # Only used for validation and testing
        "features_dir": "negative_datasets/dinner_party_eval",
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

# Training parameters
config["training_steps"] = [training_steps]
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [negative_weight]
config["learning_rates"] = [learning_rate]
config["batch_size"] = batch_size

# SpecAugment parameters (disabled by default)
config["time_mask_max_size"] = [0]
config["time_mask_count"] = [0]
config["freq_mask_max_size"] = [0]
config["freq_mask_count"] = [0]

# Evaluation parameters
config["eval_step_interval"] = 500
config["clip_duration_ms"] = 1500

# Model selection criteria
config["target_minimization"] = 0.9
config["minimization_metric"] = None
config["maximization_metric"] = "average_viable_recall"

# Save configuration to file
with open(os.path.join("training_parameters.yaml"), "w") as file:
    yaml.dump(config, file)

display(HTML("<p style='color: green;'>✅ Training configuration created successfully!</p>"))

# Add a note about configuration
display(HTML(
    "<div style='background-color: #dff0d8; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Tip:</b> For most wake words, the default settings work well as a starting point. "
    "If your model doesn't perform well, try adjusting these parameters:</p>"
    "<ul>"
    "<li>Increase <b>Training Steps</b> for better accuracy (but longer training time)</li>"
    "<li>Increase <b>Negative Class Weight</b> to reduce false positives</li>"
    "<li>Decrease <b>Negative Class Weight</b> if the model rarely detects the wake word</li>"
    "</ul>"
    "</div>"
))

## 🚀 Step 7: Train the Model

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">About Model Training</h3>
    <p>Now we'll train the neural network model using the configuration we created. The training process:</p>
    <ul>
        <li>Feeds batches of spectrograms to the neural network</li>
        <li>Adjusts the model weights based on prediction errors</li>
        <li>Periodically evaluates the model on validation data</li>
        <li>Saves the best-performing model weights</li>
    </ul>
    <p>Training may take several minutes to hours depending on your hardware and the number of training steps.</p>
</div>

In [None]:
# 🚀 Train the wake word model

import ipywidgets as widgets
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import subprocess
import os
import time

# Create model architecture selection dropdown
model_architecture = widgets.Dropdown(
    options=[
        ('MixedNet (Recommended)', 'mixednet'),
        ('MobileNet', 'mobilenet'),
        ('ResNet', 'resnet')
    ],
    value='mixednet',
    description='Model Architecture:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

# Create model size selection dropdown
model_size = widgets.Dropdown(
    options=[
        ('Small', 'small'),
        ('Medium (Recommended)', 'medium'),
        ('Large', 'large')
    ],
    value='medium',
    description='Model Size:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

# Display the dropdowns
display(HTML("<h4>Model Architecture Settings</h4>"))
display(model_architecture)
display(model_size)

# Add a note about adjusting settings
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> Adjust the model architecture and size if needed, then run this cell to train the model.</p>"
    "</div>"
))

# Get selected architecture and size
arch = model_architecture.value
size = model_size.value

display(HTML(f"<p>Starting model training with {arch} architecture ({size} size)...</p>"))
display(HTML("<p>This may take a while. Please be patient.</p>"))

# Define architecture parameters based on size
if arch == 'mixednet':
    if size == 'small':
        pointwise_filters = "48,48,48,48"
        kernel_sizes = "'[5],[7,11],[9,15],[17]'"
        first_conv_filters = 24
    elif size == 'medium':
        pointwise_filters = "64,64,64,64"
        kernel_sizes = "'[5],[7,11],[9,15],[23]'"
        first_conv_filters = 32
    else:  # large
        pointwise_filters = "96,96,96,96"
        kernel_sizes = "'[5],[7,11],[9,15],[23]'"
        first_conv_filters = 48
        
    # Build the command
    cmd = f"python -m microwakeword.model_train_eval \
    --training_config='training_parameters.yaml' \
    --train 1 \
    --restore_checkpoint 1 \
    --test_tf_nonstreaming 0 \
    --test_tflite_nonstreaming 0 \
    --test_tflite_nonstreaming_quantized 0 \
    --test_tflite_streaming 0 \
    --test_tflite_streaming_quantized 1 \
    --use_weights \"best_weights\" \
    mixednet \
    --pointwise_filters \"{pointwise_filters}\" \
    --repeat_in_block \"1,1,1,1\" \
    --mixconv_kernel_sizes {kernel_sizes} \
    --residual_connection \"0,0,0,0\" \
    --first_conv_filters {first_conv_filters} \
    --first_conv_kernel_size 5 \
    --stride 3"
else:
    # Simplified command for other architectures
    cmd = f"python -m microwakeword.model_train_eval \
    --training_config='training_parameters.yaml' \
    --train 1 \
    --restore_checkpoint 1 \
    --test_tf_nonstreaming 0 \
    --test_tflite_nonstreaming 0 \
    --test_tflite_nonstreaming_quantized 0 \
    --test_tflite_streaming 0 \
    --test_tflite_streaming_quantized 1 \
    --use_weights \"best_weights\" \
    {arch}"

# Create a progress bar
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=100,
    description='Training:',
    bar_style='info',
    orientation='horizontal'
)
display(progress_bar)

# Process output
output_text = widgets.Output()
display(output_text)

# Run the command
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)

with output_text:
    for line in process.stdout:
        print(line.strip())
        if "step" in line and "loss" in line:
            try:
                # Extract step number and update progress
                step_str = line.split("step")[1].split(",")[0].strip()
                step = int(step_str)
                training_steps = training_steps_slider.value
                progress = min(100, int(step * 100 / training_steps))
                progress_bar.value = progress
            except:
                pass

# Wait for process to complete
process.wait()

if process.returncode == 0:
    display(HTML("<p style='color: green;'>✅ Model training completed successfully!</p>"))
else:
    display(HTML("<p style='color: red;'>❌ Model training failed. Check the output for errors.</p>"))

# Add a note about training
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> Training can take a long time, especially with many training steps. "
    "The process may appear stuck for several minutes between updates. "
    "This is normal - the training is still running in the background.</p>"
    "</div>"
))

## 📤 Step 8: Export the Model

<div style="background-color: #e8f4f8; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #2980b9;">About Model Export</h3>
    <p>The final step is to export the trained model for use on devices. The model is converted to TensorFlow Lite format and quantized to reduce its size and improve inference speed.</p>
    <p>You'll also need to create a model manifest file to use with ESPHome. The manifest contains metadata about the model and detection parameters.</p>
</div>

In [None]:
# 📤 Export the trained model

import ipywidgets as widgets
from IPython.display import display, HTML
import json
import os
import shutil

# Create a slider for detection threshold
threshold_slider = widgets.FloatSlider(
    value=0.5,
    min=0.1,
    max=0.9,
    step=0.05,
    description='Detection Threshold:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

# Display the slider
display(HTML("<h4>Model Export Settings</h4>"))
display(threshold_slider)

# Add a note about adjusting settings
display(HTML(
    "<div style='background-color: #fcf8e3; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<p><b>Note:</b> Adjust the detection threshold if needed, then run this cell to export the model.</p>"
    "</div>"
))

display(HTML("<p>Exporting model...</p>"))

# Get the wake word and threshold
wake_word = wake_word_input.value
threshold = threshold_slider.value

# Path to the trained model
model_path = "trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"

# Check if the model exists
if not os.path.exists(model_path):
    display(HTML("<p style='color: red;'>❌ Model file not found. Make sure training completed successfully.</p>"))
else:
    # Create export directory
    export_dir = f"exported_model_{wake_word}"
    if not os.path.exists(export_dir):
        os.makedirs(export_dir)
    
    # Copy the model file
    export_model_path = os.path.join(export_dir, f"{wake_word}.tflite")
    shutil.copy(model_path, export_model_path)
    
    # Create manifest file
    manifest = {
        "name": wake_word,
        "version": 2,
        "type": "micro_speech",
        "description": f"Custom wake word model for '{wake_word}'",
        "specs": {
            "average_window_length": 10,
            "detection_threshold": threshold,
            "suppression_ms": 1000,
            "minimum_count": 3,
            "sample_rate": 16000,
            "vocabulary": ["_silence_", "_unknown_", wake_word]
        }
    }
    
    # Save manifest file
    manifest_path = os.path.join(export_dir, "manifest.json")
    with open(manifest_path, 'w') as f:
        json.dump(manifest, f, indent=2)
    
    # Create ESPHome configuration example
    esphome_config = f"""
# Wake word configuration
micro_wake_word:
  model_file: "{os.path.basename(export_model_path)}"
  model_name: "{wake_word}"
  probability_cutoff: {threshold}
  
binary_sensor:
  - platform: micro_wake_word
    name: "Wake Word Detected"
    id: wake_word
    model_id: {wake_word}
    
# Optional - add a text-to-speech response
esphome:
  on_boot:
    priority: -100
    then:
      - delay: 5s
      - logger.log: "Wake word detection ready"
      
on_wake_word:
  - logger.log: "Wake word detected!"
  # Add your actions here
"""
    
    # Save ESPHome config example
    config_path = os.path.join(export_dir, "esphome_example.yaml")
    with open(config_path, 'w') as f:
        f.write(esphome_config)
    
    display(HTML(f"<p style='color: green;'>✅ Model exported successfully to {export_dir}!</p>"))
    display(HTML(f"<p>Files created:</p>"))
    display(HTML(f"<ul>"))
    display(HTML(f"<li><b>{os.path.basename(export_model_path)}</b> - The TFLite model file</li>"))
    display(HTML(f"<li><b>manifest.json</b> - Model metadata for ESPHome</li>"))
    display(HTML(f"<li><b>esphome_example.yaml</b> - Example ESPHome configuration</li>"))
    display(HTML(f"</ul>"))
    
    display(HTML(f"<p>Files are saved in the <code>{export_dir}</code> directory.</p>"))

# Add a note about using the model
display(HTML(
    "<div style='background-color: #dff0d8; padding: 10px; border-radius: 5px; margin-top: 10px;'>"
    "<h4 style='margin-top: 0;'>Using Your Model with ESPHome</h4>"
    "<p>To use your trained model with ESPHome:</p>"
    "<ol>"
    "<li>Copy the .tflite file and manifest.json to your ESPHome configuration directory</li>"
    "<li>Add the configuration from esphome_example.yaml to your device's YAML file</li>"
    "<li>Adjust the detection threshold if needed (higher = fewer false positives, but may miss some activations)</li>"
    "<li>Flash your device with the updated configuration</li>"
    "</ol>"
    "<p>For more information, see the <a href='https://esphome.io/components/micro_wake_word' target='_blank'>ESPHome documentation</a>.</p>"
    "</div>"
))

## 🎉 Congratulations!

<div style="background-color: #f0f7fb; padding: 15px; border-radius: 10px; border-left: 5px solid #3498db; margin-bottom: 20px;">
    <h3 style="margin-top: 0; color: #3498db;">You've Successfully Trained a Custom Wake Word Model!</h3>
    <p>You've completed all the steps to train and export a custom wake word model using microWakeWord. Here's what you've accomplished:</p>
    <ul>
        <li>Generated synthetic wake word samples</li>
        <li>Applied audio augmentation to improve robustness</li>
        <li>Configured and trained a neural network model</li>
        <li>Exported the model for use on devices</li>
    </ul>
    <p>If your model doesn't perform as expected, try experimenting with different settings:</p>
    <ul>
        <li>Try different phonetic spellings of your wake word</li>
        <li>Adjust augmentation parameters</li>
        <li>Increase training steps</li>
        <li>Modify class weights</li>
        <li>Try different model architectures</li>
    </ul>
    <p>Happy wake word detecting!</p>
</div>