# Import Section
---

In [1]:
import os
from pathlib import Path

import argparse
import numpy as np
import tensorflow as tf

import logging
import shutil

In [2]:
#folder_exc = r'C:\Users\USERNAME\MICRO_ML\ML_kws_tflu'

try:  
    from google.colab import drive
    print('origin is:')
    print (os.getcwd())
    drive.mount('/content/drive')

    os.chdir(r'/content/drive/MyDrive/tflu-kws-cortex-m/Training')
    print('update to:')
    print (os.getcwd())
    
except ImportError:
    print(r'Running Location:')
    print(os.path.abspath(os.getcwd()))
    #if (os.getcwd() != folder_exc)&(os.getcwd() != folder_exc.replace('/', "\\")):  
    #  os.chdir(folder_exc)
    #  print('update to:')
    #  print (os.getcwd())
    #else:
    #  print('no update')  
from kws_python import data
from kws_python import models
from kws_python.test_tflite import tflite_test

Running Location:
C:\Users\USER\Desktop\ML\opennuvoton\ML_KWS\ML_kws_tflu


# Convert Section
---

In [3]:
NUM_REP_DATA_SAMPLES = 100  # How many samples to use for post training quantization.


def convert(FLAGS, model_settings, audio_processor, checkpoint, quantize, inference_type, tflite_path):
    """Load our trained floating point model and convert it.

    TFLite conversion or post training quantization is performed and the
    resulting model is saved as a TFLite file.
    We use samples from the validation set to do post training quantization.

    Args:
        model_settings: Dictionary of common model settings.
        audio_processor: Audio processor class object.
        checkpoint: Path to training checkpoint to load.
        quantize: Whether to quantize the model or convert to fp32 TFLite model.
        inference_type: Input/output type of the quantized model.
        tflite_path: Output TFLite file save path.
    """
    model = models.create_model(model_settings, FLAGS.model_architecture, FLAGS.model_size_info, False)
    model.load_weights(checkpoint).expect_partial()

    val_data = audio_processor.get_data(audio_processor.Modes.VALIDATION).batch(1)

    def _rep_dataset():
        """Generator function to produce representative dataset."""
        i = 0
        for mfcc, label in val_data:
            if i > NUM_REP_DATA_SAMPLES:
                break
            i += 1
            yield [mfcc]

    if quantize:
        # Quantize model and save to disk.
        tflite_model = post_training_quantize(model, inference_type, _rep_dataset)
        with open(tflite_path, 'wb') as f:
            f.write(tflite_model)
        print(f'Quantized model saved to {tflite_path}.')
    else:
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        tflite_model = converter.convert()
        with open(tflite_path, 'wb') as f:
            f.write(tflite_model)
        print(f'Converted model saved to {tflite_path}.')


def post_training_quantize(keras_model, inference_type, rep_dataset):
    """Perform post training quantization and returns the TFLite model ready for saving.

    See https://www.tensorflow.org/lite/performance/post_training_quantization#full_integer_quantization for
    more details.

    Args:
        keras_model: The trained tf Keras model used for post training quantization.
        inference_type: Input/output type of the quantized model.
        rep_dataset: Function to use as a representative dataset, must be callable.

    Returns:
        Quantized TFLite model ready for saving to disk.
    """
    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    if inference_type == "int8":
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8

    # Int8 post training quantization needs representative dataset.
    converter.representative_dataset = rep_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

    tflite_model = converter.convert()

    return tflite_model


