# Import Section
---

In [1]:
import os

import argparse
import numpy as np
import tensorflow as tf

In [2]:
#folder_exc = r'C:\Users\USERNAME\MICRO_ML\ML_kws_tflu'

try:  
    from google.colab import drive
    print('origin is:')
    print (os.getcwd())
    drive.mount('/content/drive')

    os.chdir(r'/content/drive/MyDrive/tflu-kws-cortex-m/Training')
    print('update to:')
    print (os.getcwd())
    
except ImportError:
    print(r'Running Location:')
    print(os.path.abspath(os.getcwd()))
    #if (os.getcwd() != folder_exc)&(os.getcwd() != folder_exc.replace('/', "\\")):  
    #  os.chdir(folder_exc)
    #  print('update to:')
    #  print (os.getcwd())
    #else:
    #  print('no update')  
import data
import models


Running Location:
C:\Users\USERNAME\MICRO_ML\ML_kws_tflu


# Test Section
---

In [3]:
def test(FLAGS):
    """Calculate accuracy and confusion matrices on validation and test sets.

    Model is created and weights loaded from supplied command line arguments.
    """
    model_settings = models.prepare_model_settings(len(data.prepare_words_list(FLAGS.wanted_words.split(','))),
                                                   FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
                                                   FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    model = models.create_model(model_settings, FLAGS.model_architecture, FLAGS.model_size_info, False)

    audio_processor = data.AudioProcessor(data_exist=FLAGS.data_exist,
                                          data_url=FLAGS.data_url,
                                          data_dir=FLAGS.data_dir,
                                          silence_percentage=FLAGS.silence_percentage,
                                          unknown_percentage=FLAGS.unknown_percentage,
                                          wanted_words=FLAGS.wanted_words.split(','),
                                          validation_percentage=FLAGS.validation_percentage,
                                          testing_percentage=FLAGS.testing_percentage,
                                          model_settings=model_settings)
    print(FLAGS.checkpoint)
    model.load_weights(FLAGS.checkpoint).expect_partial()

    # Evaluate on validation set.
    print("Running testing on validation set...")
    val_data = audio_processor.get_data(audio_processor.Modes.VALIDATION).batch(FLAGS.batch_size)
    expected_indices = np.concatenate([y for x, y in val_data])

    predictions = model.predict(val_data)
    predicted_indices = tf.argmax(predictions, axis=1)

    val_accuracy = calculate_accuracy(predicted_indices, expected_indices)
    confusion_matrix = tf.math.confusion_matrix(expected_indices, predicted_indices,
                                                num_classes=model_settings['label_count'])
    print(confusion_matrix.numpy())
    print(f'Validation accuracy = {val_accuracy * 100:.2f}%'
          f'(N={audio_processor.set_size(audio_processor.Modes.VALIDATION)})')

    # Evaluate on testing set.
    print("Running testing on test set...")
    test_data = audio_processor.get_data(audio_processor.Modes.TESTING).batch(FLAGS.batch_size)
    expected_indices = np.concatenate([y for x, y in test_data])

    predictions = model.predict(test_data)
    predicted_indices = tf.argmax(predictions, axis=1)

    test_accuracy = calculate_accuracy(predicted_indices, expected_indices)
    confusion_matrix = tf.math.confusion_matrix(expected_indices, predicted_indices,
                                                num_classes=model_settings['label_count'])
    print(confusion_matrix.numpy())
    print(f'Test accuracy = {test_accuracy * 100:.2f}%'
          f'(N={audio_processor.set_size(audio_processor.Modes.TESTING)})')
    
def calculate_accuracy(predicted_indices, expected_indices):
    """Calculates and returns accuracy.

    Args:
        predicted_indices: List of predicted integer indices.
        expected_indices: List of expected integer indices.

    Returns:
        Accuracy value between 0 and 1.
    """
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy    

# Argument Setting
---

In [4]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_exist',
        type=bool,
        default=True,
        help='True will skip download and tar.')
    parser.add_argument(
        '--data_url',
        type=str,
        default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
        help='Location of speech training data archive on the web.')
    parser.add_argument(
        '--data_dir',
        type=str,
        default='tmp/speech_dataset/',
        help="""\
        Where to download the speech training data to.
        """)
    parser.add_argument(
        '--silence_percentage',
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be silence.
        """)
    parser.add_argument(
        '--unknown_percentage',
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be unknown words.
        """)
    parser.add_argument(
        '--testing_percentage',
        type=int,
        default=10,
        help='What percentage of wavs to use as a test set.')
    parser.add_argument(
        '--validation_percentage',
        type=int,
        default=10,
        help='What percentage of wavs to use as a validation set.')
    parser.add_argument(
        '--sample_rate',
        type=int,
        default=16000,
        help='Expected sample rate of the wavs',)
    parser.add_argument(
        '--clip_duration_ms',
        type=int,
        default=1000,
        help='Expected duration in milliseconds of the wavs',)
    parser.add_argument(
        '--window_size_ms',
        type=float,
        default=30.0,
        help='How long each spectrogram timeslice is',)
    parser.add_argument(
        '--window_stride_ms',
        type=float,
        default=10.0,
        help='How long each spectrogram timeslice is',)
    parser.add_argument(
        '--dct_coefficient_count',
        type=int,
        default=40,
        help='How many bins to use for the MFCC fingerprint',)
    parser.add_argument(
        '--batch_size',
        type=int,
        default=100,
        help='How many items to train with at once',)
    parser.add_argument(
        '--wanted_words',
        type=str,
        default='yes,no,up,down,left,right,on,off,stop,go',
        help='Words to use (others will be added to an unknown label)',)
    parser.add_argument(
        '--checkpoint',
        type=str,
        help='Checkpoint to load the weights from.')
    parser.add_argument(
        '--model_architecture',
        type=str,
        default='dnn',
        help='What model architecture to use')
    parser.add_argument(
        '--model_size_info',
        type=int,
        nargs="+",
        default=[128, 128, 128],
        help='Model dimensions - different for various models')

