# Import Section
---

In [4]:
import regex as re
import shutil
import os
from pathlib import Path
from subprocess import Popen


import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from ipywidgets import AppLayout, Button, Layout, Box, FloatText, Textarea, Dropdown, Label, IntSlider
from IPython.display import display, HTML
from IPython.display import clear_output
from ipyfilechooser import FileChooser

from omegaconf import OmegaConf
import hydra
import shutil

# Widgets Control Section
---

In [33]:
class train_config_and_cmds_widgets():
    def __init__(self):

        self.cfg_dir = "cfg"
        self.my_cfg_path = Path(self.cfg_dir) / "my_config.yaml"
        
        self.tflite_file_loc = None
        self.src_audio_test_file = None
        
        form_item_layout = Layout(
        display='flex',
        flex_flow='row',
        justify_content='space-between',
        )
        
        ### Choose config ###
        self.A_de = Dropdown(options=['miniresnetv2_1stacks', 'miniresnetv2_2stacks', 'yamnet'])
        self.B_de = Dropdown(options=['ESC like'])
        self.C_de = widgets.Button(description='Set', layout=Layout(width='30%', height='30px'), button_style='success')
        
        form_cfg_prepare_items = [
            Box([Label(value = 'Model Type'),   self.A_de], layout=form_item_layout),
            Box([Label(value = 'Dataset Type'), self.B_de], layout=form_item_layout),
            Box([Label(value = 'Choose this Config'), self.C_de], layout=form_item_layout),
        ]
        
        self.form_cfg_prepare = Box(form_cfg_prepare_items,layout=Layout(
            display='flex',
            flex_flow='column',
            align_items='stretch',
            width='100%',
        ))
        
        #### data download ###
        ## Another flowers dataset
        ## https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
        #self.A_dp = widgets.Textarea(value='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', 
        #                             placeholder='Type something', disabled=False)
        ## cats_and_dogs.zip
        #self.B_dp = widgets.Text(value='flower_photos.tgz', placeholder='Type something', disabled=False)
        #
        #form_data_prepare_items = [
        #    Box([Label(value = 'URL Link'), self.A_dp], layout=form_item_layout),
        #    Box([Label(value = 'Zip Name'), self.B_dp], layout=form_item_layout),
        #]
        #
        #self.form_data_prepare_cmd = Box(form_data_prepare_items, layout=Layout(
        #    display='flex',
        #    flex_flow='column',
        #    border='solid 1px lightblue',
        #    align_items='stretch',
        #    width='100%',
        #))
        
        ### train ###
        self.A_ta = widgets.Text(value='miniresnetv2_ESC10_proj', placeholder='Type something', disabled=False)
        self.B_ta = widgets.IntSlider(value=16, min=4, max=64, step=4)
        self.C_ta = widgets.BoundedIntText(value=40, min=5, max=100000, step=1, disabled=False)
        self.E_ta = Dropdown(options=['Adam', 'RMSprop', 'SGD'])
        self.I_ta = widgets.BoundedFloatText(value=0.001, min=0.000001, max=0.1, step=0.000001, disabled=False)
        self.G_ta = Dropdown(options=['reducelronplateau', 'Exponential', 'Cosine', 'Constant'])
        self.D_ta = widgets.FloatSlider(value=0.1, min=0.1, max=0.5, step=0.1)
        self.F_ta = widgets.FloatSlider(value=0.2, min=0.1, max=0.5, step=0.1)
        self.H_ta = Textarea(value='dog,chainsaw,crackling_fire,helicopter,rain,crying_baby,clock_tick,sneezing,rooster,sea_waves',
                             disabled=False)
        self.J_ta = Textarea(value='datasets/ESC-50/audio',
                             disabled=False)
        self.K_ta = Textarea(value='datasets/ESC-50/meta/esc50.csv',
                             disabled=False)
        self.M_ta = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.N_ta = widgets.Checkbox(value=False, disabled=False, indent=False)
        self.O_ta = widgets.Checkbox(value=False, disabled=False, indent=False)
        
        form_train_items = [
            Box([Label(value = 'Project Name'),                          self.A_ta], layout=form_item_layout),
            Box([Label(value = 'Batch Size'),                            self.B_ta], layout=form_item_layout),
            Box([Label(value = 'Epochs'),                                self.C_ta], layout=form_item_layout),
            Box([Label(value = 'Optimizer'),                             self.E_ta], layout=form_item_layout),
            Box([Label(value = 'Learning Rate'),                         self.I_ta], layout=form_item_layout),
            Box([Label(value = 'Learning Rate Scheduler'),               self.G_ta], layout=form_item_layout),
            Box([Label(value = 'Validation Percent'),                    self.D_ta], layout=form_item_layout),
            Box([Label(value = 'Test Percent'),                          self.F_ta], layout=form_item_layout),
            Box([Label(value = 'Class Names'),                           self.H_ta], layout=form_item_layout),
            Box([Label(value = 'Use Other Class'),                       self.O_ta], layout=form_item_layout),
            Box([Label(value = 'Audio Folder Path'),                     self.J_ta], layout=form_item_layout),
            Box([Label(value = 'Csv File Path'),                         self.K_ta], layout=form_item_layout),
            Box([Label(value = 'Transfer Learning'),                     self.M_ta], layout=form_item_layout),
            Box([Label(value = 'Fine Tune'),                             self.N_ta], layout=form_item_layout),
        ]
        
        self.form_output_train_cmd = Box(form_train_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 1px lightblue',
            align_items='stretch',
            width='100%',
        ))

        ### pre-process###
        self.A_pp = widgets.IntSlider(value=2, min=1, max=10, step=1)
        self.B_pp = widgets.IntSlider(value=10, min=1, max=12, step=1)
        self.C_pp = widgets.IntSlider(value=16000, min=8000, max=22400, step=1000)
        self.D_pp = widgets.BoundedIntText(value=60, min=20, max=100, step=1, disabled=False)

        form_preprocess_items = [
            Box([Label(value = 'Min Length'),                            self.A_pp], layout=form_item_layout),
            Box([Label(value = 'Max Length'),                            self.B_pp], layout=form_item_layout),
            Box([Label(value = 'Target Rate'),                           self.C_pp], layout=form_item_layout),
            Box([Label(value = 'Top DB'),                                self.D_pp], layout=form_item_layout),
        ]

        self.form_preprocess_cmd = Box(form_preprocess_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 1px lightblue',
            align_items='stretch',
            width='100%',
        ))

        ### features ###
        self.A_fe = widgets.IntSlider(value=50, min=20, max=120, step=1)
        self.B_fe = widgets.IntSlider(value=64, min=32, max=160, step=1)
        self.C_fe = widgets.BoundedFloatText(value=0.25, min=0, max=0.75, step=0.01, disabled=False)
        self.D_fe = widgets.IntSlider(value=1024, min=256, max=2048, step=256)
        self.E_fe = widgets.IntSlider(value=320, min=160, max=1024, step=1)
        self.F_fe = widgets.Checkbox(value=True, disabled=False, indent=False)

        form_feature_items = [
            Box([Label(value = 'Patch Length'),                              self.A_fe], layout=form_item_layout),
            Box([Label(value = 'Number of Mels'),                            self.B_fe], layout=form_item_layout),
            Box([Label(value = 'Overlap Percent'),                           self.C_fe], layout=form_item_layout),
            Box([Label(value = 'Number of FFT'),                             self.D_fe], layout=form_item_layout),
            Box([Label(value = 'Hop Length'),                                self.E_fe], layout=form_item_layout),
            Box([Label(value = 'Convert to DB'),                             self.F_fe], layout=form_item_layout),
        ]

        self.form_feature_cmd = Box(form_feature_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 1px lightblue',
            align_items='stretch',
            width='100%',
        ))

        ### run buttoms box ###
        self.A_rb = widgets.Button(description='Set', layout=Layout(width='30%', height='30px'), button_style='success')
        self.B_rb = widgets.Button(description='Run', layout=Layout(width='30%', height='30px'), button_style='success')

        form_run_buttoms_items = [
            Box([Label(value = 'Set the Config'), self.A_rb], layout=form_item_layout),
            Box([Label(value = 'Start to Train'), self.B_rb], layout=form_item_layout),
        ]

        self.form_run_buttoms_cmd = Box(form_run_buttoms_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 1px lightblue',
            align_items='stretch',
            width='100%',
        ))

        ### test ###
        self.A_tt = widgets.Button(description='Setting', layout=Layout(width='30%', height='30px'), button_style='success')
        self.B_tt = widgets.Button(description='Setting', layout=Layout(width='30%', height='30px'), button_style='success')
        self.C_tt = widgets.Button(description='Run', layout=Layout(width='30%', height='30px'), button_style='success')
      
        form_test_items = [
            Box([Label(value = 'Choose the tflite file'), self.A_tt], layout=form_item_layout),
            Box([Label(value = 'Choose the audio file'),       self.B_tt], layout=form_item_layout),
            Box([Label(value = 'Start to Test'),          self.C_tt], layout=form_item_layout),
        ]
        
        self.form_output_test_cmd = Box(form_test_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightgreen',
            align_items='stretch',
            width='50%',
        ))
        
        ### convert model cpp ###
        self.A_cm = widgets.Text(value=r'..\workspace\miniresnetv2_ESC10_proj\quantized_models', placeholder='Type something', disabled=False)
        self.B_cm = widgets.Text(value=r'quantized_model.tflite', placeholder='Type something', disabled=False)
        self.C_cm = widgets.Text(value='..\workspace\miniresnetv2_ESC10_proj\quantized_models\vela', placeholder='Type something', disabled=False)
        self.E_cm = widgets.Button(description='Setting', layout=Layout(width='30%', height='30px'), button_style='success')
        self.D_cm = widgets.Button(description='Run', layout=Layout(width='30%', height='30px'), button_style='success')
      
        form_convert_items_paths =  [
            Box([Label(value = 'Choose the tflite file'), self.E_cm], layout=form_item_layout),
            Box([Label(value = 'MODEL SRC DIR'),          self.A_cm], layout=form_item_layout),
            Box([Label(value = 'MODEL SRC FILE'),         self.B_cm], layout=form_item_layout),
            Box([Label(value = 'GEN SRC DIR'),            self.C_cm], layout=form_item_layout)
        ]
        
        form_convert_items = [
            Box(form_convert_items_paths, layout=Layout(
            display='flex',
            flex_flow='column',
            justify_content ='center',    
            border='dotted 3px lightblue',
            align_items='stretch',
            width='70%')),
            Box([Label(value = 'Convert to cpp & Vela'), self.D_cm], layout=form_item_layout),
        ]
        
        self.form_output_convert_cmd = Box(form_convert_items, layout=Layout(
            display='flex',
            flex_flow='column',
            justify_content ='center',
            border='solid 3px lightgreen',
            align_items='stretch',
            width='70%',
        ))
        
    def move_allfiles(self, src_folder, dst_folder):
        copy_num = 0
        
        files = os.listdir(src_folder)
        for f in files:
            fullpath = os.path.join(src_folder, f)
            if os.path.isdir(fullpath):  #copy whole folder
                shutil.move(fullpath, dst_folder)
                print("Copy finish: {}".format(f))
    
    def show_headline(self, output):
        html0= widgets.HTML(value = f"<b><font color='lightblue'><font size=4>{output}</b>")
        display(html0)
    
    def show_main(self):   
        
        intro_text = 'Please Choose the setting of data prepare & train'
        htmlWidget = widgets.HTML(value = f"<b><font color='lightgreen'><font size=6>{intro_text}</b>")
        display(htmlWidget)
        
        #Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_output_train_cmd, self.form_preprocess_cmd, self.form_feature_cmd
                                                ]).add_class("parentstyle")
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, 'Configure the Training')
        accordion.set_title(1, 'Configure the Pre-processing')
        accordion.set_title(2, 'Configure the Feature-extraction')
        
        # Create a box combining with 2 elements
        box_data_train = Box([self.form_cfg_prepare, accordion, self.form_run_buttoms_cmd], layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightgreen',
            align_items='stretch',
            width='50%',
        ))
        
        #Create a tab and put the 2 boxes
        tab = widgets.Tab(children=[box_data_train, self.form_output_test_cmd, self.form_output_convert_cmd]).add_class("parentstyle")
        tab_contents = ['Train', 'Test', 'Deployment']
        tab.titles = tab_contents
       
        
        #Add a custom style tag to the notebook, you can use dev tool to inspect the class names
        #display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        #accordion.set_title(0, 'Configure the Training')        
        output_widgets = widgets.Output(layout=Layout(border = '1px solid green'))

        # Special observe for MODEL_NAME dependent value updating (fine tune layers)
        def update(*args):
            if self.E_ta.value.count('fdmobile'):
                self.K_ta.value = 6 # bcs the tf2cv structure is combining blocks
                self.K_ta.max = 10
                self.M_ta.layout.visibility = 'hidden'
                self.alpha_width_LA.layout.visibility = 'hidden'
            elif self.E_ta.value.count('shufflenet'):
                self.K_ta.value = 10 # bcs the tf2cv structure is combining blocks
                self.K_ta.max = 17
                self.M_ta.layout.visibility = 'hidden'
                self.alpha_width_LA.layout.visibility = 'hidden'
            elif self.E_ta.value.count('mobilenet_v1'):
                self.K_ta.max = 86
                self.K_ta.value = 40
                self.M_ta.layout.visibility = 'visible'
                self.alpha_width_LA.layout.visibility = 'visible'
            elif self.E_ta.value.count('mobilenet_v2'):
                self.K_ta.max = 154
                self.K_ta.value = 80
                self.M_ta.layout.visibility = 'visible'
                self.alpha_width_LA.layout.visibility = 'visible'
            elif self.E_ta.value.count('mobilenet_v3'):
                self.K_ta.max = 228
                self.K_ta.value = 120
                self.M_ta.layout.visibility = 'visible'
                self.alpha_width_LA.layout.visibility = 'visible'
            elif self.E_ta.value.count('mobilenet_v3_mini'):
                self.K_ta.max = 102
                self.K_ta.value = 50
                self.M_ta.layout.visibility = 'visible'
                self.alpha_width_LA.layout.visibility = 'visible'            
            elif self.E_ta.value.count('efficientnetB0'):
                self.K_ta.max = 238
                self.K_ta.value = 120
                self.M_ta.layout.visibility = 'hidden'
                self.alpha_width_LA.layout.visibility = 'hidden'
            elif self.E_ta.value.count('efficientnetv2B0'):
                self.K_ta.max = 170
                self.K_ta.value = 150
                self.M_ta.layout.visibility = 'hidden'
                self.alpha_width_LA.layout.visibility = 'hidden'        
        self.E_ta.observe(update)

        def act_para(A_de, B_de,
                     A_ta, B_ta, C_ta, D_ta, E_ta, F_ta, G_ta, H_ta, I_ta, J_ta, K_ta, M_ta, N_ta,
                     A_pp, B_pp, C_pp, D_pp,
                     A_fe, B_fe, C_fe, D_fe, E_fe, F_fe,
                     A_cm, B_cm, C_cm
                     ):
        
            # If any value is changed, clear the widgets
            with output_widgets:
                output_widgets.clear_output()
                
            #if data_exist:
            #    self.form_data_prepare_cmd.layout.visibility = 'hidden'
            #else:
            #    self.form_data_prepare_cmd.layout.visibility = 'visible'
        
        #------------------#
        # widgets.Accordion's interactive input with action function `act_para()`
        #------------------#
        out_inter = widgets.interactive_output(act_para, {'A_de': self.A_de, 'B_de': self.B_de,
                                                          'A_ta': self.A_ta, 'B_ta': self.B_ta, 'C_ta': self.C_ta,
                                                          'D_ta': self.D_ta, 'E_ta': self.E_ta, 'F_ta': self.F_ta,
                                                          'G_ta': self.G_ta, 'H_ta': self.H_ta, 'I_ta': self.I_ta,
                                                          'J_ta': self.J_ta, 'K_ta': self.K_ta, 'M_ta': self.M_ta,
                                                          'N_ta': self.N_ta,
                                                          'A_pp': self.A_pp, 'B_pp': self.B_pp, 'C_pp': self.C_pp,
                                                          'D_pp': self.D_pp,
                                                          'A_fe': self.A_fe, 'B_fe': self.B_fe, 'C_fe': self.C_fe,
                                                          'D_fe': self.D_fe, 'E_fe': self.E_fe, 'F_fe': self.F_fe,
                                                          'A_cm': self.A_cm, 'B_cm': self.B_cm, 'C_cm': self.C_cm,
                                                          })

        display(tab, out_inter)
        
        #------------------#
        # for labelimg cmd, move to outside of act_para to prevent keep trigering
        #------------------#
        #output_widgets = widgets.Output(layout=Layout(border = '1px solid green'))
        display(output_widgets)

        def on_button_clicked_CFG_set(b):
                with output_widgets:
                    clear_output()
                    #print("Config Set. . .")
                    self.copy_cfg()       
        self.C_de.on_click(on_button_clicked_CFG_set)

        def on_button_clicked_setCFG(b):
                with output_widgets:
                    clear_output()
                    #print("Train. . .")
                    self.update_cfg()       
        self.A_rb.on_click(on_button_clicked_setCFG)
        
        def on_button_clicked_train(b):
                with output_widgets:
                    clear_output()
                    #print("Train. . .")
                    self.run_train()       
        self.B_rb.on_click(on_button_clicked_train)
        
        def on_button_clicked_choose_tflite(b):
                with output_widgets:
                    clear_output()
                    self.choose_tflite()       
        self.E_cm.on_click(on_button_clicked_choose_tflite)
        
        def on_button_clicked_cpp(b):
                with output_widgets:
                    clear_output()
                    print("Convert to cpp & Vela. . .")
                    self.convert_tflu()       
        self.D_cm.on_click(on_button_clicked_cpp)
        
        def on_button_clicked_choose_tflite_test(b):
                with output_widgets:
                    clear_output()
                    self.choose_tflite()       
        self.A_tt.on_click(on_button_clicked_choose_tflite_test)

        def on_button_clicked_choose_audio_test(b):
                with output_widgets:
                    clear_output()
                    self.choose_audio()       
        self.B_tt.on_click(on_button_clicked_choose_audio_test)
        
        def on_button_clicked_test(b):
            with output_widgets:
                clear_output()
                self.run_test_tflite()           
        self.C_tt.on_click(on_button_clicked_test)

    def choose_tflite(self):
        path_ftflite = os.path.join(os.getcwd(), "workspace")
        f_tflite = FileChooser(path_ftflite)
        # Restrict navigation to /Users
        f_tflite.sandbox_path = os.getcwd()
        f_tflite.filter_pattern = ['*.tflite']
        f_tflite.title = f"<b><font color='lightblue'><font size=4>Choose the Tflite for test single audio.</b>"
        display(f_tflite)
        
        def act_test():
            work_dir_name = os.getcwd().split("\\")[-1]
            m_src_dir = r".." + f_tflite.selected_path.split(work_dir_name)[-1]
            m_src_tflite  = f_tflite.selected.split("\\")[-1]
            print("The chosen dir: {}".format(m_src_dir))
            print("The chosen tflite: {}".format(m_src_tflite))
            self.A_cm.value = m_src_dir
            self.B_cm.value = m_src_tflite
            self.C_cm.value = os.path.join(m_src_dir, "vela")
            print("Finish!")
        evt = interact_manual(act_test)
        evt.widget.children[0].description = 'Choose this tflite file'  #because there are 3 parameter of the evt
        evt.widget.children[0].button_style = 'primary'

    def choose_audio(self):
        path_ftflite = os.path.join(os.getcwd(), "datasets")
        f_tflite = FileChooser(path_ftflite)
        # Restrict navigation to /Users
        f_tflite.sandbox_path = os.getcwd()
        f_tflite.filter_pattern = ['*.wav']
        f_tflite.title = f"<b><font color='lightblue'><font size=4>Choose the audio file to test.</b>"
        display(f_tflite)
        
        def act_test():
            work_dir_name = os.getcwd().split("\\")[-1]
            self.src_audio_test_file = f_tflite.selected
            print("Finish!")
        evt = interact_manual(act_test)
        evt.widget.children[0].description = 'Choose this audio file'  #because there are 3 parameter of the evt
        evt.widget.children[0].button_style = 'primary'

    def copy_cfg(self):
        src_cfg = None
        files_file = [f for f in Path(self.cfg_dir).iterdir() if Path.is_file(f)]
        for dir_pth in files_file:
            if (str)(Path(dir_pth).stem).find(self.A_de.value.lower())!=-1:
                src_cfg = dir_pth
                print(src_cfg)
                break
        assert src_cfg!=None, "There is no {} model config file!!".format(self.A_de.value.lower())

        shutil.copyfile(src_cfg, self.my_cfg_path)

    def update_cfg(self):

        my_cfg = OmegaConf.load(self.my_cfg_path)
        #print(my_cfg)
        my_cfg['general']['project_name'] = self.A_ta.value
        my_cfg['train_parameters']['batch_size'] = self.B_ta.value
        my_cfg['train_parameters']['training_epochs'] = self.C_ta.value
        my_cfg['train_parameters']['optimizer'] = self.E_ta.value
        my_cfg['train_parameters']['initial_learning'] = self.I_ta.value
        my_cfg['train_parameters']['learning_rate_scheduler'] = self.G_ta.value

        my_cfg['dataset']['validation_split'] = self.D_ta.value
        my_cfg['dataset']['test_split'] = self.F_ta.value
        my_cfg['dataset']['class_names'] = self.H_ta.value.split(",")
        my_cfg['dataset']['use_other_class'] = self.O_ta.value
        my_cfg['dataset']['audio_path'] = self.J_ta.value
        my_cfg['dataset']['csv_path'] = self.K_ta.value
        my_cfg['model']['transfer_learning'] = self.M_ta.value
        my_cfg['model']['fine_tune'] = self.N_ta.value

        my_cfg['pre_processing']['min_length'] = self.A_pp.value
        my_cfg['pre_processing']['max_length'] = self.B_pp.value
        my_cfg['pre_processing']['target_rate'] = self.C_pp.value
        my_cfg['pre_processing']['top_db'] = self.D_pp.value

        my_cfg['feature_extraction']['patch_length'] = self.A_fe.value
        my_cfg['feature_extraction']['n_mels'] = self.B_fe.value
        my_cfg['feature_extraction']['overlap'] = self.C_fe.value
        my_cfg['feature_extraction']['n_fft'] = self.D_fe.value
        my_cfg['feature_extraction']['hop_length'] = self.E_fe.value
        my_cfg['feature_extraction']['to_db'] = self.F_fe.value

        # This must be set by [mels, frames]
        my_cfg['model']['input_shape'] = [self.B_fe.value, self.A_fe.value]

        my_cfg['hydra']['run']['dir'] = 'workspace/'+ self.A_ta.value +'_${now:%Y_%m_%d_%H_%M_%S}'
        
        OmegaConf.save(my_cfg, self.my_cfg_path)

        print("Save the updated CFG!")


    def run_train(self):
        
        %run train.py
        
        print("Finish !!")
          
    def run_test_tflite(self):
        
        tflite_location = os.path.join((self.A_cm.value).split("..\\")[-1], self.B_cm.value)
        print("The tflite file: {}".format(tflite_location))
        print("The audio file: {}".format(self.src_audio_test_file))

        #src_folder = r"datasets\ESC-50\audio"
        #self.src_audio_test_file = os.path.join(src_folder, r"1-110389-A-0.wav")
        
        tflite_inference(tflite_location, self.src_audio_test_file)
        
        print("Finish !!")
        
    def convert_tflu(self):
        
        %run exebat.py --SRC_DIR $self.A_cm.value --SRC_FILE $self.B_cm.value --GEN_DIR $self.C_cm.value
        
        print('Finish!')
            

