# Import Section
---


In [1]:
import tensorflow.lite as tflite

import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

import tensorflow_datasets as tfds
from time import perf_counter

import glob
import random
from PIL import Image
from object_detection.utils import label_map_util

import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from ipywidgets import AppLayout, Button, Layout, Box, FloatText, Textarea, Dropdown, Label, IntSlider
from IPython.display import display, HTML
from IPython.display import clear_output
from ipyfilechooser import FileChooser

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Inference tflite Section
---

In [2]:
class Detector():
    def __init__(self, tflite_model_loc):
        self._max_results = 30
        #self._interpreter = tflite.Interpreter(model_path='mobilev2_ssd_noopt.tflite', num_threads=4)
        #self._interpreter = tflite.Interpreter(model_path='mobilev2_ssd_17.tflite', num_threads=4)
        #self._interpreter = tflite.Interpreter(model_path='mobilev2_ssd_intquant.tflite', num_threads=4)
        #self._interpreter = tflite.Interpreter(model_path='mobilenetv3_ssd_float.tflite', num_threads=4)
        #self._interpreter = tflite.Interpreter(model_path='mobilev2_ssd_all_41.tflite', num_threads=4)
        #self._interpreter = tflite.Interpreter(model_path='mobilev2_ssd_mask_1.tflite', num_threads=4)
        
        #self._interpreter = tflite.Interpreter(model_path='mobilenetv3_ssd_77367_opt.tflite', num_threads=4)
        
        self._interpreter = tflite.Interpreter(model_path = tflite_model_loc, num_threads=4)
        
        input_detail = self._interpreter.get_input_details()[0]
        self._interpreter.allocate_tensors()
        self._model_input_size = (input_detail['shape'][1], input_detail['shape'][2])
    
        self._is_quantized_input = input_detail['dtype'] == np.uint8
        #print(self._is_quantized_input)

        self.scale, self.zero_point = self._interpreter.get_input_details()[0]['quantization']


        sorted_output_indices = sorted([output['index'] for output in self._interpreter.get_output_details()])

        self._output_indices = {
            'BBOX': sorted_output_indices[0],
            'CLASS': sorted_output_indices[1],
            'SCORE': sorted_output_indices[2],
            'VALIDNUM': sorted_output_indices[3], 
        }

    def detect(self, input_image):
        input_tensor = self._preprocess(input_image)
        self._set_input_tensor(input_tensor)
        self._interpreter.invoke()
        return self._postprocess()

    def _preprocess(self, input_image):
        input_tensor = cv2.resize(input_image, self._model_input_size)
        
        if self._is_quantized_input:
            #input_tensor = input_tensor / self.scale + self.zero_point
            pass
        else:
            input_tensor = input_tensor/127.5-1

        return np.expand_dims(input_tensor, axis=0)

    def _set_input_tensor(self, image):
        tensor_index = self._interpreter.get_input_details()[0]['index']
        input_tensor = self._interpreter.tensor(tensor_index)()[0]
        input_tensor[:, :] = image

    def _get_output_tensor(self, name):
        output_index = self._output_indices[name]
        return np.squeeze(self._interpreter.get_tensor(output_index))

    def _postprocess(self):
        bboxes = self._get_output_tensor('BBOX')[:self._max_results, :] #max
        cls = self._get_output_tensor('CLASS')[:self._max_results] #c
        scores = self._get_output_tensor('SCORE')[:self._max_results] #b
        return bboxes, cls, scores
    
    def inf_test_tfds(self, dataset_name = "coco/2017", split_type="validation", 
                      NUMPIC = 10, _threshold = 0.5, random_EN = True, _SEED = 3, log_show = False):
        [test_dataset], dataset_info = tfds.load(name = dataset_name, split=[split_type], with_info=True)
        #[train_dataset], dataset_info_train = tfds.load(name="coco/2017", split=["train"], with_info=True)
        #[__test_dataset], dataset_info_test = tfds.load(name="coco/2017", split=["test"], with_info=True)
        #test_pct_ds = tfds.load(name="coco/2017", split='validation[50%:70%]') # choose the percent range
        
        print(len(test_dataset))
        labelMap_Func = dataset_info.features["objects"]["label"].int2str
        colors = np.random.rand(200, 3)*255
        
        score_threshold = _threshold
        TimeBench = {'FPS_Inf':0.0}
        #Test_Set = test_dataset
        
        if random_EN:
            shuffle_buffer_size = len(test_dataset)
        else:
            shuffle_buffer_size = 1
      
        count_det = 0
        #while count_det < NUMPIC:
        for sample in test_dataset.shuffle(buffer_size = shuffle_buffer_size, seed = _SEED).take(len(test_dataset)):
            
            if(count_det >= NUMPIC):
                break
            
            if (0 in sample['objects']['label'].numpy()) or split_type=="test": #check only the pictures which have obj we want
               
                plt.figure(figsize=(12,12))
                orignal_image = sample['image'].numpy()
                input_img = orignal_image
                
                ground_truth = sample['objects']['bbox']
                
                detection_start = perf_counter()
                bboxes, classes, scores = self.detect(input_img)
                if log_show:
                    print("bboxes: {}".format(bboxes))
                    print("classes: {}".format(classes))
                    print("scores: {}".format(scores))
                
                detection_end = perf_counter()
                
                TimeBench['FPS_Inf'] += (detection_end - detection_start)
                for bbox, cls, score in zip(bboxes, classes, scores):
                    y1, x1, y2, x2 = bbox
                    
                    if score < score_threshold:
                        break
                    x1 = int(x1*orignal_image.shape[1])
                    x2 = int(x2*orignal_image.shape[1])
                    y1 = int(y1*orignal_image.shape[0])
                    y2 = int(y2*orignal_image.shape[0])
                    _text = '{}_{:.2f}'.format(labelMap_Func(int(cls)), score)
                    #_text = '{:.2f}'.format(score)
                    cv2.rectangle(orignal_image, (x1, y1), (x2, y2), colors[int(cls)], 1)
                    cv2.putText(orignal_image, _text, (x1,y1+10), cv2.FONT_HERSHEY_COMPLEX, 0.4, colors[int(cls)], thickness=1, lineType=cv2.LINE_AA)
                
                count_det = count_det + 1
                plt.imshow(orignal_image)
        if count_det > 0:
            print("FPS: {}".format(TimeBench['FPS_Inf']/count_det))
        else:
            print(r"The 'count_det' is zero, so there is no pictures that have goals")
    
    def _load_image_into_numpy_array(self, image):
        (im_width, im_height) = image.size
        return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
    
    def inf_test_my_dataset(self, dataset_loc, label_map_path, NUMPIC = 10, _threshold = 0.5, random_EN = True, _SEED = 3):
        inf_num = NUMPIC
        score_threshold = _threshold
        TimeBench = {'FPS_Inf':0.0}
        
        TEST_IMAGE_PATHS = glob.glob(dataset_loc)
        print('The number of all test image set is: {}'.format(len(TEST_IMAGE_PATHS)))
        print('The number of test image this time is: {}'.format(inf_num))
        if random_EN:
            random.seed(_SEED)
            random.shuffle(TEST_IMAGE_PATHS)
        TEST_IMAGE_PATHS_INF = TEST_IMAGE_PATHS[:inf_num]
        
        # map labels for inference decoding
        label_map = label_map_util.load_labelmap(label_map_path)
        categories = label_map_util.convert_label_map_to_categories(
            label_map,
            max_num_classes=label_map_util.get_max_label_map_index(label_map),
            use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)
        
        #labelMap_Func = dataset_info.features["objects"]["label"].int2str
        colors = np.random.rand(200, 3)*255
        
        for image_path in TEST_IMAGE_PATHS_INF:
            
            image = Image.open(image_path)
            plt.figure(figsize=(12,12))
            
            input_img = self._load_image_into_numpy_array(image)
            
            #ground_truth = sample['objects']['bbox']
            orignal_image = input_img
            
            detection_start = perf_counter()
            bboxes, classes, scores = self.detect(input_img)
            #print("bboxes: {}".format(bboxes))
            #print("classes: {}".format(classes))
            #print("scores: {}".format(scores))
            
            detection_end = perf_counter()
            
            TimeBench['FPS_Inf'] += (detection_end - detection_start)
            for bbox, cls, score in zip(bboxes, classes, scores):
                y1, x1, y2, x2 = bbox
                
                if score < score_threshold:
                    break
                x1 = int(x1*orignal_image.shape[1])
                x2 = int(x2*orignal_image.shape[1])
                y1 = int(y1*orignal_image.shape[0])
                y2 = int(y2*orignal_image.shape[0])
                _text = '{}_{:.2f}'.format(category_index[(cls+1)]['name'], score)
                #_text = '{:.2f}'.format(score)
                cv2.rectangle(orignal_image, (x1, y1), (x2, y2), colors[int(cls)], 1)
                cv2.putText(orignal_image, _text, (x1,y1+10), cv2.FONT_HERSHEY_COMPLEX, 0.4, colors[int(cls)], thickness=1, lineType=cv2.LINE_AA)
            
            plt.imshow(orignal_image)
        
        if len(TEST_IMAGE_PATHS_INF) > 0:
            print("Time per plot: {}".format(TimeBench['FPS_Inf']/len(TEST_IMAGE_PATHS_INF)))
        else:
            print(r"There is no pictures that have goals")

