# Import Section
---

In [1]:
import tensorflow as tf
import regex as re
import shutil
import json
import os

from object_detection.utils import label_map_util

import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from ipywidgets import AppLayout, Button, Layout, Box, FloatText, Textarea, Dropdown, Label, IntSlider
from IPython.display import display, HTML
from IPython.display import clear_output
from ipyfilechooser import FileChooser

import random
import numpy as np
from glob import glob

# Convert the tflite
---

In [2]:
class my_tflite_trans():
    def __init__(self, source_model_folder, output_tflite_location, rep_dataset_loc, input_img_size):
        self.source_model_folder = source_model_folder
        self.output_tflite_location = output_tflite_location
        
        self.rep_dataset_loc = rep_dataset_loc
        self.input_img_size = input_img_size
        
    def tflite_preprocess(self, image, height, width):
        if image.dtype != tf.float32:
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    
        # Resize the image to the specified height and width.
        image = tf.expand_dims(image, 0)
        image = tf.compat.v1.image.resize_bilinear(image, [height, width],
                                       align_corners=False)
        #image = tf.squeeze(image, [0])
    
        image = tf.subtract(image, 0.5)
        image = tf.multiply(image, 2.0)
        return image
    
    def representative_dataset(self):
        files = glob(self.rep_dataset_loc)
        random.shuffle(files)
        files = files[:128]
        for file in files:
            image = tf.io.read_file(file)
            image = tf.compat.v1.image.decode_jpeg(image)
            if image.get_shape()[2] == 3: # skip the not correct channel pictures
                image = self.tflite_preprocess(image, int(self.input_img_size), int(self.input_img_size))
            else:
                continue
            
            yield [image]

    def run_tflite(self, quant_options):
        # Refer to: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md#step-2-convert-to-tflite
        print("Start to convert, please wait...")
        if quant_options == 'None':
            output_location = self.output_tflite_location + r'.tflite'
        elif quant_options == 'Dynamic':
            output_location = self.output_tflite_location + r'_quant.tflite'
        elif quant_options == 'Full':
            output_location = self.output_tflite_location + r'_fullquant.tflite'
        elif quant_options == 'Float16':    
            output_location = self.output_tflite_location + r'_f16quant.tflite'
        
        converter = tf.lite.TFLiteConverter.from_saved_model(self.source_model_folder)
        
        if quant_options == 'Dynamic' or quant_options == 'Full' or quant_options == 'Float16':
            converter.optimizations = [tf.lite.Optimize.DEFAULT]
            print("tf lite Optimize")
        
        if quant_options == 'Full':
            converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.TFLITE_BUILTINS]
            converter.representative_dataset = self.representative_dataset
            converter.inference_input_type = tf.int8  # or tf.uint8
            #converter.inference_output_type = tf.int8  # or tf.uint8 # The head of TF ssd_mobileNetv2 has dequant
            print("Full quantation")
        
        if quant_options == 'Float16':   
            converter.target_spec.supported_types = [tf.float16]
            converter.representative_dataset = self.representative_dataset
            print("Float16 quantation")
        
        tflite_model = converter.convert()
          
        # Save the model.
        with open(output_location, 'wb') as f:
            f.write(tflite_model)
        print("The tflite output location: {}".format(output_location))    
        print("Finish!")    

# Widgets Control Section
---

In [3]:

