# Import Section
---

In [1]:
import os
import argparse
from pathlib import Path

import tensorflow as tf
import numpy as np
import shutil

from tensorflow.python.profiler.model_analyzer import profile
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder

In [2]:
#folder_exc = r'C:\Users\USERNAME\MICRO_ML\ML-examples-main\tflu-kws-cortex-m\Training'
folder_exc = r'C:\Users\garyc\ML_kws_tflu-main'
try:
    from google.colab import drive
    #drive.mount('/content/drive')
    print('Colab in:')
    print (os.getcwd())
    
except ImportError:
    print(r'Running Location:')
    print(os.path.abspath(os.getcwd()))
    #if (os.getcwd() != folder_exc)&(os.getcwd() != folder_exc.replace('/', "\\")):  
    #  os.chdir(folder_exc)
    #  print('update to:')
    #  print (os.getcwd())
    #else:
    #  print('no update')  
from kws_python import data
from kws_python import models

Running Location:
/ML_KWS/ML_kws_tflu


# Training Section
---

In [3]:
def train(FLAGS, save_cmd_fileName):
    
    #print(FLAGS.data_exist, FLAGS.model_size_info)
    
    model_settings = models.prepare_model_settings(len(data.prepare_words_list(FLAGS.wanted_words.split(','))),
                                                   FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
                                                   FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    # Create the model.
    model = models.create_model(model_settings, FLAGS.model_architecture, FLAGS.model_size_info, True)

    audio_processor = data.AudioProcessor(data_exist=FLAGS.data_exist,
                                          data_url=FLAGS.data_url,
                                          data_dir=FLAGS.data_dir,
                                          silence_percentage=FLAGS.silence_percentage,
                                          unknown_percentage=FLAGS.unknown_percentage,
                                          wanted_words=FLAGS.wanted_words.split(','),
                                          validation_percentage=FLAGS.validation_percentage,
                                          testing_percentage=FLAGS.testing_percentage,
                                          model_settings=model_settings)

      # We decay learning rate in a constant piecewise way to help learning.
    training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    lr_boundary_list = training_steps_list[:-1]  # Only need the values at which to change lr.
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=lr_boundary_list,
                                                                       values=learning_rates_list)
  
    # Specify the optimizer configurations.
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    model.compile(optimizer=optimizer,
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  metrics=['accuracy'])
  
    train_data = audio_processor.get_data(audio_processor.Modes.TRAINING,
                                          FLAGS.background_frequency, FLAGS.background_volume,
                                          int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000))
    train_data = train_data.repeat().batch(FLAGS.batch_size).prefetch(tf.data.AUTOTUNE)
    val_data = audio_processor.get_data(audio_processor.Modes.VALIDATION)
    val_data = val_data.batch(FLAGS.batch_size).prefetch(tf.data.AUTOTUNE)
  
    # We train for a max number of iterations so need to calculate how many 'epochs' this will be.
    training_steps_max = np.sum(training_steps_list)
    training_epoch_max = int(np.ceil(training_steps_max / FLAGS.eval_step_interval))
    
    # Callbacks.
    train_dir = Path(FLAGS.train_dir) / "best"
    train_dir.mkdir(parents=True, exist_ok=True)
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=(train_dir / (FLAGS.model_architecture + "_{val_accuracy:.3f}_ckpt")),
        save_weights_only=True,
        monitor='val_accuracy',
        mode='max',
        save_best_only=True)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.summaries_dir)
    
    #Save the train model seeting
    src = Path(os.getcwd()) / save_cmd_fileName
    dst = Path(os.getcwd()) / Path(FLAGS.train_dir) / save_cmd_fileName
    shutil.copy(src, dst)
    
    # Train the model.
    model.fit(x=train_data,
              steps_per_epoch=FLAGS.eval_step_interval,
              epochs=training_epoch_max,
              validation_data=val_data,
              callbacks=[model_checkpoint_callback, tensorboard_callback])
    
    # Test and save the model.
    test_data = audio_processor.get_data(audio_processor.Modes.TESTING)
    test_data = test_data.batch(FLAGS.batch_size)
    
    test_loss, test_acc = model.evaluate(x=test_data)
    print(f'Final test accuracy: {test_acc*100:.2f}%')
    
    # save result record
    forward_pass = tf.function(
        model.call,
        input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])])
    graph_info = profile(forward_pass.get_concrete_function().graph,
                            options=ProfileOptionBuilder.float_operation())
    # The //2 is necessary since `profile` counts multiply and accumulate
    # as two flops, here we report the total number of multiply accumulate ops
    flops = graph_info.total_float_ops // 2
    total_para = model.count_params()
    print('TensorFlow:', tf.__version__)
    print('The MACs of this model: {:,}'.format(flops))
    print('The total parameters of this model: {:,}'.format(total_para))
    
    test_txt_path = os.path.join(os.getcwd(), FLAGS.train_dir, 'result_record.txt')
    with open(test_txt_path, 'w') as f:
        f.write('Test accuracy: {}'.format(test_acc) + '\n')
        f.write('MACs: {}'.format(flops) + '\n')
        f.write('Total Parameters: {}'.format(total_para) + '\n')

