# Import Section
---

In [1]:
"""
This script is used for testing a keyword spotting model using TensorFlow
with ipywidgets UI.
"""

import os
import argparse
from collections import OrderedDict
import numpy as np
import tensorflow as tf

import ipywidgets as widgets
from ipywidgets import Layout, Box, Dropdown, Label, IntSlider
from IPython.display import display, HTML, clear_output

from kws_python import data
from kws_python import models

In [2]:
try:
    # For using colab
    from google.colab import drive

    print("origin is:")
    print(os.getcwd())
    drive.mount("/content/drive")

    os.chdir(r"/content/drive/MyDrive/tflu-kws-cortex-m/Training")
    print("update to:")
    print(os.getcwd())

except ImportError:
    print(r"Running Location:")
    print(os.path.abspath(os.getcwd()))

Running Location:
c:\CYCHEN38\OpenNuvoton\ML_KWS\ML_kws_tflu


# Test Section
---

In [3]:
def test(flags):
    """Calculate accuracy and confusion matrices on validation and test sets.

    Model is created and weights loaded from supplied command line arguments.
    """
    model_settings = models.prepare_model_settings(
        len(data.prepare_words_list(flags.wanted_words.split(","))), flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms, flags.window_stride_ms, flags.dct_coefficient_count
    )

    model = models.create_model(model_settings, flags.model_architecture, flags.model_size_info, False)

    audio_processor = data.AudioProcessor(
        data_exist=flags.data_exist,
        data_url=flags.data_url,
        data_dir=flags.data_dir,
        silence_percentage=flags.silence_percentage,
        unknown_percentage=flags.unknown_percentage,
        wanted_words=flags.wanted_words.split(","),
        validation_percentage=flags.validation_percentage,
        testing_percentage=flags.testing_percentage,
        model_settings=model_settings,
    )
    print(flags.checkpoint)
    model.load_weights(flags.checkpoint).expect_partial()

    # Evaluate on validation set.
    print("Running testing on validation set...")
    val_data = audio_processor.get_data(audio_processor.Modes.VALIDATION).batch(flags.batch_size)
    expected_indices = np.concatenate([y for x, y in val_data])

    predictions = model.predict(val_data)
    predicted_indices = tf.argmax(predictions, axis=1)

    val_accuracy = calculate_accuracy(predicted_indices, expected_indices)
    confusion_matrix = tf.math.confusion_matrix(expected_indices, predicted_indices, num_classes=model_settings["label_count"])
    print(confusion_matrix.numpy())
    print(f"Validation accuracy = {val_accuracy * 100:.2f}%" f"(N={audio_processor.set_size(audio_processor.Modes.VALIDATION)})")

    # Evaluate on testing set.
    print("Running testing on test set...")
    test_data = audio_processor.get_data(audio_processor.Modes.TESTING).batch(flags.batch_size)
    expected_indices = np.concatenate([y for x, y in test_data])

    predictions = model.predict(test_data)
    predicted_indices = tf.argmax(predictions, axis=1)

    test_accuracy = calculate_accuracy(predicted_indices, expected_indices)
    confusion_matrix = tf.math.confusion_matrix(expected_indices, predicted_indices, num_classes=model_settings["label_count"])
    print(confusion_matrix.numpy())
    print(f"Test accuracy = {test_accuracy * 100:.2f}%" f"(N={audio_processor.set_size(audio_processor.Modes.TESTING)})")


def calculate_accuracy(predicted_indices, expected_indices):
    """Calculates and returns accuracy.

    Args:
        predicted_indices: List of predicted integer indices.
        expected_indices: List of expected integer indices.

    Returns:
        Accuracy value between 0 and 1.
    """
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy

# Argument Setting
---