class init_tflite_widgets():
    def __init__(self):
        
        self.tflite_file_loc = ""
        
        form_item_layout = Layout(
        display='flex',
        flex_flow='row',
        justify_content='space-between',
        )
        
        ### open source data download###
        self.A_ta = widgets.Text(value='training_demo', placeholder='Type something', disabled=False)
        self.B_ta = widgets.Text(value=r'exported-models\tflite_infer_graph', placeholder='Type something', disabled=False)
        self.C_ta = widgets.Text(value='ssd_mobileNetv2', placeholder='Type something', disabled=False)
        self.D_ta = widgets.Textarea(value=r'C:\\Users\\USERNAME\\image_detection\\TensorFlow\\workspace\\training_demo_8000\\images\\test', 
                                     placeholder='Type something', disabled=False)
        self.G_ta = widgets.Text(value='320', placeholder='Type something', disabled=False)
        self.E_ta = Dropdown(options=['None', 'Dynamic', 'Full', 'Float16'])
        self.F_ta = widgets.Checkbox(value=False, disabled=False, indent=False)
        
         
        form_train_items = [
            Box([Label(value = 'Your Working Directory Name'), self.A_ta], layout=form_item_layout),
            Box([Label(value = 'Source pb Model Folder'), self.B_ta], layout=form_item_layout),
            Box([Label(value = 'Output tflite Name'), self.C_ta], layout=form_item_layout),
            Box([Label(value = 'Rep Dataset Location'), self.D_ta], layout=form_item_layout),
            Box([Label(value = 'Input Image size'), self.G_ta], layout=form_item_layout),
            Box([Label(value = 'Quantization'), self.E_ta], layout=form_item_layout),
            Box([Label(value = 'Run All'), self.F_ta], layout=form_item_layout)
        ]
        
        self.form_output_train_cmd = Box(form_train_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightgreen',
            align_items='stretch',
            width='50%',
        ))
        
    def move_allfiles(self, src_folder, dst_folder):
        copy_num = 0
        
        files = os.listdir(src_folder)
        for f in files:
            fullpath = os.path.join(src_folder, f)
            if os.path.isdir(fullpath):  #copy whole folder
                shutil.move(fullpath, dst_folder)
                print("Copy finish: {}".format(f))
    
    def show_headline(self, output):
        html0= widgets.HTML(value = f"<b><font color='lightblue'><font size=4>{output}</b>")
        display(html0)
    
    def show_main(self):   
        
        intro_text = 'Please Choose the setting of train config'
        htmlWidget = widgets.HTML(value = f"<b><font color='lightgreen'><font size=6>{intro_text}</b>")
        display(htmlWidget)
        
        #Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_output_train_cmd 
                                                ]).add_class("parentstyle")
        #Add a custom style tag to the notebook, you can use dev tool to inspect the class names
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, 'Configure the Training')
        
        def act_para(training_dir, source_model_folder, tflite_name,
                    rep_dataset_folder, quant_options, run_all, input_img_size):
            #------------------#
            # The main executing Toggle_Button
            #------------------#
            toggle_convert_tflite = widgets.ToggleButton(description='Convert to tflite', 
                                                   layout=Layout(width='30%', height='30px'), button_style='success')
            out = widgets.Output(layout=Layout(border = '1px solid green'))
            
            
            #------------------#
            # The main executing Toggle_Button's event function
            #------------------#        
            def run_convert_tflite(obj):
                with out:
                    if obj['new']:
                        out.clear_output()
                        self.show_headline('Converting the model graph to tflite... ')
                        source_md = os.path.join(training_dir, source_model_folder, 'saved_model')
                        output_tflite = os.path.join(training_dir, source_model_folder, tflite_name)
                        rep_dataset = os.path.join(rep_dataset_folder, '*.jpg')
                        #print(source_md)
                        #print(output_tflite)
                        #print(rep_dataset)
                        if run_all:
                            x = my_tflite_trans(source_md, output_tflite, rep_dataset, input_img_size)
                            x.run_tflite('None')
                            x.run_tflite('Dynamic')
                            x.run_tflite('Full')
                            x.run_tflite('Float16')
                        else:
                            x = my_tflite_trans(source_md, output_tflite, rep_dataset, input_img_size)
                            x.run_tflite(quant_options)
                        
                    else:
                        print('stop')
                        out.clear_output() 
            
            toggle_convert_tflite.observe(run_convert_tflite, 'value')
            display(toggle_convert_tflite)
            display(out)
            
        
        #------------------#
        # widgets.Accordion's interactive input with action function `act_para()`
        #------------------#
        out_inter = widgets.interactive_output(act_para, {'training_dir': self.A_ta, 'source_model_folder': self.B_ta, 'tflite_name': self.C_ta, 
                                                          'rep_dataset_folder': self.D_ta, 'quant_options': self.E_ta, 'run_all': self.F_ta,
                                                          'input_img_size': self.G_ta
                                                          })
        display(accordion, out_inter)
        

# Run Section
---
- The detail description of all the parameters and each step meaning is here [meaning](#id-train_evl_monitor)
- In this notebook step, you have alreay finish the train. If not, please go to `workspace\train_evl_monitor.ipynb`.

In [4]:
act = init_tflite_widgets()
act.show_main()

HTML(value="<b><font color='lightgreen'><font size=6>Please Choose the setting of train config</b>")

Accordion(children=(Box(children=(Box(children=(Label(value='Your Working Directory Name'), Text(value='traini…

Output()