# Inference single audio using tflite

In [34]:
import numpy as np
import librosa
sys.path.append(os.path.abspath('../utils'))
from preprocess import load_and_reformat
from feature_extraction import get_patches
import tensorflow as tf
from evaluation import _aggregate_predictions, compute_accuracy_score

def tflite_inference(tflite_location, src_wav_file):
    wave, sr = librosa.load(src_wav_file, sr=16000, duration=10)
    #wave = wave[24000:57408] #wave[24000:40704] #wave[24000:57408]
    
    patches = get_patches(wave=wave,
                          sr=sr,
                          patch_length=50,
                          overlap=0.25,
                          n_fft=1024,
                          hop_length=320,
                          include_last_patch=False,
                          win_length=1024,
                          window='hann',
                          center=True,
                          pad_mode='constant',
                          power=2.0,
                          n_mels=64,
                          fmin=20,
                          fmax=7500,
                          power_to_db_ref=np.max,
                          norm='slaney',
                          htk=False,
                          to_db=True,
                          )

    print("Total wav samples: {}".format(len(wave)))
    print("Num patches: {}".format(len(patches)))
    print("patch_length: {} n_mels: {}".format(len(patches[0][1]), len(patches[0])))
    
    clip_labels = []
    clip_labels.extend([0] * len(patches)) # only 1 test data
    clip_labels = np.array(clip_labels)
    X = []
    y = []
    X.extend(patches)
    X = np.stack(X, axis=0)
    X = np.expand_dims(X, axis=-1)
    
    y.extend(['dog'] * len(patches))
    #vocab = ['dog', 'chainsaw', 'crackling_fire', 'helicopter', 'rain',
    #       'crying_baby', 'clock_tick', 'sneezing', 'rooster', 'sea_waves']
    vocab = ["chainsaw","clock_tick","crackling_fire","crying_baby","dog","helicopter","other","rain","rooster","sea_waves","sneezing"]
    string_lookup_layer = tf.keras.layers.StringLookup(
            vocabulary=sorted(list(vocab)),
            num_oov_indices=0,
            output_mode='one_hot')
    y = np.array(string_lookup_layer(y))

    # Run the tflite
    X_test = X
    y_test = y
    tf.print('[INFO] Evaluating the quantized model ...')
    interpreter_quant = tf.lite.Interpreter(model_path=tflite_location)
    
    input_details = interpreter_quant.get_input_details()[0]
    #print(input_details)
    output_details = interpreter_quant.get_output_details()[0]
    #print(output_details)
    
    tf.print("[INFO] Quantization input details : {}".format(input_details["quantization"]))
    tf.print("[INFO] Dtype input details : {}".format(input_details["dtype"]))
    input_index_quant = interpreter_quant.get_input_details()[0]["index"]
    
    output_index_quant = interpreter_quant.get_output_details()[0]["index"]
    interpreter_quant.resize_tensor_input(input_index_quant, list(X_test.shape))
    interpreter_quant.allocate_tensors()
    X_processed = (X_test / input_details['quantization'][0]) + input_details['quantization'][1]
    
    print(np.iinfo(input_details['dtype']).min, np.iinfo(input_details['dtype']).max)
    #print(np.round(X_processed))
    
    X_processed = np.clip(np.round(X_processed), np.iinfo(input_details['dtype']).min, np.iinfo(input_details['dtype']).max)
    X_processed = X_processed.astype(input_details['dtype'])
    #print(X_processed)
    
    interpreter_quant.set_tensor(input_index_quant, X_processed)
    interpreter_quant.invoke()
    preds = interpreter_quant.get_tensor(output_index_quant)
    
    # Aggregate predictions
    aggregated_preds = _aggregate_predictions(preds=preds,
                                                clip_labels=clip_labels,
                                                is_multilabel=False,
                                                is_truth=False)
    aggregated_truth = _aggregate_predictions(preds=y_test,
                                                clip_labels=clip_labels,
                                                is_multilabel=False,
                                                is_truth=True)
     #generate the confusion matrix for the float model
    patch_level_accuracy = compute_accuracy_score(y_test, preds,
                                                    is_multilabel=False)
    #print("[INFO] : Quantized model patch-level accuracy on test set : {}".format(patch_level_accuracy))


    preds = preds.astype('float') 
    preds_q = (preds - output_details['quantization'][1]) * output_details['quantization'][0]
    print(vocab)
    print(preds_q)
            