# Widgets Control Section
---

In [5]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from ipywidgets import AppLayout, Button, Layout, Box, FloatText, Textarea, Dropdown, Label, IntSlider
from IPython.display import display, HTML
from IPython.display import Image, clear_output

from collections import OrderedDict

class init_train_widgets():
    def __init__(self):   ###intial the widgets elements
          
        form_item_layout = Layout(
        display='flex',
        flex_flow='row',
        justify_content='space-between',
        )
        
        ### follow parameters widgets ###
        self.A_ch = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.B_ch = widgets.Text(value='work/DNN/DNN2/training/best/dnn_0.826_ckpt', placeholder='Type something', disabled=False)
        form_follow_items = [
            Box([Label(value = 'Follow the train process setting(recommend)'), self.A_ch], layout=form_item_layout),
            Box([Label(value = 'Model location'), self.B_ch], layout=form_item_layout)
        ]    
        self.form_box_follow_para = Box(form_follow_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightblue',
            align_items='stretch',
            width='50%',
        ))

        ### train model parameters widgets ###
        self.A_ta = Dropdown(options=['dnn', 'cnn', 'ds_cnn', 'basic_lstm'])
        self.B_ta = widgets.BoundedIntText(value=10, min=0, max=50.0, step=1, disabled=False)
        self.C_ta = widgets.BoundedIntText(value=10, min=0, max=50.0, step=1, disabled=False)
        self.G_ta = widgets.IntSlider(value=100, min=50, max=1000, step=50)
        self.H_ta = widgets.Text(value='128,128,128', placeholder='Type something', description='Int:', disabled=False)
        self.I_ta = widgets.Textarea(value='yes,no,up,down,left,right,on,off,stop,go', placeholder='Type something', description='String:', disabled=False)
        
        form_train_items = [
            Box([Label(value = 'Model Architecture'), self.A_ta], layout=form_item_layout),
            Box([Label(value = 'Testing percentage'), self.B_ta], layout=form_item_layout),
            Box([Label(value = 'Validation percentage'), self.C_ta], layout=form_item_layout),
            Box([Label(value = 'Batch size'), self.G_ta], layout=form_item_layout),
            Box([Label(value = 'Model size (dimension)'), self.H_ta], layout=form_item_layout),
            Box([Label(value = 'Wanted words'), self.I_ta], layout=form_item_layout),
        ]
        
        self.form_box_train_para = Box(form_train_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightblue',
            align_items='stretch',
            width='50%',
        ))
        
        
        ### data parameters widgets ###
        self.A_da = IntSlider(value=10, min=10, max=50)
        self.B_da = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.C_da = widgets.FloatSlider(value=0.1, min=0.0, max=1.0)
        self.D_da = widgets.FloatSlider(value=0.8, min=0.0, max=1.0)
        self.E_da = widgets.FloatSlider(value=10.0, min=0.0, max=30.0)
        self.F_da = widgets.FloatSlider(value=10.0, min=0.0, max=30.0)
        self.G_da = widgets.FloatSlider(value=100.0, min=50.0, max=200.0, step=10.0)
        self.H_da = widgets.IntSlider(value=16000, min=16000, max=32000, step=16000)
        self.I_da = widgets.IntSlider(value=1000, min=800, max=3000, step=200)
        self.J_da = widgets.IntSlider(value=40, min=10, max=100, step=10)
        self.K_da = widgets.IntSlider(value=40, min=10, max=100, step=10)
        
        form_data_items = [
            Box([Label(value = 'DCT coefficient count'), self.A_da], layout=form_item_layout),
            Box([Label(value = 'Data exist'), self.B_da], layout=form_item_layout),
            Box([Label(value = 'Background volume'), self.C_da], layout=form_item_layout),
            Box([Label(value = 'Background frequency'), self.D_da], layout=form_item_layout),
            Box([Label(value = 'Silence percentage'), self.E_da], layout=form_item_layout),
            Box([Label(value = 'Unknown percentage'), self.F_da], layout=form_item_layout),
            Box([Label(value = 'Time shift (ms)'), self.G_da], layout=form_item_layout),
            Box([Label(value = 'Sample rate'), self.H_da], layout=form_item_layout),
            Box([Label(value = 'Clip duration (ms)'), self.I_da], layout=form_item_layout),
            Box([Label(value = 'Window size (ms)'), self.J_da], layout=form_item_layout),
            Box([Label(value = 'Window stride (ms)'), self.K_da], layout=form_item_layout)
        ]
        
        self.form_box_data_para = Box(form_data_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightblue',
            align_items='stretch',
            width='50%',
        ))
    
    def create_command(self, cm_list):
        argument_list = ['--checkpoint',
                         '--model_architecture', '--testing_percentage', '--validation_percentage',  
                         '-batch_size', '--model_size_info', '--wanted_words', 
                         '--dct_coefficient_count', '--data_exist', '--background_volume','--background_frequency', '--silence_percentage', 
                         '--unknown_percentage', '--time_shift_ms', '--sample_rate','--clip_duration_ms', '--window_size_ms', '--window_stride_ms']
        cm_dict = OrderedDict()
        
        if(cm_list[0] == True):      #directly use train process setting
            cm_dict[argument_list[0]] = cm_list[1]  #save the checkpoint
            
            with open('train_cmd.txt','r') as f:
                train_cmd_line = f.read()
            train_cmd_list = train_cmd_line.split()
            if(train_cmd_list != []):
                print('read the exist train_cmd.txt')
            
            for idx, val in enumerate(train_cmd_list):
                if val in argument_list:    #find the needed attrs
                    
                    if val == '--model_size_info':
                        i = 1;
                        m_list = []
                        while (train_cmd_list[idx + i].find('--') == -1):
                            m_list.append(train_cmd_list[idx + i])
                            i = i+1
                        cm_dict[val] = m_list   
                    else:
                        cm_dict[val] = train_cmd_list[idx + 1] 
        else:
             for idx, val in enumerate(cm_list[1:]):
                    print(idx,val)
                    if argument_list[idx] == '--model_size_info':  #transfer from single string to list format
                        cm_dict[argument_list[idx]] = val.split(',')
                    else:
                        cm_dict[argument_list[idx]] = val    
             
        with open('test_cmd.txt','w') as f:  #save the complete command for test.py
            for key, value in cm_dict.items():
                
                if(type(value) == list):
                    f.write('%s ' % (key))
                    for i in range(len(value)):
                        f.write('%s ' % (value[i]))
                else:    
                    f.write('%s %s ' % (key, value))
     
        return 0
        
    def show_main(self):   ###interactive swection
        
        intro_text = 'Please Choose the parameters of the testing or using the default'
        htmlWidget = widgets.HTML(value = f"<b><font color='lightblue'><font size=4>{intro_text}</b>")
        display(htmlWidget)
        
        #Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_box_follow_para, self.form_box_train_para, self.form_box_data_para]).add_class("parentstyle")
        
        #Add a custom style tag to the notebook, you can use dev tool to inspect the class names
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, 'Follow Setting')
        accordion.set_title(1, 'Test Setting')
        accordion.set_title(2, 'Data Setting')
        
        
        def act_para(follow,model_loc,model,test_per,vali_per,batch,dims,outputs,
                     dct_coe,data,b_vol,b_freq,silence,unk,t_sft,rate,dura,win_size,win_str):
            toggle_train_save = widgets.ToggleButton(description='Save Test Setting', layout=Layout(width='30%', height='30px'), button_style='success')
            toggle_run = widgets.ToggleButton(description='Start to Run', layout=Layout(width='30%', height='30px'), button_style='success')
            out = widgets.Output(layout=Layout(border = '1px solid green'))
            
            if (follow):
                self.form_box_train_para.layout.visibility = 'hidden'
                self.form_box_data_para.layout.visibility = 'hidden'
            else:
                self.form_box_train_para.layout.visibility = 'visible'
                self.form_box_data_para.layout.visibility = 'visible'
            
            def para_process(obj):
                with out:
                    if obj['new']:
                        self.create_command([follow,model_loc,model,test_per,vali_per,batch,dims,outputs,
                              dct_coe,data,b_vol,b_freq,silence,unk,t_sft,rate,dura,win_size,win_str])
                        
                        text0 = 'The training setting is finish and saved'
                        html0= widgets.HTML(value = f"<b><font color='lightblue'><font size=2>{text0}</b>")
                        display(html0)
                        
                    else:
                        #print('re-start...')
                        out.clear_output()
                        
            def run(obj):
                with out:
                    if obj['new']:
                        self.run_test()
                        print('Finish')
                    else:
                        #print('stop')
                        out.clear_output()
            
            toggle_train_save.observe(para_process, 'value')
            toggle_run.observe(run, 'value')
            display(toggle_train_save, toggle_run)
            display(out)
                   
        
        out = widgets.interactive_output(act_para, {'follow': self.A_ch, 'model_loc': self.B_ch,
                                                    'model': self.A_ta, 'test_per': self.B_ta, 'vali_per': self.C_ta, 'batch': self.G_ta, 'dims': self.H_ta,
                                                    'outputs': self.I_ta, 
                                                    'dct_coe': self.A_da, 'data': self.B_da, 'b_vol': self.C_da, 'b_freq': self.D_da,
                                                    'silence': self.E_da, 'unk': self.F_da, 't_sft': self.G_da, 'rate': self.H_da,
                                                    'dura': self.I_da, 'win_size': self.J_da,  'win_str': self.K_da})
        display(accordion, out)
    
    def run_test(self):   ###run the mainprogram
        with open('test_cmd.txt','r') as f:  #save the complete command for train.py
            train_cmd_line = f.read()
        cmd_list = train_cmd_line.split()
        
        if(cmd_list != []):
            print('read the test commands!')
        else:
            print('The test_cmd.txt is empty!')
        
        FLAGS, _ = parser.parse_known_args(args = cmd_list)
        #FLAGS, _ = parser.parse_known_args(args = ['--model_architecture','dnn','--checkpoint',r'work\DNN\DNN3\training\best\dnn_0.835_ckpt',
        #'--model_size_info','128','128','128'])
        #print(FLAGS)
        test(FLAGS)

# Run Section
---
- The detail description of all the parameters is here [meaning](#id-PDD)
- `Follow the train process setting`: Please directly use the train setting of the same model
- `Model Location`: Please fill in the trained model location which is the `*_ckpt` file, for example: work/DNN/DNN2/training/dnn_0.826_ckpt


In [6]:
act = init_train_widgets()
act.show_main()

HTML(value="<b><font color='lightblue'><font size=4>Please Choose the parameters of the testing or using the d…

Accordion(children=(Box(children=(Box(children=(Label(value='Follow the train process setting(recommend)'), Ch…

Output()

<a id="id-PDD"></a>
# Parameter Description
---
- This notebook is basing on https://github.com/ARM-software/ML-examples/tree/main/tflu-kws-cortex-m.
- The Parameter Description is same as train, please check the `train.ipynb`