def main_convert(FLAGS):
    model_settings = models.prepare_model_settings(len(data.prepare_words_list(FLAGS.wanted_words.split(','))),
                                                   FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
                                                   FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    audio_processor = data.AudioProcessor(data_exist=FLAGS.data_exist,
                                          data_url=FLAGS.data_url,
                                          data_dir=FLAGS.data_dir,
                                          silence_percentage=FLAGS.silence_percentage,
                                          unknown_percentage=FLAGS.unknown_percentage,
                                          wanted_words=FLAGS.wanted_words.split(','),
                                          validation_percentage=FLAGS.validation_percentage,
                                          testing_percentage=FLAGS.testing_percentage,
                                          model_settings=model_settings)

    if FLAGS.quantize:
        if FLAGS.inference_type == 'int8':
            tflite_path = f'{FLAGS.model_architecture}_int8quant.tflite'
        else:
            tflite_path = f'{FLAGS.model_architecture}_dyquant.tflite'
    else:
        tflite_path = f'{FLAGS.model_architecture}.tflite'
        
    tflite_path = os.path.join(FLAGS.checkpoint.split('best')[0], tflite_path)    

    # Load floating point model from checkpoint and convert it.
    convert(FLAGS, model_settings, audio_processor, FLAGS.checkpoint,
            FLAGS.quantize, FLAGS.inference_type, tflite_path)

    # Test the newly converted model on the test set.
    tflite_test(model_settings, audio_processor, tflite_path)
    
    return tflite_path

# Argument Setting
---

In [4]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_exist',
        type=bool,
        default=True,
        help='True will skip download and tar.')
    parser.add_argument(
        '--data_url',
        type=str,
        default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
        help='Location of speech training data archive on the web.')
    parser.add_argument(
        '--data_dir',
        type=str,
        default='tmp/speech_dataset/',
        help="""\
        Where to download the speech training data to.
        """)
    parser.add_argument(
        '--silence_percentage',
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be silence.
        """)
    parser.add_argument(
        '--unknown_percentage',
        type=float,
        default=10.0,
        help="""\
        How much of the training data should be unknown words.
        """)
    parser.add_argument(
        '--testing_percentage',
        type=int,
        default=10,
        help='What percentage of wavs to use as a test set.')
    parser.add_argument(
        '--validation_percentage',
        type=int,
        default=10,
        help='What percentage of wavs to use as a validation set.')
    parser.add_argument(
        '--sample_rate',
        type=int,
        default=16000,
        help='Expected sample rate of the wavs',)
    parser.add_argument(
        '--clip_duration_ms',
        type=int,
        default=1000,
        help='Expected duration in milliseconds of the wavs',)
    parser.add_argument(
        '--window_size_ms',
        type=float,
        default=30.0,
        help='How long each spectrogram timeslice is',)
    parser.add_argument(
        '--window_stride_ms',
        type=float,
        default=10.0,
        help='How long each spectrogram timeslice is',)
    parser.add_argument(
        '--dct_coefficient_count',
        type=int,
        default=40,
        help='How many bins to use for the MFCC fingerprint',)
    parser.add_argument(
        '--wanted_words',
        type=str,
        default='yes,no,up,down,left,right,on,off,stop,go',
        help='Words to use (others will be added to an unknown label)',)
    parser.add_argument(
        '--model_architecture',
        type=str,
        default='dnn',
        help='What model architecture to use')
    parser.add_argument(
        '--model_size_info',
        type=int,
        nargs="+",
        default=[128, 128, 128],
        help='Model dimensions - different for various models')
    parser.add_argument(
        '--checkpoint',
        type=str,
        help='Checkpoint to load the weights from.')
    parser.add_argument(
        '--quantize',
        dest='quantize',
        action="store_true",
        default=True,
        help='Whether to quantize the model or convert to fp32 TFLite model. Defaults to True.')
    parser.add_argument(
        '--no-quantize',
        dest='quantize',
        action="store_false",
        help='Whether to quantize the model or convert to fp32 TFLite model. Defaults to True.')
    parser.add_argument(
        '--inference_type',
        type=str,
        default='fp32',
        help='If quantize is true, whether the model input and output is float32 or int8')

# Widgets Control Section
---

In [5]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from ipywidgets import AppLayout, Button, Layout, Box, FloatText, Textarea, Dropdown, Label, IntSlider
from IPython.display import display, HTML
from IPython.display import Image, clear_output

from collections import OrderedDict

class init_train_widgets():
    def __init__(self):   ###intial the widgets elements
        
        self.cmd_list = [] # command list
        self.tflite_path = 'dnn_quantized.tflite' # tflite file name
        self.tflu_model_dir = 'my_tflu_model'
        self.tflu_files_list = os.listdir(self.tflu_model_dir)
        
        self.tflu_c_proj_saveLoc = 'C:/Users/ML_M460_NuKws_SampleCode/SampleCode/tflu_kws_arm_rt_mc/Generated/DNN'
        self.tflu_c_proj_runDir = 'MyRunModel'
        self.tflu_c_proj_runName = 'runModel.cc'
        
        form_item_layout = Layout(
        display='flex',
        flex_flow='row',
        justify_content='space-between',
        )
        
        ### follow parameters widgets ###
        self.A_ch = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.B_ch = widgets.Text(value='work/DS_CNN/1/training/best/ds_cnn_0.933_ckpt', placeholder='Type something', disabled=False)
        self.C_ch = widgets.Checkbox(value=False, disabled=False, indent=False)
        self.D_ch = Dropdown(value='int8', options=['fp32', 'int8'])  
        self.E_ch = widgets.Text(value='number_en', placeholder='Type something', disabled=False)
        self.F_ch = widgets.Dropdown(options=['M55M1', 'M467'], value='M55M1', disabled=False)
        self.G_ch = widgets.Button(description='Start to Run', layout=Layout(width='50%', height='30px'), button_style='success')
        self.H_ch = widgets.Button(description='Start to Run', layout=Layout(width='50%', height='30px'), button_style='success')
        
        form_follow_items = [
            Box([Label(value = 'Follow the train process setting(must)'), self.A_ch], layout=form_item_layout),
            Box([Label(value = 'Model Location'), self.B_ch], layout=form_item_layout),
            Box([Label(value = 'No-Quantize'), self.C_ch], layout=form_item_layout),
            Box([Label(value = 'Inference Type'), self.D_ch], layout=form_item_layout),
            Box([Label(value = 'Convert to tflite model'), self.G_ch], layout=form_item_layout),
            Box([Label(value = 'Cpp Model Name'), self.E_ch], layout=form_item_layout),
            Box([Label(value = 'Deployment Board'), self.F_ch], layout=form_item_layout),
            Box([Label(value = 'tflite to tflu'), self.H_ch], layout=form_item_layout)
        ]    
        self.form_box_follow_para = Box(form_follow_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightblue',
            align_items='stretch',
            width='50%',
        ))
        
        ### deployment parameters widgets ###
        self.A_dp = widgets.Dropdown(options=self.tflu_files_list)
        self.B_dp = widgets.Textarea(value=self.tflu_c_proj_saveLoc, placeholder='Type something', disabled=False)
        self.C_dp = widgets.ToggleButton(description='Deploy Model', layout=Layout(width='30%', height='30px'), button_style='success')
        form_deploy_items = [
            Box([Label(value = 'Choose the model'), self.A_dp], layout=form_item_layout),
            Box([Label(value = 'The location of model deployment'), self.B_dp], layout=form_item_layout),
            Box([Label(value = 'Copy to your proj.'), self.C_dp], layout=form_item_layout)
        ]    
        self.form_box_deploy_para = Box(form_deploy_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightblue',
            align_items='stretch',
            width='50%',
        ))
     
    def create_folder(self, dir_path):
        try:
            os.mkdir(dir_path)
        except OSError as error:
            print(error)
            print('skip create')
    
    def create_command(self, value_list):
        argument_list = ['--checkpoint', '--no-quantize', '--inference_type']
        cm_dict = OrderedDict()
             
        if(value_list[0]):
            train_cmd_path = Path(self.B_ch.value).parents[1] / 'train_cmd.txt'
            with open(train_cmd_path,'r') as f:  #save the complete command for train.py
                train_cmd_line = f.read()
            self.cmd_list = train_cmd_line.split()
            
            if(self.cmd_list != []):
                print('read the train commands!')
            else:
                print('The train_cmd.txt is empty!')
                
            for idx, val in enumerate(value_list[1:]):
                if(idx == 1):   #--no-quantize attr
                    if(val == True):
                        self.cmd_list.append(argument_list[idx])
                else:    
                    self.cmd_list.append(argument_list[idx])
                    self.cmd_list.append(val)
        #print(self.cmd_list)
        
    def tflite_to_tflu(self, inf_type_s, my_f_name, tflite_name):
        out_file = my_f_name + '_' + tflite_name.split('/')[-1].split('.tflite')[0] + '.cc'
        out_file = self.tflu_model_dir + '/' + out_file
        ! python tflite_to_tflu.py --tflite_path $tflite_name --output_path $out_file
        print(tflite_name)
        return out_file
    
    def tflite_to_tflu_para(self, inf_type_s, my_f_name, tflite_name, para_list, board_type):
        para_string = ''
        for key in para_list:
            para_string = para_string + key + ' ' + para_list[key] + ' '
       
        out_file = my_f_name + '_' + tflite_name.split('/')[-1].split('.tflite')[0] + '.cc'
        out_file = Path(tflite_name).parent / board_type / out_file

        self.create_folder(Path(tflite_name).parent / board_type)
        
        ! python kws_python/tflite_to_tflu_para.py --tflite_path $tflite_name --output_path $out_file $para_string
        
        print(f'original TFLite: {tflite_name}')
        return out_file

    def tflite_to_tflu_para_m55m1(self, inf_type_s, my_f_name, tflite_name, para_list, board_type):
        para_string = ''
        for key in para_list:
            para_string = para_string + key + ' ' + para_list[key] + ' '
       
        out_file = my_f_name + '_' + tflite_name.split('/')[-1].split('.tflite')[0] + '.cc'
        out_file = Path(tflite_name).parent / board_type / out_file

        self.create_folder(Path(tflite_name).parent / board_type)
        
        ! python kws_python/tflite_to_tflu_para_m55m1.py --tflite_path $tflite_name --output_path $out_file $para_string
        print(f'Original TFLite: {tflite_name}')
        return out_file

    def autogen_label_cc(self, tflite_name, para_list, board_type):
        para_string_list = para_list['--wanted_words'].split(',')
        para_string_label = '-l _silence_ _unknown_'
        for label in para_string_list:
            para_string_label = para_string_label + ' ' + label

        out_dir = Path(tflite_name).parent / board_type
        
        ! python kws_python/gen_label_cpp.py --output_dir $out_dir $para_string_label
    
    def get_train_parameter(self, wanted_para_list):
            train_cmd_path = Path(self.B_ch.value).parents[1] / 'train_cmd.txt'
            with open(train_cmd_path,'r') as f:
                train_cmd_line = f.read()
                
            train_cmd_list = train_cmd_line.split()
            if(train_cmd_list != []):
                print('read the exist train_cmd.txt')
            else:
                print('There is no train_cmd.txt')
            
            cm_para_dict = OrderedDict()
            for idx, val in enumerate(train_cmd_list):
                if val in wanted_para_list:
                    cm_para_dict[val] = train_cmd_list[idx + 1] 
                        
            return cm_para_dict
    
    def deploy_tflu_to_proj(self, model_name, dst_loc):
        src_model_loc = os.path.join(self.tflu_model_dir, model_name)
        if not os.path.exists(dst_loc):
            print('Not exist: ')
            print(dst_loc)
            os.mkdir(dst_loc)
            
        ###copy to a save folder    
        shutil.copy(src_model_loc, os.path.join(dst_loc, model_name))
        print('The copy saved model is here:')
        print(os.path.join(dst_loc, model_name))
        print('\n')
        
        ###copy to a run folder
        dst_run_loc = os.path.join(os.path.split(dst_loc)[0], self.tflu_c_proj_runDir)
        self.create_folder(dst_run_loc)
        dst_run_loc = os.path.join(dst_run_loc, self.tflu_c_proj_runName)
        shutil.copy(src_model_loc, dst_run_loc) 
        print('The run model is here:')
        print(dst_run_loc)
        
    def show_main(self):   ###interactive swection
        
        intro_text = 'Please Choose the parameters of the testing or using the default'
        htmlWidget = widgets.HTML(value = f"<b><font color='lightblue'><font size=4>{intro_text}</b>")
        display(htmlWidget)
        
        #Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_box_follow_para, self.form_box_deploy_para]).add_class("parentstyle")
        #Add a custom style tag to the notebook, you can use dev tool to inspect the class names
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, 'Quantizing Setting')
        accordion.set_title(1, 'Deployment')
        
        
        def act_para(follow,model_loc,no_qu,inf_type,my_f_name,para_bring,model_cpy,cpy_loc,d_button):
            out = widgets.Output()
            
            ###Deployment section###
            self.tflu_files_list = os.listdir(self.tflu_model_dir) ### update the files list at each action
            self.A_dp.options = self.tflu_files_list
            
            if d_button:  
                with out:
                    #print('yes\n')
                    self.deploy_tflu_to_proj(model_cpy, cpy_loc)
            else:
                with out:
                    out.clear_output()
                    #print('no')
                    
            display(out)
                   
        
        out = widgets.interactive_output(act_para, {'follow': self.A_ch, 'model_loc': self.B_ch, 'no_qu': self.C_ch, 
                                                    'inf_type': self.D_ch, 'my_f_name': self.E_ch, 'para_bring' : self.F_ch,
                                                    'model_cpy':self.A_dp, 'cpy_loc':self.B_dp, 'd_button':self.C_dp
                                                    })
        display(accordion, out)
        
        #------------------#
        # buttoms event control in widgets.Accordion
        #------------------# 
        output_button = widgets.Output(layout=Layout(border = '1px solid green'))
        display(output_button)
        def on_button_clicked_convert_tflite(b):
            with output_button:
                clear_output()
                
                self.create_command([self.A_ch.value, self.B_ch.value, self.C_ch.value, self.D_ch.value])
                text0 = 'The convert setting is finish and saved'
                html0= widgets.HTML(value = f"<b><font color='lightblue'><font size=2>{text0}</b>")
                display(html0)
                
                self.run_convert()
                print('Finish')    
        self.G_ch.on_click(on_button_clicked_convert_tflite)
         
        def on_button_clicked_tflu(b):
            with output_button:
                clear_output()
                ### update the tflite_path
                train_cmd_path = Path(self.B_ch.value).parents[1] / 'train_cmd.txt'
                with open(train_cmd_path,'r') as f:
                    train_cmd_line = f.read()
                train_cmd_list = train_cmd_line.split()
                for idx, val in enumerate(train_cmd_list):
                    if val == '--model_architecture':
                        if not self.C_ch.value:
                            if self.D_ch.value == 'int8':
                                self.tflite_path = f'{train_cmd_list[idx + 1]}_int8quant.tflite'
                            else:
                                self.tflite_path = f'{train_cmd_list[idx + 1]}_dyquant.tflite'
                        else:
                            self.tflite_path = f'{train_cmd_list[idx + 1]}.tflite'
                    
                self.tflite_path = os.path.join(self.B_ch.value.split('best')[0], self.tflite_path)           
            
                ### weather to bring kws specify parameter to fflu.cc
                if self.F_ch.value == 'M55M1':
                    wanted_para_list = ['--window_size_ms', '--window_stride_ms', '--dct_coefficient_count', '--sample_rate', '--clip_duration_ms']
                    para_list = self.get_train_parameter(wanted_para_list)
                    out_ccfile_path = self.tflite_to_tflu_para_m55m1(self.D_ch.value, self.E_ch.value, self.tflite_path, para_list, self.F_ch.value)
                    print('Finish converting to:  {}'.format(out_ccfile_path))
                elif self.F_ch.value == 'M467':
                    wanted_para_list = ['--window_size_ms', '--window_stride_ms', '--dct_coefficient_count']
                    para_list = self.get_train_parameter(wanted_para_list)
                    print('Finish converting to:  {}'.format(
                        self.tflite_to_tflu_para(self.D_ch.value, self.E_ch.value, self.tflite_path, para_list, self.F_ch.value)))

                # AutoGen label file
                wanted_para_list = ['--wanted_words']
                para_list = self.get_train_parameter(wanted_para_list)
                self.autogen_label_cc(self.tflite_path, para_list, self.F_ch.value)
                
        self.H_ch.on_click(on_button_clicked_tflu)        
        
    
    
    def run_convert(self):   ###run the mainprogram
        
        FLAGS, _ = parser.parse_known_args(args = self.cmd_list)
        #FLAGS, _ = parser.parse_known_args(args = ['--model_architecture','dnn','--checkpoint',r'work\DNN\DNN3\training\best\dnn_0.835_ckpt',
        #'--model_size_info','128','128','128'])
        #print(FLAGS)
        logger = logging.getLogger()
        logger.setLevel(logging.CRITICAL)

        self.tflite_path = main_convert(FLAGS)

# Run Section
---
- The detail description of all the parameters is here [meaning](#id-PDD).
- `Follow the train process setting`: Please directly use the train setting of the same model (in `train_cmd.txt`).
- After settting finish, please click `Convert to tflite model` to convert the model to tflite model.
- The final step is to convert from tflite to tflu, please click `tflite to tflu`.


In [6]:
act = init_train_widgets()
act.show_main()

HTML(value="<b><font color='lightblue'><font size=4>Please Choose the parameters of the testing or using the d…

Accordion(children=(Box(children=(Box(children=(Label(value='Follow the train process setting(must)'), Checkbo…

Output()

Output(layout=Layout(border_bottom='1px solid green', border_left='1px solid green', border_right='1px solid g…

<a id="id-PDD"></a>
# Parameter Description
---
- This notebook is basing on [ARM-software/ML-examples](https://github.com/ARM-software/ML-examples/tree/main/tflu-kws-cortex-m).
- `Model Location`: Please fill in the trained model location which is the `*_ckpt` file, for example: work/DNN/DNN2/training/dnn_0.826_ckpt
- `No-Quantize`: Whether to quantize the model or convert to fp32 TFLite model. Defaults to True. 
- `Inference Type`: If quantize is true, whether the model input and output is float32 or int8
- `File Name`: The name of quantized model in c++ style. This file can be load into mcu.
- `Parameters Inherited`: Recommend enable. This will add the KWS specify parameters into tflu.cc which user no need to update KWS parameters manually in MCU C++ code.
- Post-training quantization: [Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)
- More description: No-Quantize = Dynamic range quantization: At inference, weights are converted from 8-bits of precision to floating point and computed using floating-point kernels. This conversion is done once and cached to reduce latency.

# nuvoTon m460 for KWS running example 
---
- There are 4 examples, 2 for offline, and 2 for online.
    

## offline examples are in SampleCode/tflu_kws_arm & SampleCode/tflu_kws_arm_mc
- tflu_kws_arm can run DNN and user can update the `#include "raw/<keyWord>.h"` in `main.c` for test different PCM header style data in `raw` folder.
- tflu_kws_arm_mc can run DNN & DS-CNN model with only update `#define <which model>` in `model.h`.
- There is a small notebook called `transferPWM.ipynb` which can help you transfer `*.wav` file to C style `<keyWord>.h`. In this way, you can test the model offline with any new sliced `*.wav` file.

## online examples are in SampleCode/tflu_kws_arm_rt & SampleCode/tflu_kws_arm_rt_mc
- tflu_kws_arm_rt can run DNN. (detail: in MCU, the each inference is after 1/25 * 16000 data collected finish by PDMA through I2S and codec) 
- tflu_kws_arm_rt_mc can run DNN & DS-CNN model with only update `#define <which model>` in `model.h`. (detail: in MCU, the each inference is after 16000 data collected finish by PDMA through I2S and codec)