# Argument Setting
---

In [4]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_exist',
        type=bool,
        default=True,
        help='True will skip download and tar.')
    parser.add_argument(
        '--data_url',
        type=str,
        default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
        help='Location of speech training data archive on the web.')
    parser.add_argument(
        '--data_dir',
        type=str,
        default='tmp/speech_dataset',
        help="""\
        Where to download the speech training data to.
        """)
    parser.add_argument(
        '--background_volume',
        type=float,
        default=0.1,
        help="""\
        How loud the background noise should be, between 0 and 1.
        """)
    parser.add_argument(
        '--background_frequency',
        type=float,
        default=0.8,
        help="""\
        How many of the training samples have background noise mixed in.
        """)
    parser.add_argument(
        '--silence_percentage',
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be silence.
        """)
    parser.add_argument(
        '--unknown_percentage',
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be unknown words.
        """)
    parser.add_argument(
        '--time_shift_ms',
        type=float,
        default=100.0,
        help="""\
        Range to randomly shift the training audio by in time.
        """)
    parser.add_argument(
        '--testing_percentage',
        type=int,
        default=10,
        help='What percentage of wavs to use as a test set.')
    parser.add_argument(
        '--validation_percentage',
        type=int,
        default=10,
        help='What percentage of wavs to use as a validation set.')
    parser.add_argument(
        '--sample_rate',
        type=int,
        default=16000,
        help='Expected sample rate of the wavs',)
    parser.add_argument(
        '--clip_duration_ms',
        type=int,
        default=1000,
        help='Expected duration in milliseconds of the wavs',)
    parser.add_argument(
        '--window_size_ms',
        type=float,
        default=30.0,
        help='How long each spectrogram timeslice is',)
    parser.add_argument(
        '--window_stride_ms',
        type=float,
        default=10.0,
        help='How long each spectrogram timeslice is',)
    parser.add_argument(
        '--dct_coefficient_count',
        type=int,
        default=40,
        help='How many bins to use for the MFCC fingerprint',)
    parser.add_argument(
        '--how_many_training_steps',
        type=str,
        #default='15,3',
        default='15000,3000',
        help='How many training loops to run',)
    parser.add_argument(
        '--eval_step_interval',
        type=int,
        default=400,
        help='How often to evaluate the training results.')
    parser.add_argument(
        '--learning_rate',
        type=str,
        default='0.001,0.0001',
        help='How large a learning rate to use when training.')
    parser.add_argument(
        '--batch_size',
        type=int,
        default=100,
        help='How many items to train with at once',)
    parser.add_argument(
        '--summaries_dir',
        type=str,
        default='/tmp/retrain_logs',
        help='Where to save summary logs for TensorBoard.')
    parser.add_argument(
        '--wanted_words',
        type=str,
        default='yes,no',
        help='Words to use (others will be added to an unknown label)',)
    parser.add_argument(
        '--train_dir',
        type=str,
        default='/tmp/speech_commands_train',
        help='Directory to write event logs and checkpoint.')
    parser.add_argument(
        '--model_architecture',
        type=str,
        default='dnn',
        help='What model architecture to use')
    parser.add_argument(
        '--model_size_info',
        type=int,
        nargs="+",
        default=[128, 128, 128],
        help='Model dimensions - different for various models')

# Widgets Control Section
---

In [11]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from ipywidgets import AppLayout, Button, Layout, Box, FloatText, Textarea, Dropdown, Label, IntSlider
from IPython.display import display, HTML
from IPython.display import Image, clear_output

from collections import OrderedDict

