In [None]:
import os
import importlib
from omegaconf import OmegaConf
from IPython.display import display, clear_output
from PyQt5.QtWidgets import QApplication, QFileDialog
import ipywidgets as widgets
import warnings
warnings.filterwarnings("ignore")
from utils.test_image import create_test_dataset

# task folder
task_folders = [folder for folder in os.listdir("saved_model") if os.path.isdir(
    os.path.join("saved_model", folder))]
task_folders = sorted(list(task_folders))
task_folder_path = widgets.Dropdown(
    options=task_folders,
    value=task_folders[0],
    description='Task Type:',
    disabled=False,
)

# method folder
method_type_folder_path = f'saved_model/{task_folder_path.value}'
method_folders = [folder for folder in os.listdir(method_type_folder_path) if os.path.isdir(
    os.path.join(method_type_folder_path, folder))]
method_folders = sorted(list(method_folders))
method_folder_path = widgets.Dropdown(
    options=method_folders,
    value=method_folders[0],
    description='Method Type:',
    disabled=False,
)

# dataset folder
dataset_type_folder_path = f'saved_model/{task_folder_path.value}/{method_folder_path.value}'
dataset_folders = [folder for folder in os.listdir(dataset_type_folder_path) if os.path.isdir(
    os.path.join(dataset_type_folder_path, folder))]
dataset_folders = sorted(list(dataset_folders))
dataset_folder_path = widgets.Dropdown(
    options=dataset_folders,
    value=dataset_folders[0],
    description='Dataset Type:',
    disabled=False,
)

# model type
model_type_folder_path = f'saved_model/{task_folder_path.value}/{method_folder_path.value}/{dataset_folder_path.value}'
model_name_folders = [folder for folder in os.listdir(model_type_folder_path) if os.path.isdir(
    os.path.join(model_type_folder_path, folder))]
model_name_folders = sorted(list(model_name_folders))
model_name_list = widgets.Dropdown(
    options=model_name_folders,
    value=model_name_folders[0],
    description='Model Name:',
    disabled=False,
)

# task change
def on_task_type_change(change):
    # method type
    method_type_folder_path = f'saved_model/{change.new}'
    method_folders = [folder for folder in os.listdir(method_type_folder_path) if os.path.isdir(
        os.path.join(method_type_folder_path, folder))]
    method_folder_path.options = method_folders
    method_folder_path.value = method_folders[0] if method_folders else None

    # dataset type
    dataset_type_folder_path = f'saved_model/{change.new}/{method_folder_path.value}'
    dataset_folders = [folder for folder in os.listdir(dataset_type_folder_path) if os.path.isdir(
        os.path.join(dataset_type_folder_path, folder))]
    dataset_folder_path.options = dataset_folders
    dataset_folder_path.value = dataset_folders[0] if dataset_folders else None

    # model type
    model_type_folder_path = f'saved_model/{change.new}/{method_folder_path.value}/{dataset_folder_path.value}'
    model_name_folders = [folder for folder in os.listdir(model_type_folder_path) if os.path.isdir(
        os.path.join(model_type_folder_path, folder))]
    model_name_list.options = model_name_folders
    model_name_list.value = model_name_folders[0] if model_name_folders else None


task_folder_path.observe(on_task_type_change, names='value')

# method change
def on_method_type_change(change):
    # dataset type
    dataset_type_folder_path = f'saved_model/{task_folder_path.value}/{change.new}'
    dataset_folders = [folder for folder in os.listdir(dataset_type_folder_path) if os.path.isdir(
        os.path.join(dataset_type_folder_path, folder))]
    dataset_folder_path.options = dataset_folders
    dataset_folder_path.value = dataset_folders[0] if dataset_folders else None

    # model type
    model_type_folder_path = f'saved_model/{task_folder_path.value}/{change.new}/{dataset_folder_path.value}'
    model_name_folders = [folder for folder in os.listdir(model_type_folder_path) if os.path.isdir(
        os.path.join(model_type_folder_path, folder))]
    model_name_list.options = model_name_folders
    model_name_list.value = model_name_folders[0] if model_name_folders else None