In [4]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_exist", type=bool, default=True, help="True will skip download and tar.")
    parser.add_argument("--data_url", type=str, default="http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz", help="Location of speech training data archive on the web.")
    parser.add_argument(
        "--data_dir",
        type=str,
        default="tmp/speech_dataset/",
        help="""\
        Where to download the speech training data to.
        """,
    )
    parser.add_argument(
        "--silence_percentage",
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be silence.
        """,
    )
    parser.add_argument(
        "--unknown_percentage",
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be unknown words.
        """,
    )
    parser.add_argument("--testing_percentage", type=int, default=10, help="What percentage of wavs to use as a test set.")
    parser.add_argument("--validation_percentage", type=int, default=10, help="What percentage of wavs to use as a validation set.")
    parser.add_argument(
        "--sample_rate",
        type=int,
        default=16000,
        help="Expected sample rate of the wavs",
    )
    parser.add_argument(
        "--clip_duration_ms",
        type=int,
        default=1000,
        help="Expected duration in milliseconds of the wavs",
    )
    parser.add_argument(
        "--window_size_ms",
        type=float,
        default=30.0,
        help="How long each spectrogram timeslice is",
    )
    parser.add_argument(
        "--window_stride_ms",
        type=float,
        default=10.0,
        help="How long each spectrogram timeslice is",
    )
    parser.add_argument(
        "--dct_coefficient_count",
        type=int,
        default=40,
        help="How many bins to use for the MFCC fingerprint",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=100,
        help="How many items to train with at once",
    )
    parser.add_argument(
        "--wanted_words",
        type=str,
        default="yes,no,up,down,left,right,on,off,stop,go",
        help="Words to use (others will be added to an unknown label)",
    )
    parser.add_argument("--checkpoint", type=str, help="Checkpoint to load the weights from.")
    parser.add_argument("--model_architecture", type=str, default="dnn", help="What model architecture to use")
    parser.add_argument("--model_size_info", type=int, nargs="+", default=[128, 128, 128], help="Model dimensions - different for various models")

# Widgets Control Section
---