class init_train_widgets():
    def __init__(self):
        form_item_layout = Layout(
        display='flex',
        flex_flow='row',
        justify_content='space-between',
        )
        
        #button_layout = Layout(flex='1 1 auto',
        #                      width='auto')
        #button_words = ['Start Train']
        #self.button_items = [Button(description=w, layout=button_layout, button_style='success') for w in button_words]

        ### train model parameters widgets ###
        self.A_ta = Dropdown(options=['ds_cnn', 'dnn', 'cnn', 'basic_lstm'])
        self.B_ta = widgets.BoundedIntText(value=10, min=0, max=50.0, step=1, disabled=False)
        self.C_ta = widgets.BoundedIntText(value=10, min=0, max=50.0, step=1, disabled=False)
        self.D_ta = widgets.Text(value='5000,10000,10000', placeholder='Type something', description='String:', disabled=False)
        self.E_ta = widgets.Text(value='0.0005,0.0001,0.00002', placeholder='Type something', description='String:', disabled=False)
        self.F_ta = widgets.IntSlider(value=400, min=100, max=1500, step=100)
        self.G_ta = widgets.IntSlider(value=100, min=50, max=1000, step=50)
        self.H_ta = widgets.Text(value='5 64 10 4 2 2 64 3 3 1 1 64 3 3 1 1 64 3 3 1 1 64 3 3 1 1', placeholder='Type something', description='Int:', disabled=False)
        self.I_ta = widgets.Textarea(value='yes,no,up,down,left,right,on,off,stop,go', placeholder='Type something', description='String:', disabled=False)
        self.J_ta = widgets.Text(value='work/DS_CNN/1/retrain_logs', placeholder='Type something', description='String:', disabled=False)
        self.K_ta = widgets.Text(value='work/DS_CNN/1/training', placeholder='Type something', description='String:', disabled=False)
        
 
        form_train_items = [
            Box([Label(value = 'Model Architecture'), self.A_ta], layout=form_item_layout),
            Box([Label(value = 'Testing percentage'), self.B_ta], layout=form_item_layout),
            Box([Label(value = 'Validation percentage'), self.C_ta], layout=form_item_layout),
            Box([Label(value = 'Training Steps'), self.D_ta], layout=form_item_layout),
            Box([Label(value = 'Learning rates'), self.E_ta], layout=form_item_layout),
            Box([Label(value = 'Eval step interval'), self.F_ta], layout=form_item_layout),
            Box([Label(value = 'Batch size'), self.G_ta], layout=form_item_layout),
            Box([Label(value = 'Model size (dimension)'), self.H_ta], layout=form_item_layout),
            Box([Label(value = 'Wanted words'), self.I_ta], layout=form_item_layout),
            Box([Label(value = 'Summaries directory'), self.J_ta], layout=form_item_layout),
            Box([Label(value = 'Train directory'), self.K_ta], layout=form_item_layout)
        ]
        
        self.form_box_train_para = Box(form_train_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightblue',
            align_items='stretch',
            width='50%',
        ))
        
        
        ### data parameters widgets ###
        self.A_da = IntSlider(value=10, min=10, max=50)
        self.B_da = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.C_da = widgets.FloatSlider(value=0.5, min=0.0, max=1.0)
        self.D_da = widgets.FloatSlider(value=0.9, min=0.0, max=1.0)
        self.E_da = widgets.FloatSlider(value=20.0, min=0.0, max=50.0)
        self.F_da = widgets.FloatSlider(value=20.0, min=0.0, max=50.0)
        self.G_da = widgets.FloatSlider(value=200.0, min=50.0, max=500.0, step=10.0)
        self.H_da = widgets.IntSlider(value=16000, min=16000, max=32000, step=16000)
        self.I_da = widgets.IntSlider(value=1000, min=800, max=3000, step=200)
        self.J_da = widgets.IntSlider(value=40, min=10, max=100, step=10)
        self.K_da = widgets.IntSlider(value=20, min=10, max=100, step=10)
        
        
        form_data_items = [
            Box([Label(value = 'DCT coefficient count'), self.A_da], layout=form_item_layout),
            Box([Label(value = 'Data exist'), self.B_da], layout=form_item_layout),
            Box([Label(value = 'Background volume'), self.C_da], layout=form_item_layout),
            Box([Label(value = 'Background frequency'), self.D_da], layout=form_item_layout),
            Box([Label(value = 'Silence percentage'), self.E_da], layout=form_item_layout),
            Box([Label(value = 'Unknown percentage'), self.F_da], layout=form_item_layout),
            Box([Label(value = 'Time shift (ms)'), self.G_da], layout=form_item_layout),
            Box([Label(value = 'Sample rate'), self.H_da], layout=form_item_layout),
            Box([Label(value = 'Clip duration (ms)'), self.I_da], layout=form_item_layout),
            Box([Label(value = 'Window size (ms)'), self.J_da], layout=form_item_layout),
            Box([Label(value = 'Window stride (ms)'), self.K_da], layout=form_item_layout)
        ]
        
        self.form_box_data_para = Box(form_data_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightblue',
            align_items='stretch',
            width='50%',
        ))
    
    def folder_num_check(self, train_loc, dataset_list_check):
        for fld_name in dataset_list_check:
            check_fld = os.path.join(train_loc, fld_name)
            length = len([entry for entry in os.listdir(check_fld) if os.path.isfile(os.path.join(check_fld, entry))])
            if 15 > length:  # need < 15
                return 1
        return 0
    
    def create_command(self, cm_list):
        #print(cm_list)
        argument_list = ['--model_architecture', '--testing_percentage', '--validation_percentage', '--how_many_training_steps',
                         '--learning_rate', '--eval_step_interval', '-batch_size', '--model_size_info', '--wanted_words',
                         '--summaries_dir', '--train_dir', 
                         '--dct_coefficient_count', '--data_exist', '--background_volume','--background_frequency', '--silence_percentage', 
                         '--unknown_percentage', '--time_shift_ms', '--sample_rate','--clip_duration_ms', '--window_size_ms', '--window_stride_ms']
        cm_dict = OrderedDict()
    
        for idx, val in enumerate(cm_list):
            if argument_list[idx] == '--model_size_info':  #transfer from single string to list format
                cm_dict[argument_list[idx]] = val.split(',')
            else:
                cm_dict[argument_list[idx]] = val  
        print(cm_dict)        
        
        with open('train_cmd.txt','w') as f:  #save the complete command for train.py
            for key, value in cm_dict.items():
                
                if(type(value) == list):
                    f.write('%s ' % (key))
                    for i in range(len(value)):
                        f.write('%s ' % (value[i]))
                else:    
                    f.write('%s %s ' % (key, value))
     
        return 0
        
    def show_main(self):   
        
        intro_text = 'Please Choose the parameters of the training or using the default'
        htmlWidget = widgets.HTML(value = f"<b><font color='lightblue'><font size=4>{intro_text}</b>")
        display(htmlWidget)
        
        #Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_box_train_para, self.form_box_data_para]).add_class("parentstyle")
        #Add a custom style tag to the notebook, you can use dev tool to inspect the class names
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, 'Train Setting')
        accordion.set_title(1, 'Data Setting')
        
        
        def act_para(model,test_per,vali_per,steps,lr,step_inter,batch,dims,outputs,sum_dir,train_dir,
                     dct_coe,data,b_vol,b_freq,silence,unk,t_sft,rate,dura,win_size,win_str):
            toggle_train_save = widgets.ToggleButton(description='Save Train Setting', layout=Layout(width='30%', height='30px'), button_style='success')
            toggle_run = widgets.ToggleButton(description='Start to Run', layout=Layout(width='30%', height='30px'), button_style='success')
            out = widgets.Output(layout=Layout(border = '1px solid green'))
            def para_process(obj):
                with out:
                    if obj['new']:
                        self.create_command([model,test_per,vali_per,steps,lr,step_inter,batch,dims,outputs,sum_dir,train_dir,
                              dct_coe,data,b_vol,b_freq,silence,unk,t_sft,rate,dura,win_size,win_str])
                        
                        text0 = 'The training setting is finish and saved'
                        html0= widgets.HTML(value = f"<b><font color='lightblue'><font size=2>{text0}</b>")
                        display(html0)
                        
                    else:
                        print('re-start...')
                        #out.clear_output()
                        
            def run(obj):
                with out:
                    if obj['new']:
                        self.run_test()
                    else:
                        print('stop')
                        #out.clear_output()           
                        
            toggle_train_save.observe(para_process, 'value')
            toggle_run.observe(run, 'value')
            display(toggle_train_save, toggle_run)
            display(out)
                   
        
        out = widgets.interactive_output(act_para, {'model': self.A_ta, 'test_per': self.B_ta, 'vali_per': self.C_ta, 'steps': self.D_ta,
                                                    'lr': self.E_ta, 'step_inter': self.F_ta, 'batch': self.G_ta, 'dims': self.H_ta,
                                                    'outputs': self.I_ta, 'sum_dir': self.J_ta, 'train_dir': self.K_ta,
                                                    'dct_coe': self.A_da, 'data': self.B_da, 'b_vol': self.C_da, 'b_freq': self.D_da,
                                                    'silence': self.E_da, 'unk': self.F_da, 't_sft': self.G_da, 'rate': self.H_da,
                                                    'dura': self.I_da, 'win_size': self.J_da,  'win_str': self.K_da})

        display(accordion, out)
     
    
    def run_test(self):   ###run the mainprogram
        save_cmd_fileName = 'train_cmd.txt'
        with open(save_cmd_fileName,'r') as f:  #save the complete command for train.py
            train_cmd_line = f.read()
        cmd_list = train_cmd_line.split()
        print(cmd_list)
        
        for idx, val in enumerate(cmd_list):
            if val == 'False':
                print('change to bool')
                cmd_list[idx] = False
            if val == '--wanted_words':  # get the dataset's name
                dataset_list_check = cmd_list[idx+1].split(',')    
        
        if(cmd_list != []):
            print('read the train commands!')
        else:
            print('The train_cmd.txt is empty!')
        
        FLAGS, _ = parser.parse_known_args(args = cmd_list)
        #FLAGS, _ = parser.parse_known_args(args = ['--model_architecture','dnn','--checkpoint',r'work\DNN\DNN3\training\best\dnn_0.835_ckpt',
        #'--model_size_info','128','128','128'])
        #print(FLAGS)
        dataset_loc_for_check = os.path.join(os.getcwd(), 'tmp', 'speech_dataset')
        if not self.B_da.value:
            train(FLAGS, save_cmd_fileName)
            print('Finish')
        elif (self.folder_num_check(dataset_loc_for_check, dataset_list_check)):  # if any files < 15, don't run training
            print("The data is not enough, please > 15 files in each label folder.")
            print("The dataset path: {}".format(dataset_loc_for_check))
        else:
            train(FLAGS, save_cmd_fileName)
            print('Finish')