method_folder_path.observe(on_method_type_change, names='value')

# dataset change
def on_dataset_type_change(change):
    # model type
    model_type_folder_path = f'saved_model/{task_folder_path.value}/{method_folder_path.value}/{change.new}'
    model_name_folders = [folder for folder in os.listdir(model_type_folder_path) if os.path.isdir(
        os.path.join(model_type_folder_path, folder))]
    model_name_list.options = model_name_folders
    model_name_list.value = model_name_folders[0] if model_name_folders else None


dataset_folder_path.observe(on_dataset_type_change, names='value')


# ====================================
# Select Model
# ====================================
def load_model():
    global model_name
    model_name = model_name_list.value.split('-')[0].lower()
    
    # config
    global cfg
    cfg = OmegaConf.load(f'configs/{task_folder_path.value}/{method_folder_path.value}/{model_name}/{dataset_folder_path.value}.yaml')
    
    # model & data
    global model
    global testset
    global test_img

    test_img = False
    # params: dict
    global params
    task_path = f"task.{task_folder_path.value}.{method_folder_path.value}.models.{model_name}.tool"
    if infer_pth:
        testset, model, params = importlib.import_module(
            task_path).load_model(cfg=cfg, model_name=model_name_list.value, target=dataset_folder_path.value)
    else:
        testset, model, params = importlib.import_module(
            task_path).load_model_onnx(cfg=cfg, model_name=model_name_list.value, target=dataset_folder_path.value)

# ====================================
# Select Folder
# ====================================
def load_folder():
    global test_img 
    global testset
    app = QApplication([])
    datadir = QFileDialog.getExistingDirectory()
    app.exit()
    # data
    if datadir:
        testset = create_test_dataset(cfg=cfg, datadir=datadir)
        test_img = True
    else:
        if infer_pth:
            button_pth.button_style = "success"
            button_folder.button_style = ""
        else:
            button_onnx.button_style = "success"
            button_folder.button_style = ""

# ====================================
# Visualization
# ====================================
def result_plot(idx):
    task_path = f"task.{task_folder_path.value}.{method_folder_path.value}.models.{model_name}.tool"
    if infer_pth:
        importlib.import_module(task_path).result_plot(
            idx, model, testset, params, test_img)
    else:
        importlib.import_module(task_path).result_plot_onnx(
            idx, model, testset, params, test_img)

# ====================================
# Save All Result
# ====================================
def save_all_result():
    task_path = f"task.{task_folder_path.value}.{method_folder_path.value}.models.{model_name}.tool"
    if infer_pth:
        importlib.import_module(task_path).result_save_plot(
            cfg, model, testset, params, test_img)
    else:
        importlib.import_module(task_path).result_save_plot_onnx(
            cfg, model, testset, params, test_img)
        
# ====================================
# SHOW TRAIN Result
# ====================================
def show_train_result():
    task_path = f"task.{task_folder_path.value}.{method_folder_path.value}.models.{model_name}.tool"
    importlib.import_module(task_path).show_train_result(params)

# ====================================
# Tracking Camera Result
# ====================================
def track_camera_result():
    task_path = f"task.{task_folder_path.value}.{method_folder_path.value}.models.{model_name}.tool"
    if infer_pth:
        importlib.import_module(task_path).track_camera_run(
            cfg, model, params)
    else:
        importlib.import_module(task_path).track_camera_run_onnx(
            cfg, model, params)
        