# Widgets Control Section
---

In [3]:
from ipywidgets.widgets.interaction import show_inline_matplotlib_plots

%matplotlib inline
class init_inference_tflite_widgets():
    def __init__(self):
        
        self.tflite_file_loc = ""
        
        form_item_layout = Layout(
        display='flex',
        flex_flow='row',
        justify_content='space-between',
        )
        
        ### open source data download###
        self.A_ta = widgets.Text(value='coco/2017', placeholder='Type something', disabled=False)
        self.B_ta = Dropdown(options=['validation', 'test', 'train'])
        self.C_ta = widgets.BoundedIntText(value=1, min=0, max=1000, step=1, description='Text:', disabled=False)
        self.D_ta = widgets.FloatSlider(value=0.5, min=0.1, max=1.0, step=0.02)
        self.E_ta = widgets.Checkbox(value=False, disabled=False, indent=False)
        self.F_ta = widgets.IntSlider(value=3, min=1, max=100, step=1)
        
        form_train_items = [
            Box([Label(value = 'Dataset Name'), self.A_ta], layout=form_item_layout),
            Box([Label(value = 'Dataset Type'), self.B_ta], layout=form_item_layout),
            Box([Label(value = 'Number of Test'), self.C_ta], layout=form_item_layout),
            Box([Label(value = 'Threshold of Positive'), self.D_ta], layout=form_item_layout),
            Box([Label(value = 'Random Enable'), self.E_ta], layout=form_item_layout),
            Box([Label(value = 'Random Seed'), self.F_ta], layout=form_item_layout)
        ]
        
        self.form_box_train_para = Box(form_train_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightgreen',
            align_items='stretch',
            width='50%',
        ))
        
        ### custom data labeling###
        self.A_da = widgets.Textarea(value='C:/Users/USERNAME/image_detection/TensorFlow/workspace/training_demo_8000/images/test/*.jpg', 
                                     placeholder='Type something', disabled=False)
        self.B_da = widgets.Textarea(value='C:/Users/USERNAME/image_detection/TensorFlow/workspace/training_demo_8000/annotations/label_map.pbtxt', 
                                     placeholder='Type something', disabled=False)
        self.C_da = widgets.BoundedIntText(value=10, min=0, max=1000, step=1, description='Text:', disabled=False)
        self.D_da = widgets.FloatSlider(value=0.5, min=0.1, max=1.0, step=0.02)
        self.E_da = widgets.Checkbox(value=False, disabled=False, indent=False)
        self.F_da = widgets.IntSlider(value=3, min=1, max=100, step=1)
                
        form_data_items = [
            Box([Label(value = 'Dataset Location'), self.A_da], layout=form_item_layout),
            Box([Label(value = 'Label Map Location'), self.B_da], layout=form_item_layout),
            Box([Label(value = 'Number of Test'), self.C_da], layout=form_item_layout),
            Box([Label(value = 'Threshold of Positive'), self.D_da], layout=form_item_layout),
            Box([Label(value = 'Random Enable'), self.E_da], layout=form_item_layout),
            Box([Label(value = 'Random Seed'), self.F_da], layout=form_item_layout)
        ]
        
        self.form_box_data_para = Box(form_data_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightgreen',
            align_items='stretch',
            width='50%',
        ))
        
        ### choose the tflite###
        self.A_tfc = widgets.Button(description='tfliteChooser', layout=Layout(width='30%', height='30px'), button_style='success')
                
        form_data_items = [
            Box([Label(value = 'Please Choose tflite model'), self.A_tfc], layout=form_item_layout),
        ]
        
        self.form_box_tflite_load = Box(form_data_items, layout=Layout(
            display='flex',
            flex_flow='column',
            border='solid 3px lightgreen',
            align_items='stretch',
            width='50%',
        ))
       
        
    def move_allfiles(self, src_folder, dst_folder):
        copy_num = 0
        
        files = os.listdir(src_folder)
        for f in files:
            fullpath = os.path.join(src_folder, f)
            if os.path.isdir(fullpath):  #copy whole folder
                shutil.move(fullpath, dst_folder)
                print("Copy finish: {}".format(f))
            
    
    def show_main(self):   
        
        intro_text = 'Please Choose the setting of inference data'
        htmlWidget = widgets.HTML(value = f"<b><font color='lightgreen'><font size=4>{intro_text}</b>")
        display(htmlWidget)
        
        #Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_box_tflite_load, self.form_box_train_para, self.form_box_data_para 
                                                ]).add_class("parentstyle")
        #Add a custom style tag to the notebook, you can use dev tool to inspect the class names
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, 'Choose tflite model')
        accordion.set_title(1, 'Inference tfds')
        accordion.set_title(2, 'Inference my dataset')    
        
        def act_para(dataset_name, dataset_type, num, th, ran_en, ran_seed,
                     dataset_loc, label_map_loc, num_myD, th_myD, ran_en_myD, ran_seed_myD):
            #------------------#
            # The main executing Toggle_Button
            #------------------#
            toggle_run_tfds = widgets.ToggleButton(description='Start Inference tfds', 
                                                   layout=Layout(width='30%', height='30px'), button_style='success')
            toggle_run_mydataset = widgets.ToggleButton(description='Start Inference my dataset', 
                                                   layout=Layout(width='30%', height='30px'), button_style='success')
            out = widgets.Output(layout=Layout(border = '1px solid green'))
            
            #------------------#
            # buttoms event control in widgets.Accordion
            #------------------#
            def on_button_clicked_tfliteChooser(b):
                with out:
                    clear_output()
                    self.tflite_File_Choose()           
            self.A_tfc.on_click(on_button_clicked_tfliteChooser)
            
            #------------------#
            # The main executing Toggle_Button's event function
            #------------------#
            def run_inference_tfds(obj):
                with out:
                    if obj['new']:
                        out.clear_output()
                        print('Run tflite test on tfds...')
                        
                        if (self.tflite_file_loc != ""):
                            det_obj = Detector(self.tflite_file_loc)
                            det_obj.inf_test_tfds(dataset_name, split_type = dataset_type, 
                                                  NUMPIC = int(num), _threshold = float(th),
                                                  random_EN = ran_en, _SEED = int(ran_seed))
                            #det_obj.inf_test_tfds(dataset_name = "coco/2017", split_type = "validation", 
                            #                      NUMPIC = 1, _threshold = 0.50, random_EN = False, _SEED = 3)
                            show_inline_matplotlib_plots()
                            print('Finish')
                        else:
                            print("There is no tflite model!!!")
                    else:
                        print('stop')
                        out.clear_output() 
            
            def run_inference_mydataset(obj):
                with out:
                    if obj['new']:
                        out.clear_output()
                        print('Run tflite test on my dataset...')
                        
                        #dataset_loc = "C:/Users/USERNAME/image_detection/TensorFlow/workspace/training_demo_8000/images/test/*.jpg"
                        #label_map_loc = "C:/Users/USERNAME/image_detection/TensorFlow/workspace/training_demo_8000/annotations/label_map.pbtxt"
                        
                        if (self.tflite_file_loc != ""):
                            det_obj_mydata = Detector(self.tflite_file_loc)
                            det_obj_mydata.inf_test_my_dataset(dataset_loc, label_map_loc, 
                                                               NUMPIC = int(num_myD), _threshold = float(th_myD), 
                                                               random_EN = ran_en_myD, _SEED = int(ran_seed_myD))
                            show_inline_matplotlib_plots()
                            print('Finish')
                        else:
                            print("There is no tflite model!!!")
                    else:
                        print('Stop')
                        out.clear_output()
            
            
            toggle_run_tfds.observe(run_inference_tfds, 'value')
            toggle_run_mydataset.observe(run_inference_mydataset, 'value')
            display(toggle_run_tfds, toggle_run_mydataset)
            display(out)
            
        
        #------------------#
        # widgets.Accordion's interactive input with action function `act_para()`
        #------------------#
        out_inter = widgets.interactive_output(act_para, {'dataset_name': self.A_ta, 'dataset_type': self.B_ta, 'num': self.C_ta, 
                                                          'th': self.D_ta, 'ran_en': self.E_ta, 'ran_seed': self.F_ta, 
                                                          'dataset_loc' : self.A_da, 'label_map_loc' : self.B_da, 'num_myD': self.C_da, 
                                                          'th_myD' : self.D_da, 'ran_en_myD' : self.E_da, 'ran_seed_myD': self.F_da
                                                          })
        display(accordion, out_inter)
        
        
    def tflite_File_Choose(self):
        path_fc = os.getcwd() ##The image dataset location
        path_fc = os.path.join(path_fc, "tflite_example")
        fc = FileChooser(path_fc)
        #fc.show_only_dirs = True
        fc.title = f"<b><font color='lightblue'><font size=4>Choose the tflite.</b>"
        display(fc)
        
        def act_load_tflite():
            self.tflite_file_loc = fc.selected
            print("Load tflite file: {}".format(self.tflite_file_loc))
            
        
        evt = interact_manual(act_load_tflite)
        evt.widget.children[0].description = 'Load tflite'  #because there are 3 parameter of the evt
        evt.widget.children[0].button_style = 'primary'        

In [4]:
act = init_inference_tflite_widgets()
act.show_main()

HTML(value="<b><font color='lightgreen'><font size=4>Please Choose the setting of inference data</b>")

Accordion(children=(Box(children=(Box(children=(Label(value='Please Choose tflite model'), Button(button_style…

Output()