# Run Section
---
- The detail description of all the parameters and each step meaning is here [meaning](#id-train_evl_monitor)
- In this notebook step, you have alreay finish the dataset prepared. If not, please go to `image_dataset\create_data.ipynb`.

In [36]:
act = train_config_and_cmds_widgets()
act.show_main()

HTML(value="<b><font color='lightgreen'><font size=6>Please Choose the setting of data prepare & train</b>")

Tab(children=(Box(children=(Box(children=(Box(children=(Label(value='Model Type'), Dropdown(options=('miniresn…

Output()

Output(layout=Layout(border_bottom='1px solid green', border_left='1px solid green', border_right='1px solid g…

In [29]:
from omegaconf import OmegaConf
import hydra
import shutil

my_cfg = OmegaConf.load('cfg/my_config.yaml')
print(my_cfg)
my_cfg['general']['project_name'] = [1, 2]
print(my_cfg['general']['project_name'])
print(my_cfg)




{'general': {'project_name': 'miniresnetv2_ESC10_proj', 'logs_dir': 'logs', 'saved_models_dir': 'saved_models'}, 'train_parameters': {'batch_size': 16, 'training_epochs': 40, 'optimizer': 'Adam', 'initial_learning': 0.001, 'patience': 100, 'learning_rate_scheduler': 'reducelronplateau', 'restore_best_weights': False}, 'dataset': {'name': 'custom', 'class_names': ['dog', 'chainsaw', 'crackling_fire', 'helicopter', 'rain', 'crying_baby', 'clock_tick', 'sneezing', 'rooster', 'sea_waves'], 'audio_path': 'datasets/ESC-50/audio', 'csv_path': 'datasets/ESC-50/meta/esc50.csv', 'file_extension': '.wav', 'validation_split': 0.1, 'test_split': 0.2, 'test_path': None, 'use_other_class': True, 'n_samples_per_other_class': 2, 'to_cache': True}, 'pre_processing': {'min_length': 2, 'max_length': 10, 'target_rate': 16000, 'top_db': 320, 'frame_length': 3200, 'hop_length': 3200, 'trim_last_second': False, 'lengthen': 'after'}, 'feature_extraction': {'patch_length': 50, 'n_mels': 64, 'overlap': 0.2499999

In [26]:
OmegaConf.save(my_cfg, "config.yaml")

In [31]:
my_cfg_read = OmegaConf.load('config.yaml')
print(my_cfg_read)

{'general': {'project_name': '[1, 2]', 'logs_dir': 'logs', 'saved_models_dir': 'saved_models'}, 'train_parameters': {'batch_size': 16, 'training_epochs': 40, 'optimizer': 'Adam', 'initial_learning': 0.001, 'patience': 100, 'learning_rate_scheduler': 'reducelronplateau', 'restore_best_weights': False}, 'dataset': {'name': 'custom', 'class_names': ['dog', 'chainsaw', 'crackling_fire', 'helicopter', 'rain', 'crying_baby', 'clock_tick', 'sneezing', 'rooster', 'sea_waves'], 'audio_path': 'datasets/ESC-50/audio', 'csv_path': 'datasets/ESC-50/meta/esc50.csv', 'file_extension': '.wav', 'validation_split': 0.1, 'test_split': 0.2, 'test_path': None, 'use_other_class': True, 'n_samples_per_other_class': 2, 'to_cache': True}, 'pre_processing': {'min_length': 2, 'max_length': 10, 'target_rate': 16000, 'top_db': 320, 'frame_length': 3200, 'hop_length': 3200, 'trim_last_second': False, 'lengthen': 'after'}, 'feature_extraction': {'patch_length': 50, 'n_mels': 64, 'overlap': 0.24999999999999997, 'n_ff