# ====================================
# Tracking Video Result
# ====================================
def track_video_result():
    app = QApplication([])
    while True:
        # Chọn tập tin thay vì thư mục
        file_path, _ = QFileDialog.getOpenFileName()

        if file_path: 
            # Kiểm tra xem tập tin có đuôi mở rộng là .mp4 (để xác định là tập tin video)
            if os.path.splitext(file_path)[1].lower() != '.mp4':
                continue
            else:
                break  # Thoát khỏi vòng lặp khi đã chọn tập tin video hợp lệ
    app.exit()
    
    params['path_to_vid'] = file_path
    task_path = f"task.{task_folder_path.value}.{method_folder_path.value}.models.{model_name}.tool"
    if infer_pth:
        importlib.import_module(task_path).track_video_run(
            cfg, model, params)
    else:
        importlib.import_module(task_path).track_video_run_onnx(
            cfg, model, params)


# ====================================
# widgets
# ====================================
button_pth = widgets.Button(description="Model .pth Change")
button_onnx = widgets.Button(description="Model .onnx Change")
button_folder = widgets.Button(description="Select folder")
input_box = widgets.HBox([task_folder_path, method_folder_path, dataset_folder_path,
                         model_name_list], layout=widgets.Layout(justify_content='center'))
button_box = widgets.HBox([button_pth, button_onnx],
                          layout=widgets.Layout(justify_content='center'))
button_model = widgets.Button(description="Model: ", layout=widgets.Layout(width='auto'), disabled=False)
button_model.button_style = "info"
button_save = widgets.Button(description="Save All Result")
button_show_train_result = widgets.Button(description="Show Train Result")
button_track_camera = widgets.Button(description="Tracking Camera")
button_track_video = widgets.Button(description="Tracking Video")
hbox = widgets.HBox([button_model, button_save, button_folder], layout=widgets.Layout(justify_content='center'))
output = widgets.Output()

# Callback function for save button
@output.capture()
def on_button_save_clicked(b):
    button_save.button_style = "danger"
    save_all_result()
    button_save.button_style = ""

# Callback function for show result button
@output.capture()
def on_button_show_train_result(b):
    button_show_train_result.button_style = "danger"
    show_train_result()
    button_show_train_result.button_style = ""
    
# Callback function for track button
@output.capture()
def on_track_camera_button_clicked(b):
    button_track_camera.button_style = "danger"
    track_camera_result()
    button_track_camera.button_style = ""
    
# Callback function for track button
@output.capture()
def on_track_video_button_clicked(b):
    button_track_camera.button_style = "danger"
    track_video_result()
    button_track_camera.button_style = ""

# Callback function for model change button
@output.capture()
def on_button_clicked(b):
    clear_output(wait=True)
    global infer_pth

    button_pth.button_style = "success" if b is button_pth else ""
    button_onnx.button_style = "success" if b is button_onnx else ""
    button_folder.button_style = "success" if b is button_folder else ""

    if b is button_pth:
        infer_pth = True
    elif b is button_onnx:
        infer_pth = False

    if b is button_pth or b is button_onnx:
        load_model()
    elif b is button_folder:
        load_folder()

    # vizualization
    file_list = widgets.Dropdown(
        options=[(file_path, i) for i, file_path in enumerate(testset.file_list)],
        value=0,
        description='Image:',
    )

    widgets.interact(result_plot, idx=file_list)
    button_model.description = f"Model {'pth' if infer_pth else 'onnx'}: {model_name_list.value}"
    
    if task_folder_path.value.lower() == 'objectdetection':
        hbox.children = [button_model, button_save, button_folder, button_track_camera, button_track_video]
    elif task_folder_path.value.lower() == 'classify':
        hbox.children = [button_model, button_save, button_folder, button_show_train_result]
    else:
        hbox.children = [button_model, button_save, button_folder]
    display(hbox)

# Attach callbacks to widgets
button_pth.on_click(on_button_clicked)
button_onnx.on_click(on_button_clicked)
button_save.on_click(on_button_save_clicked)
button_show_train_result.on_click(on_button_show_train_result)
button_folder.on_click(on_button_clicked)
button_track_camera.on_click(on_track_camera_button_clicked)
button_track_video.on_click(on_track_video_button_clicked)

# Display widgets
display(input_box, button_box, output)
