# Import Section
---

In [1]:
"""
This script sets up the necessary imports for a Jupyter Notebook that involves
image classification using TensorFlow 2. 
"""
import shutil
import os

import ipywidgets as widgets
from ipywidgets import interact_manual
from ipywidgets import Layout, Box, Dropdown, Label
from IPython.display import display, HTML
from IPython.display import clear_output
from ipyfilechooser import FileChooser

# Widgets Control Section
---

In [2]:
class TrainConfigAndCmdsWidgets:
    """
    A class to create and manage widgets for configuring and running training, testing, and model conversion tasks.
    Attributes:
    """
    def __init__(self):

        self.tflite_file_loc = ""

        form_item_layout = Layout(
            display="flex",
            flex_flow="row",
            justify_content="space-between",
        )

        # data exist
        self.a_de = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.b_de = widgets.Text(value="flower_photos", placeholder="Type something", disabled=False)

        form_data_prepare_items = [
            Box([Label(value="Data Exist"), self.a_de], layout=form_item_layout),
            Box([Label(value="Dataset Name"), self.b_de], layout=form_item_layout),
        ]

        self.form_data_prepare_exist = Box(
            form_data_prepare_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                align_items="stretch",
                width="100%",
            ),
        )

        # data download
        # Another flowers dataset
        # https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
        self.a_dp = widgets.Textarea(value="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", placeholder="Type something", disabled=False)
        # cats_and_dogs.zip
        self.b_dp = widgets.Text(value="flower_photos.tgz", placeholder="Type something", disabled=False)

        form_data_prepare_items = [
            Box([Label(value="URL Link"), self.a_dp], layout=form_item_layout),
            Box([Label(value="Zip Name"), self.b_dp], layout=form_item_layout),
        ]

        self.form_data_prepare_cmd = Box(
            form_data_prepare_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                border="solid 1px lightblue",
                align_items="stretch",
                width="100%",
            ),
        )

        # train
        self.a_ta = widgets.Text(value="MyTask", placeholder="Type something", disabled=False)
        self.e_ta = Dropdown(
            options=[
                "mobilenet_v1",
                "mobilenet_v2",
                "mobilenet_v3_mini",
                "mobilenet_v3",
                "fdmobilenet_w1",
                "fdmobilenet_wd2",
                "fdmobilenet_wd4",
                "shufflenet_g1_wd4",
                "shufflenet_g3_wd4",
                "shufflenet_g1_wd2",
                "shufflenet_g3_wd2",
                "efficientnetB0",
                "efficientnetv2B0",
            ]
        )
        self.o_ta = Dropdown(options=["1: Load Pretrain Model(Not support all setting)", "2: Download Pretrain Model Only", "3: Train From Scratch"])
        self.b_ta = widgets.IntSlider(value=32, min=4, max=192, step=4)
        self.c_ta = widgets.IntSlider(value=224, min=32, max=352, step=32)
        self.m_ta = widgets.FloatSlider(value=0.35, min=0.1, max=1.5, step=0.05)
        self.d_ta = widgets.FloatSlider(value=0.2, min=0.1, max=0.5, step=0.1)
        self.f_ta = widgets.FloatSlider(value=0.2, min=0.0, max=0.5, step=0.1)
        self.g_ta = widgets.Checkbox(value=True, disabled=False, indent=False)
        self.h_ta = widgets.Text(value="10,50", placeholder="Type something", disabled=False)
        self.i_ta = widgets.Text(value="0.001,0.0005", placeholder="Type something", disabled=False)
        self.j_ta = Dropdown(value="3: Transfer and Fine-Tuning training", options=["1: Show the train data and model", "2: Transfer training", "3: Transfer and Fine-Tuning training"])
        self.k_ta = widgets.BoundedIntText(value=40, min=0, max=300, step=1, disabled=False)
        self.l_ta = widgets.Button(description="Run", layout=Layout(width="30%", height="30px"), button_style="success")

        self.fine_tune_la = Label(value="Freezing Layers of Fine-Tuning")
        self.k_ta_box = Box([self.fine_tune_la, self.k_ta], layout=form_item_layout)
        self.alpha_width_la = Label(value="Alpha Width")
        self.m_ta_box = Box([self.alpha_width_la, self.m_ta], layout=form_item_layout)

        form_train_items = [
            Box([Label(value="Project Name"), self.a_ta], layout=form_item_layout),
            Box([Label(value="Model Name"), self.e_ta], layout=form_item_layout),
            Box([Label(value="Model Setting"), self.o_ta], layout=form_item_layout),
            Box([Label(value="Batch Size"), self.b_ta], layout=form_item_layout),
            Box([Label(value="Image Size"), self.c_ta], layout=form_item_layout),
            self.m_ta_box,
            Box([Label(value="Validation Percent"), self.d_ta], layout=form_item_layout),
            Box([Label(value="Test Percent from Val Dataset"), self.f_ta], layout=form_item_layout),
            Box([Label(value="Data Augmentation Enable"), self.g_ta], layout=form_item_layout),
            Box([Label(value="Epochs (Transfer, Fine-Tuning)"), self.h_ta], layout=form_item_layout),
            Box([Label(value="Learning Rate (Transfer, Fine-Tuning)"), self.i_ta], layout=form_item_layout),
            Box([Label(value="Switch Mode"), self.j_ta], layout=form_item_layout),
            self.k_ta_box,
            Box([Label(value="Start to Execute"), self.l_ta], layout=form_item_layout),
        ]

        self.form_output_train_cmd = Box(
            form_train_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                border="solid 1px lightblue",
                align_items="stretch",
                width="100%",
            ),
        )

        # test
        self.a_tt = widgets.Button(description="Setting", layout=Layout(width="30%", height="30px"), button_style="success")
        self.b_tt = widgets.BoundedIntText(value=1, min=1, max=100, step=1, disabled=False)
        self.c_tt = widgets.Button(description="Run", layout=Layout(width="30%", height="30px"), button_style="success")

        form_test_items = [
            Box([Label(value="Choose the tflite file"), self.a_tt], layout=form_item_layout),
            Box([Label(value="Batches for test"), self.b_tt], layout=form_item_layout),
            Box([Label(value="Start to Test"), self.c_tt], layout=form_item_layout),
        ]

        self.form_output_test_cmd = Box(
            form_test_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                border="solid 3px lightgreen",
                align_items="stretch",
                width="50%",
            ),
        )

        # convert model cpp
        self.a_cm = widgets.Text(value=r"..\workspace\\catsdogs\\tflite_model", placeholder="Type something", disabled=False)
        self.b_cm = widgets.Text(value="mobilenet_v2_int8quant.tflite", placeholder="Type something", disabled=False)
        self.c_cm = widgets.Text(value=r"..\workspace\\catsdogs\\tflite_model\\vela", placeholder="Type something", disabled=False)
        self.e_cm = widgets.Button(description="Setting", layout=Layout(width="30%", height="30px"), button_style="success")
        self.d_cm = widgets.Button(description="Run", layout=Layout(width="30%", height="30px"), button_style="success")
        self.f_cm = widgets.Text(value="cats_and_dogs_filtered", placeholder="Type something", disabled=False)
        self.g_cm = widgets.Checkbox(value=True, disabled=False, indent=False)

        form_convert_items_paths = [
            Box([Label(value="Choose the tflite file"), self.e_cm], layout=form_item_layout),
            Box([Label(value="MODEL SRC DIR"), self.a_cm], layout=form_item_layout),
            Box([Label(value="MODEL SRC FILE"), self.b_cm], layout=form_item_layout),
            Box([Label(value="GEN SRC DIR"), self.c_cm], layout=form_item_layout),
        ]

        form_convert_items_label = [
            Box([Label(value="Dataset Name"), self.f_cm], layout=form_item_layout),
            Box([Label(value="Labels Including"), self.g_cm], layout=form_item_layout),
        ]

        form_convert_items = [
            Box(form_convert_items_paths, layout=Layout(display="flex", flex_flow="column", justify_content="center", border="dotted 3px lightblue", align_items="stretch", width="70%")),
            Box(form_convert_items_label, layout=Layout(display="flex", flex_flow="column", justify_content="center", border="dotted 3px lightblue", align_items="stretch", width="70%")),
            Box([Label(value="Convert to cpp & Vela"), self.d_cm], layout=form_item_layout),
        ]

        self.form_output_convert_cmd = Box(
            form_convert_items,
            layout=Layout(
                display="flex",
                flex_flow="column",
                justify_content="center",
                border="solid 3px lightgreen",
                align_items="stretch",
                width="70%",
            ),
        )

    def move_allfiles(self, src_folder, dst_folder):
        """
        Moves all files and directories from the source folder to the destination folder.
        """
        files = os.listdir(src_folder)
        for f in files:
            fullpath = os.path.join(src_folder, f)
            if os.path.isdir(fullpath):  # copy whole folder
                shutil.move(fullpath, dst_folder)
                print(f"Copy finish: {f}")

    def show_headline(self, output):
        """
        Displays a headline with the given output text in a styled HTML format.
        """
        html0 = widgets.HTML(value=f"<b><font color='lightblue'><font size=4>{output}</b>")
        display(html0)

    def show_main(self):
        """
        Displays the main interface for data preparation and training settings.
        This method creates and displays various widgets including HTML, Accordion, Box, and Tab widgets
        to allow users to configure data preparation and training settings. It also sets up observers and
        interactive outputs to handle user interactions and updates.
        Observers and Interactive Outputs:
        - An observer to update dependent values based on the selected model name.
        - An interactive output to handle parameter changes and update the visibility of widgets accordingly.
        Button Click Handlers:
        - Handlers for various buttons to trigger training, model selection, conversion, and testing actions.
        """

        intro_text = "Please Choose the setting of data prepare & train"
        html_widget = widgets.HTML(value=f"<b><font color='lightgreen'><font size=6>{intro_text}</b>")
        display(html_widget)

        # Create an accordion and put the 2 boxes
        accordion = widgets.Accordion(children=[self.form_data_prepare_cmd, self.form_output_train_cmd]).add_class("parentstyle")
        display(HTML("<style>.parentstyle > .p-Accordion-child > .p-Collapse-header{background-color:green}</style>"))
        accordion.set_title(0, "Data Download Setting")
        accordion.set_title(1, "Configure the Training")

        # Create a box combining with 2 elements
        box_data_train = Box(
            [self.form_data_prepare_exist, accordion],
            layout=Layout(
                display="flex",
                flex_flow="column",
                border="solid 3px lightgreen",
                align_items="stretch",
                width="50%",
            ),
        )

        # Create a tab and put the 2 boxes
        tab = widgets.Tab(children=[box_data_train, self.form_output_test_cmd, self.form_output_convert_cmd]).add_class("parentstyle")
        tab_contents = ["Train", "Test", "Deployment"]
        tab.titles = tab_contents

        output_widgets = widgets.Output(layout=Layout(border="1px solid green"))

        # Special observe for MODEL_NAME dependent value updating (fine tune layers)
        def update(*args):
            if self.e_ta.value.count("fdmobile"):
                self.k_ta.value = 6  # bcs the tf2cv structure is combining blocks
                self.k_ta.max = 10
                self.m_ta.layout.visibility = "hidden"
                self.alpha_width_la.layout.visibility = "hidden"
            elif self.e_ta.value.count("shufflenet"):
                self.k_ta.value = 10  # bcs the tf2cv structure is combining blocks
                self.k_ta.max = 17
                self.m_ta.layout.visibility = "hidden"
                self.alpha_width_la.layout.visibility = "hidden"
            elif self.e_ta.value.count("mobilenet_v1"):
                self.k_ta.max = 86
                self.k_ta.value = 40
                self.m_ta.layout.visibility = "visible"
                self.alpha_width_la.layout.visibility = "visible"
            elif self.e_ta.value.count("mobilenet_v2"):
                self.k_ta.max = 154
                self.k_ta.value = 80
                self.m_ta.layout.visibility = "visible"
                self.alpha_width_la.layout.visibility = "visible"
            elif self.e_ta.value.count("mobilenet_v3"):
                self.k_ta.max = 228
                self.k_ta.value = 120
                self.m_ta.layout.visibility = "visible"
                self.alpha_width_la.layout.visibility = "visible"
            elif self.e_ta.value.count("mobilenet_v3_mini"):
                self.k_ta.max = 102
                self.k_ta.value = 50
                self.m_ta.layout.visibility = "visible"
                self.alpha_width_la.layout.visibility = "visible"
            elif self.e_ta.value.count("efficientnetB0"):
                self.k_ta.max = 238
                self.k_ta.value = 120
                self.m_ta.layout.visibility = "hidden"
                self.alpha_width_la.layout.visibility = "hidden"
            elif self.e_ta.value.count("efficientnetv2B0"):
                self.k_ta.max = 170
                self.k_ta.value = 150
                self.m_ta.layout.visibility = "hidden"
                self.alpha_width_la.layout.visibility = "hidden"

        self.e_ta.observe(update)

        def act_para(*, data_exist, url, zip_n, dataset_n, a_ta, b_ta, c_ta, d_ta, e_ta, f_ta, g_ta, h_ta, i_ta, j_ta, k_ta, a_cm, b_cm, c_cm, f_cm, g_cm, b_tt, m_ta, o_ta):

            # If any value is changed, clear the widgets
            with output_widgets:
                output_widgets.clear_output()

            if data_exist:
                self.form_data_prepare_cmd.layout.visibility = "hidden"
            else:
                self.form_data_prepare_cmd.layout.visibility = "visible"

            if j_ta.count("3:"):
                self.k_ta.layout.visibility = "visible"
                self.fine_tune_la.layout.visibility = "visible"
            else:
                self.k_ta.layout.visibility = "hidden"
                self.fine_tune_la.layout.visibility = "hidden"

        # ------------------#
        # widgets.Accordion's interactive input with action function `act_para()`
        # ------------------#
        out_inter = widgets.interactive_output(
            act_para,
            {
                "data_exist": self.a_de,
                "url": self.a_dp,
                "zip_n": self.b_dp,
                "dataset_n": self.b_de,
                "a_ta": self.a_ta,
                "b_ta": self.b_ta,
                "c_ta": self.c_ta,
                "d_ta": self.d_ta,
                "e_ta": self.e_ta,
                "f_ta": self.f_ta,
                "g_ta": self.g_ta,
                "h_ta": self.h_ta,
                "i_ta": self.i_ta,
                "j_ta": self.j_ta,
                "k_ta": self.k_ta,
                "a_cm": self.a_cm,
                "b_cm": self.b_cm,
                "c_cm": self.c_cm,
                "f_cm": self.f_cm,
                "g_cm": self.g_cm,
                "b_tt": self.b_tt,
                "m_ta": self.m_ta,
                "o_ta": self.o_ta,
            },
        )

        display(tab, out_inter)

        # ------------------#
        # for labelimg cmd, move to outside of act_para to prevent keep trigering
        # ------------------#
        # output_widgets = widgets.Output(layout=Layout(border = '1px solid green'))
        display(output_widgets)

        def on_button_clicked_train(b):
            with output_widgets:
                clear_output()
                print("Train. . .")
                self.run_train()

        self.l_ta.on_click(on_button_clicked_train)

        def on_button_clicked_choose_tflite(b):
            with output_widgets:
                clear_output()
                self.choose_tflite()

        self.e_cm.on_click(on_button_clicked_choose_tflite)

        def on_button_clicked_cpp(b):
            with output_widgets:
                clear_output()
                print("Convert to cpp & Vela. . .")
                self.convert_tflu()

        self.d_cm.on_click(on_button_clicked_cpp)

        def on_button_clicked_choose_tflite_test(b):
            with output_widgets:
                clear_output()
                self.choose_tflite()

        self.a_tt.on_click(on_button_clicked_choose_tflite_test)

        def on_button_clicked_test(b):
            with output_widgets:
                clear_output()
                self.run_test_tflite()

        self.c_tt.on_click(on_button_clicked_test)

    def choose_tflite(self):
        """
        Prompts the user to choose a TensorFlow Lite (.tflite) file for conversion.
        This method uses a file chooser to allow the user to select a .tflite file from the workspace directory.
        It restricts navigation to the current working directory and filters the files to only show .tflite files.
        Once a file is selected, it sets the chosen directory and file name to the corresponding attributes and prints them.
        Attributes:
            a_cm (ipywidgets.Text): Widget to display the chosen directory.
            b_cm (ipywidgets.Text): Widget to display the chosen .tflite file name.
            c_cm (ipywidgets.Text): Widget to display the path for the 'vela' directory.
        """

        path_ftflite = os.path.join(os.getcwd(), "workspace")
        f_tflite = FileChooser(path_ftflite)
        # Restrict navigation to /Users
        f_tflite.sandbox_path = os.getcwd()
        f_tflite.filter_pattern = ["*.tflite"]
        f_tflite.title = "<b><font color='lightblue'><font size=4>Choose the Tflite for Converting.</b>"
        display(f_tflite)

        def act_test():
            work_dir_name = os.getcwd().split("\\")[-1]
            m_src_dir = r".." + f_tflite.selected_path.split(work_dir_name)[-1]
            m_src_tflite = f_tflite.selected.split("\\")[-1]
            print(f"The chosen dir: {m_src_dir}")
            print(f"The chosen tflite: {m_src_tflite}")
            self.a_cm.value = m_src_dir
            self.b_cm.value = m_src_tflite
            self.c_cm.value = os.path.join(m_src_dir, "vela")

            print("Finish!")

        evt = interact_manual(act_test)
        evt.widget.children[0].description = "Set this file"  # because there are 3 parameter of the evt
        evt.widget.children[0].button_style = "primary"

    def run_train(self):
        """
        Executes the training process for a machine learning model based on the provided parameters.
        This method determines the appropriate training script to use and constructs the command to run it
        with the specified parameters. It also records the command in a text file for future reference.
        """

        sw_num = int(self.j_ta.value.split(":")[0])

        if self.o_ta.value.count("1:"):
            dl_model_set = 0
        elif self.o_ta.value.count("2:"):
            dl_model_set = 1
        else:
            dl_model_set = 2

        if self.e_ta.value.count("fdmobile") or self.e_ta.value.count("shufflenet"):
            python_file = "train_tf2cv.py"
            print("Need tf2cv to run this model!")
        else:
            python_file = "train.py"

        %run $python_file --data_exist $self.a_de.value --url $self.a_dp.value --zip_name $self.b_dp.value --dataset_name $self.b_de.value \
        --proj_name $self.a_ta.value --BATCH_SIZE $self.b_ta.value --IMG_SIZE $self.c_ta.value --VAL_PCT $self.d_ta.value \
        --MODEL_NAME $self.e_ta.value --TEST_PCT $self.f_ta.value --DATA_AUGM $self.g_ta.value --EPOCHS $self.h_ta.value \
        --LEARNING_RATE $self.i_ta.value --FINE_TUNE_LAYER $self.k_ta.value --switch_mode $sw_num --ALPHA_WIDTH $self.m_ta.value \
        --IMAGENET_MODEL_EN $dl_model_set

        path_cmd_record = os.path.join(os.getcwd(), "workspace", self.a_ta.value, "train_cmd_record.txt")
        with open(path_cmd_record, "w", encoding='utf-8') as file1:
            open_dir = f"cd {os.getcwd()} \n"
            cmd = f"python {python_file} --data_exist {self.a_de.value} --url {self.a_dp.value} --zip_name {self.b_dp.value} --dataset_name {self.b_de.value} \
--proj_name {self.a_ta.value} \
--BATCH_SIZE { self.b_ta.value} --IMG_SIZE {self.c_ta.value} --VAL_PCT {self.d_ta.value} --MODEL_NAME {self.e_ta.value} \
--TEST_PCT {self.f_ta.value} --DATA_AUGM {self.g_ta.value} --EPOCHS {self.h_ta.value} --LEARNING_RATE {self.i_ta.value} \
--FINE_TUNE_LAYER {self.k_ta.value} --switch_mode {sw_num} --ALPHA_WIDTH {self.m_ta.value} \
--IMAGENET_MODEL_EN {dl_model_set}"
            file1.writelines([open_dir, cmd])
        print("Finish !!")

    def run_test_tflite(self):
        """
        Runs a test on a TensorFlow Lite model.
        This method constructs the file path for the TensorFlow Lite model based on the provided configuration values,
        determines the appropriate Python script to use for running the model, and prints relevant information about the
        test.
        """

        tflite_location = os.path.join((self.a_cm.value).split("..\\")[-1], self.b_cm.value)
        print(f"The tflite file: {tflite_location} The number of test batch: {self.b_tt.value}")

        if self.b_cm.value.count("fdmobile") or self.b_cm.value.count("shufflenet"):
            python_file = "train_tf2cv.py"
            print("Need tf2cv to run this model!")
        else:
            python_file = "train.py"

        %run $python_file --data_exist $self.a_de.value --url $self.a_dp.value --zip_name $self.b_dp.value --dataset_name $self.b_de.value \
        --proj_name $self.a_ta.value --BATCH_SIZE $self.b_ta.value --IMG_SIZE $self.c_ta.value --VAL_PCT $self.d_ta.value \
        --MODEL_NAME $self.e_ta.value --TEST_PCT $self.f_ta.value --switch_mode 4 --ALPHA_WIDTH $self.m_ta.value\
        --TFLITE_F $tflite_location --TFLITE_TEST_BATCH_N $self.b_tt.value

        print("Finish !!")

    def convert_tflu(self):
        """
        Converts TensorFlow Lite model and generates necessary C++ source/header files for labels.
        Attributes:
            self.a_cm.value (str): Source directory path.
            self.b_cm.value (str): Source file name.
            self.c_cm.value (str): Generated directory path.
            self.f_cm.value (str): Labels dataset name.
            self.g_cm.value (bool): Flag to indicate whether to create the label C++ source/header files.
        """

        %run exebat.py --SRC_DIR $self.a_cm.value --SRC_FILE $self.b_cm.value --GEN_DIR $self.c_cm.value

        if self.g_cm.value:  # create the label C++ source/header files
            # Change to dataset folder
            old_cwd = os.getcwd()
            batch_cwd = os.path.join(old_cwd, "dataset")
            os.chdir(batch_cwd)

            %run gen_labels_cpp.py --labels_dataset_name $self.f_cm.value --source_folder_path $self.a_cm.value \
            --header_folder_path $self.a_cm.value

            os.chdir(old_cwd)
            print(f'Finish, the label file is at: {(old_cwd + self.a_cm.value.split("..")[1])}')


# Run Section
---
- The detail description of all the parameters and each step meaning is here [meaning](#id-train_evl_monitor)
- In this notebook step, you have alreay finish the dataset prepared. If not, please go to `image_dataset\create_data.ipynb`.

In [3]:
act = TrainConfigAndCmdsWidgets()
act.show_main()

HTML(value="<b><font color='lightgreen'><font size=6>Please Choose the setting of data prepare & train</b>")

Tab(children=(Box(children=(Box(children=(Box(children=(Label(value='Data Exist'), Checkbox(value=True, indent…

Output()

Output(layout=Layout(border_bottom='1px solid green', border_left='1px solid green', border_right='1px solid g…