In [5]:
class InitTestWidgets:
    """
    A class to initialize and manage interactive widgets for testing machine learning models.
    Methods:
    --------
    create_command(cm_list):
        Creates a command for testing based on the provided parameters.
    show_main():
        Displays the main interactive section with all widgets.
    run_test():
        Executes the test based on the saved command.
    """

    def __init__(self):  # intial the widgets elements

        form_item_layout = Layout(
            display="flex",
            flex_flow="row",
            justify_content="space-between",
        )

        # follow parameters widgets
        self.a_ch = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.b_ch = widgets.Text(value="work/DS_CNN/1/training/best/ds_cnn_0.933_ckpt", placeholder="Type something", disabled=False)
        form_follow_items = [
            Box([Label(value="Follow the train process setting(recommend)"), self.a_ch], layout=form_item_layout),
            Box([Label(value="Model location"), self.b_ch], layout=form_item_layout),
        ]
        self.form_box_follow_para = Box(
            form_follow_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                border="solid 3px lightblue",
                align_items="stretch",
                width="50%",
            ),
        )

        # train model parameters widgets
        self.a_ta = Dropdown(options=["dnn", "cnn", "ds_cnn", "basic_lstm"])
        self.b_ta = widgets.BoundedIntText(value=10, min=0, max=50.0, step=1, disabled=False)
        self.c_ta = widgets.BoundedIntText(value=10, min=0, max=50.0, step=1, disabled=False)
        self.g_ta = widgets.IntSlider(value=100, min=50, max=1000, step=50)
        self.h_ta = widgets.Text(value="128,128,128", placeholder="Type something", description="Int:", disabled=False)
        self.i_ta = widgets.Textarea(value="yes,no,up,down,left,right,on,off,stop,go", placeholder="Type something", description="String:", disabled=False)

        form_train_items = [
            Box([Label(value="Model Architecture"), self.a_ta], layout=form_item_layout),
            Box([Label(value="Testing percentage"), self.b_ta], layout=form_item_layout),
            Box([Label(value="Validation percentage"), self.c_ta], layout=form_item_layout),
            Box([Label(value="Batch size"), self.g_ta], layout=form_item_layout),
            Box([Label(value="Model size (dimension)"), self.h_ta], layout=form_item_layout),
            Box([Label(value="Wanted words"), self.i_ta], layout=form_item_layout),
        ]

        self.form_box_train_para = Box(
            form_train_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                border="solid 3px lightblue",
                align_items="stretch",
                width="50%",
            ),
        )

        # data parameters widgets
        self.a_da = IntSlider(value=10, min=10, max=50)
        self.b_da = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.c_da = widgets.FloatSlider(value=0.1, min=0.0, max=1.0)
        self.d_da = widgets.FloatSlider(value=0.8, min=0.0, max=1.0)
        self.e_da = widgets.FloatSlider(value=10.0, min=0.0, max=30.0)
        self.f_da = widgets.FloatSlider(value=10.0, min=0.0, max=30.0)
        self.g_da = widgets.FloatSlider(value=100.0, min=50.0, max=200.0, step=10.0)
        self.h_da = widgets.IntSlider(value=16000, min=16000, max=32000, step=16000)
        self.i_da = widgets.IntSlider(value=1000, min=800, max=3000, step=200)
        self.j_da = widgets.IntSlider(value=40, min=10, max=100, step=10)
        self.k_da = widgets.IntSlider(value=40, min=10, max=100, step=10)

        form_data_items = [
            Box([Label(value="DCT coefficient count"), self.a_da], layout=form_item_layout),
            Box([Label(value="Data exist"), self.b_da], layout=form_item_layout),
            Box([Label(value="Background volume"), self.c_da], layout=form_item_layout),
            Box([Label(value="Background frequency"), self.d_da], layout=form_item_layout),
            Box([Label(value="Silence percentage"), self.e_da], layout=form_item_layout),
            Box([Label(value="Unknown percentage"), self.f_da], layout=form_item_layout),
            Box([Label(value="Time shift (ms)"), self.g_da], layout=form_item_layout),
            Box([Label(value="Sample rate"), self.h_da], layout=form_item_layout),
            Box([Label(value="Clip duration (ms)"), self.i_da], layout=form_item_layout),
            Box([Label(value="Window size (ms)"), self.j_da], layout=form_item_layout),
            Box([Label(value="Window stride (ms)"), self.k_da], layout=form_item_layout),
        ]

        self.form_box_data_para = Box(
            form_data_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                border="solid 3px lightblue",
                align_items="stretch",
                width="50%",
            ),
        )

        # button widgets
        self.a_bu = widgets.Button(description="Save Test Setting", layout=Layout(width="30%", height="30px"), button_style="success")
        self.b_bu = widgets.Button(description="Start to Run", layout=Layout(width="30%", height="30px"), button_style="success")

        form_button_items = [self.a_bu, self.b_bu]

        self.form_button = Box(
            form_button_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                # border='solid 3px lightblue',
                align_items="stretch",
                width="50%",
            ),
        )

    def create_command(self, cm_list):
        """
        Creates a command for testing a machine learning model based on the provided list of parameters.
        Args:
            cm_list (list): A list of parameters for the command. The first element is a boolean indicating whether to use
                            the training process settings. The subsequent elements are the values for the command arguments.
        Returns:
            int: Returns 0 upon successful creation of the command.
        The function performs the following steps:
        1. Initializes a list of argument names.
        2. Creates an ordered dictionary to store the command arguments and their values.
        3. If the first element of cm_list is True, reads the training command from 'train_cmd.txt' and uses its settings.
           - Saves the checkpoint value from cm_list.
           - Reads the training command line from 'train_cmd.txt'.
           - Parses the training command line and extracts the values for the required arguments.
        4. If the first element of cm_list is False, uses the values from cm_list directly.
        5. Writes the complete command to 'test_cmd.txt'.
        """
        argument_list = [
            "--checkpoint",
            "--model_architecture",
            "--testing_percentage",
            "--validation_percentage",
            "-batch_size",
            "--model_size_info",
            "--wanted_words",
            "--dct_coefficient_count",
            "--data_exist",
            "--background_volume",
            "--background_frequency",
            "--silence_percentage",
            "--unknown_percentage",
            "--time_shift_ms",
            "--sample_rate",
            "--clip_duration_ms",
            "--window_size_ms",
            "--window_stride_ms",
        ]
        cm_dict = OrderedDict()

        if cm_list[0] is True:  # directly use train process setting
            cm_dict[argument_list[0]] = cm_list[1]  # save the checkpoint

            with open("train_cmd.txt", "r", encoding="utf-8") as f:
                train_cmd_line = f.read()
            train_cmd_list = train_cmd_line.split()
            if train_cmd_list != []:
                print("read the exist train_cmd.txt")

            for idx, val in enumerate(train_cmd_list):
                if val in argument_list:  # find the needed attrs

                    if val == "--model_size_info":
                        i = 1
                        m_list = []
                        while train_cmd_list[idx + i].find("--") == -1:
                            m_list.append(train_cmd_list[idx + i])
                            i = i + 1
                        cm_dict[val] = m_list
                    else:
                        cm_dict[val] = train_cmd_list[idx + 1]
        else:
            cm_dict = self.__get_ipyw_vals(cm_list, argument_list)

        self.__write_test_cmd_txt(cm_dict)

        return 0

    def __get_ipyw_vals(self, cm_list, argument_list):
        """
        Converts a list of command-line arguments and their corresponding values into an ordered dictionary.
        Args:
            cm_list (list): A list of values where the first element is ignored, and the rest are processed.
            argument_list (list): A list of command-line arguments corresponding to the values in cm_list.
        Returns:
            OrderedDict
        """
        cm_dict = OrderedDict()
        for idx, val in enumerate(cm_list[1:]):
            print(idx, val)
            if argument_list[idx] == "--model_size_info":  # transfer from single string to list format
                cm_dict[argument_list[idx]] = val.split(",")
            else:
                cm_dict[argument_list[idx]] = val

        return cm_dict

    def __write_test_cmd_txt(self, cm_dict):
        """
        Writes the complete test command to the 'test_cmd.txt' file.
        Args:
            cm_dict (OrderedDict): An ordered dictionary containing the command-line arguments and their values.
        """
        with open("test_cmd.txt", "w", encoding="utf-8") as f:  # save the complete command for test.py
            for key, value in cm_dict.items():

                if isinstance(value, list):
                    f.write(f"{key} ")
                    for _, val in enumerate(value):
                        f.write(f"{val} ")
                else:
                    f.write(f"{key} {value} ")

    def show_main(self):  # interactive swection
        """
        Displays the main interactive section for parameter selection and testing.
        This method creates and displays an interactive UI for selecting parameters
        for testing or using default values.It includes two buttons for saving the
        test settings and running the test.
        """

        intro_text = "Please Choose the parameters of the testing or using the default"
        html_widget = widgets.HTML(value=f"<b><font color='lightgreen'><font size=6>{intro_text}</b>")
        display(html_widget)

        # Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_box_follow_para, self.form_box_train_para, self.form_box_data_para]).add_class("parentstyle")

        # Add a custom style tag to the notebook, you can use dev tool to inspect the class names
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, "Follow Setting")
        accordion.set_title(1, "Test Setting")
        accordion.set_title(2, "Data Setting")

        output_widgets = widgets.Output(layout=Layout(border="1px solid green"))

        def act_para(*, follow, model_loc, model, test_per, vali_per, batch, dims, outputs, dct_coe, data_b_da, b_vol, b_freq, silence, unk, t_sft, rate, dura, win_size, win_str):

            # If any value is changed, clear the widgets
            with output_widgets:
                output_widgets.clear_output()

            if follow:
                self.form_box_train_para.layout.visibility = "hidden"
                self.form_box_data_para.layout.visibility = "hidden"
            else:
                self.form_box_train_para.layout.visibility = "visible"
                self.form_box_data_para.layout.visibility = "visible"

        # ------------------#
        # widgets.Accordion's interactive input with action function `act_para()`
        # ------------------#
        out_inter = widgets.interactive_output(
            act_para,
            {
                "follow": self.a_ch,
                "model_loc": self.b_ch,
                "model": self.a_ta,
                "test_per": self.b_ta,
                "vali_per": self.c_ta,
                "batch": self.g_ta,
                "dims": self.h_ta,
                "outputs": self.i_ta,
                "dct_coe": self.a_da,
                "data_b_da": self.b_da,
                "b_vol": self.c_da,
                "b_freq": self.d_da,
                "silence": self.e_da,
                "unk": self.f_da,
                "t_sft": self.g_da,
                "rate": self.h_da,
                "dura": self.i_da,
                "win_size": self.j_da,
                "win_str": self.k_da,
            },
        )
        display(accordion, self.form_button, out_inter)
        display(output_widgets)

        def on_button_clicked_save_test_set(b):
            with output_widgets:
                clear_output()
                self.create_command(
                    [
                        self.a_ch.value,
                        self.b_ch.value,
                        self.a_ta.value,
                        self.b_ta.value,
                        self.c_ta.value,
                        self.g_ta.value,
                        self.h_ta.value,
                        self.i_ta.value,
                        self.a_da.value,
                        self.b_da.value,
                        self.c_da.value,
                        self.d_da.value,
                        self.e_da.value,
                        self.f_da.value,
                        self.g_da.value,
                        self.h_da.value,
                        self.i_da.value,
                        self.j_da.value,
                        self.k_da.value,
                    ]
                )
                text0 = "The testing setting is finish and saved"
                html0 = widgets.HTML(value=f"<b><font color='lightblue'><font size=2>{text0}</b>")
                display(html0)
        self.a_bu.on_click(on_button_clicked_save_test_set)

        def on_button_clicked_train(b):
            with output_widgets:
                clear_output()
                self.run_test()
                print("Finish")
        self.b_bu.on_click(on_button_clicked_train)

    def run_test(self):
        """
        Executes the test commands specified in the "test_cmd.txt" file.
        This method reads the test commands from the "test_cmd.txt" file, parses them,
        and then runs the test using the parsed arguments.
        """
        with open("test_cmd.txt", "r", encoding="utf-8") as f:  # save the complete command for train.py
            train_cmd_line = f.read()
        cmd_list = train_cmd_line.split()

        if cmd_list != []:
            print("read the test commands!")
        else:
            print("The test_cmd.txt is empty!")

        parser_flags, _ = parser.parse_known_args(args=cmd_list)
        # parser_flags, _ = parser.parse_known_args(args = ['--model_architecture','dnn','--checkpoint',r'work\DNN\DNN3\training\best\dnn_0.835_ckpt',
        # '--model_size_info','128','128','128'])
        test(parser_flags)

# Run Section
---
- The detail description of all the parameters is here [meaning](#id-PDD)
- `Follow the train process setting`: Please directly use the train setting of the same model
- `Model Location`: Please fill in the trained model location which is the `*_ckpt` file, for example: work/DNN/DNN2/training/dnn_0.826_ckpt


In [6]:
act = InitTestWidgets()
act.show_main()

HTML(value="<b><font color='lightgreen'><font size=6>Please Choose the parameters of the testing or using the …

Accordion(children=(Box(children=(Box(children=(Label(value='Follow the train process setting(recommend)'), Ch…

Box(children=(Button(button_style='success', description='Save Test Setting', layout=Layout(height='30px', wid…

Output()

Output(layout=Layout(border_bottom='1px solid green', border_left='1px solid green', border_right='1px solid g…

<a id="id-PDD"></a>
# Parameter Description
---
- This notebook is basing on https://github.com/ARM-software/ML-examples/tree/main/tflu-kws-cortex-m.
- The Parameter Description is same as train, please check the `train.ipynb`