# Run Section
---
- The detail description of all the parameters is here [meaning](#id-PD)
- Please download the google train data at first time ==> click`Data Setting` tab, and unclick the `Data exist`

In [12]:
act = init_train_widgets()
act.show_main()

HTML(value="<b><font color='lightblue'><font size=4>Please Choose the parameters of the training or using the …

Accordion(children=(Box(children=(Box(children=(Label(value='Model Architecture'), Dropdown(options=('ds_cnn',…

Output()

<a id="id-PD"></a>
# Parameter Description
---
- This notebook is basing on https://github.com/ARM-software/ML-examples/tree/main/tflu-kws-cortex-m.

## Train Setting
- `Model Architecture`: What model architecture to use.
- `Testing percentage`: What percentage of wavs as a test set.
- `Validation percentage`: What percentage of wavs as a validation set.
- `Training steps`: How many training loops to run. It matches with the learning rates.
- `Learning rates`: How large a learning rate to use when training. It matches with the training steps.
- `Eval step interval`: How often to evaluate the training results.
- `Batch size`: How many items to train with at once.
- `Model size (dimension)`: Model dimensions - different for various models. For more detail, please check the `train_commands.txt`.
- `Wanted words`: Words to use (others will be added to an unknown label).
- `Summaries directory`: Where to save summary logs for TensorBoard.
- `Train directory`: Directory to write event logs and checkpoint(The trained model and weights).

## Data Setting
- `DCT coefficient count`: How many bins to use for the MFCC fingerprint
- `Data exist`: True will skip download and tar the default tensorflow's speech dataset. (Notice)When you first play this notebook, please unclick it for downlowing the train dataset at first time.
- `Background volume`: How loud the background noise should be, between 0 and 1.
- `Background frequency`: How many of the training samples have background noise mixed in.
- `Silence percentage`: How much of the training data should be silence.
- `Unknown percentage`: How much of the training data should be unknown words.
- `Time shift (ms)`: Range to randomly shift the training audio by in time.
- `Sample rate`: Expected sample rate of the wavs.
- `Clip duration (ms)`: Expected duration in milliseconds of the wavs.
- `Window size (ms)`: How long each spectrogram timeslice is.
- `Window stride (ms)`: Window stride in samples for calculating spectrogram.