# Gradio demo MMDetection
> This will contain an interactive demo, built with gradio  
> Select an appropriate kernel before executing

Based on [hysts' mmdetection huggingface space](https://huggingface.co/spaces/hysts/mmdetection)

In [None]:
import gradio as gr
import os

In [None]:
os.chdir('../model_repos/VAN-Detection/')

In [None]:
import wandb

wandb_api = wandb.Api()

In [None]:
run_base_path = 'nkoch-aitastic/van-detection/run_{model}_model:latest'

model_dict = {
    "FCOS-ResNet50": {
        # 'config': 'configs/refs/fcos_r50_caffe_fpn_gn-head_1x_coco_adam.py',
        'config': 'configs/refs/fcos_r50_caffe_fpn_gn-head_1x_coco.py',
        'model': '3ouwht0k',
    },
    "FCOS-VAN-B0": {
        'config': 'configs/det_fpn/fcos_van_b0_fpn_dcn_1x_coco_adam_scp_bof.py',
        'model': 'xhgm8eyk',
    },
    "FCOS-VAN-B2": {
        'config': 'configs/det_fpn/fcos_van_b2_fpn_coco_adam_scp.py',
        'model': '111lxdne',
    },
    # "ATSS": {
    #     'config': 'configs/det_fpn/atss_van_b2_fpn_dyhead_coco.py',
    #     'model': '3d74nm6r',
    # },
    # TODO upload model
}

DEFAULT_MODEL_NAME = 'FCOS-VAN-B0'
DESCRIPTION = 'An interactive demonstration of various MMDetection based models'


In [None]:
# Model handling
import os

import huggingface_hub
import numpy as np
import torch
import torch.nn as nn
import yaml
from mmdet.apis import inference_detector, init_detector


def _load_model_dict(path: str) -> dict[str, dict[str, str]]:
    with open(path) as f:
        dic = yaml.safe_load(f)
    _update_config_path(dic)
    return dic


def _update_config_path(model_dict: dict[str, dict[str, str]]) -> None:
    for dic in model_dict.values():
        dic['config'] = dic['config'].replace(
            'https://github.com/open-mmlab/mmdetection/tree/master',
            'mmdet_configs')


class Model:
    # DETECTION_MODEL_DICT = _load_model_dict('model_dict/detection.yaml')
    # INSTANCE_SEGMENTATION_MODEL_DICT = _load_model_dict(
    #     'model_dict/instance_segmentation.yaml')
    # PANOPTIC_SEGMENTATION_MODEL_DICT = _load_model_dict(
    #     'model_dict/panoptic_segmentation.yaml')
    MODEL_DICT = model_dict

    def __init__(self, model_name: str, device: str | torch.device):
        self.device = torch.device(device)
        self._load_all_models_once()
        self.model_name = model_name
        self.model = self._load_model(model_name)

    def _load_all_models_once(self) -> None:
        for name in self.MODEL_DICT:
            self._load_model(name)

    def _load_model(self, name: str) -> nn.Module:
        model_dict = self.MODEL_DICT[name]
        print(model_dict)
        print(model_dict.keys())
        artifact = wandb_api.artifact(run_base_path.format(model=model_dict['model']), type='model')
        chkpt_path = artifact.download()
        # FIXME this is gonna suck
        # most runs use the reference implemenation of VAN by the original authors, but some use
        # the implementation of the mmdet team.
        # They're not compatible.
        # The workaround below will switch to the original implementation when necessary, but it won't switch back.
        # FIXME also the chkpt_path won't always be epoch_12, I need to figure something out for this
        try:
            return init_detector(model_dict['config'], chkpt_path + '/epoch_12.pth', device=self.device)
        except TypeError:
            import van
            return init_detector(model_dict['config'], chkpt_path + '/epoch_12.pth', device=self.device)

    def set_model(self, name: str) -> None:
        if name == self.model_name:
            return
        self.model_name = name
        self.model = self._load_model(name)

    def detect_and_visualize(
        self, image: np.ndarray, score_threshold: float
    ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
                                        list[list[np.ndarray]]]
               | dict[str, np.ndarray], np.ndarray]:
        out = self.detect(image)
        vis = self.visualize_detection_results(image, out, score_threshold)
        return out, vis

    def detect(
        self, image: np.ndarray
    ) -> list[np.ndarray] | tuple[
            list[np.ndarray], list[list[np.ndarray]]] | dict[str, np.ndarray]:
        image = image[:, :, ::-1]  # RGB -> BGR
        out = inference_detector(self.model, image)
        return out

    def visualize_detection_results(
            self,
            image: np.ndarray,
            detection_results: list[np.ndarray]
        | tuple[list[np.ndarray], list[list[np.ndarray]]]
        | dict[str, np.ndarray],
            score_threshold: float = 0.3) -> np.ndarray:
        image = image[:, :, ::-1]  # RGB -> BGR
        vis = self.model.show_result(image,
                                     detection_results,
                                     score_thr=score_threshold,
                                     bbox_color=None,
                                     text_color=(200, 200, 200),
                                     mask_color=None)
        return vis[:, :, ::-1]  # BGR -> RGB


class AppModel(Model):
    def run(
        self, model_name: str, image: np.ndarray, score_threshold: float
    ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
                                        list[list[np.ndarray]]]
               | dict[str, np.ndarray], np.ndarray]:
        self.set_model(model_name)
        return self.detect_and_visualize(image, score_threshold)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Some UI Helpers

import pathlib
import subprocess
import tarfile

import cv2
import numpy as np


def update_input_image(image: np.ndarray) -> dict:
    if image is None:
        return gr.Image.update(value=None)
    scale = 1500 / max(image.shape[:2])
    if scale < 1:
        image = cv2.resize(image, None, fx=scale, fy=scale)
    return gr.Image.update(value=image)


def update_model_name(model_type: str) -> dict:
    model_dict = getattr(AppModel, f'{model_type.upper()}_MODEL_DICT')
    model_names = list(model_dict.keys())
    model_name = DEFAULT_MODEL_NAMES[model_type]
    return gr.Dropdown.update(choices=model_names, value=model_name)


def update_visualization_score_threshold(model_type: str) -> dict:
    return gr.Slider.update(visible=model_type != 'panoptic_segmentation')


def update_redraw_button(model_type: str) -> dict:
    return gr.Button.update(visible=model_type != 'panoptic_segmentation')


def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])


In [None]:
#|output: false
# if demo:
    # demo.close()

model = AppModel(DEFAULT_MODEL_NAME, 'cuda')
    
with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(label='Input Image', type='numpy')
            with gr.Group():
                with gr.Row():
                    model_name = gr.Dropdown(list(
                        model.MODEL_DICT.keys()),
                                             value=DEFAULT_MODEL_NAME,
                                             label='Model')
            with gr.Row():
                run_button = gr.Button(value='Run')
                prediction_results = gr.Variable()
        with gr.Column():
            with gr.Row():
                visualization = gr.Image(label='Result', type='numpy')
            with gr.Row():
                visualization_score_threshold = gr.Slider(
                    0,
                    1,
                    step=0.05,
                    value=0.3,
                    label='Visualization Score Threshold')
            with gr.Row():
                redraw_button = gr.Button(value='Redraw')

    with gr.Row():
        paths = sorted(pathlib.Path('images').rglob('*.jpg'))
        example_images = gr.Dataset(components=[input_image],
                                    samples=[[path.as_posix()]
                                             for path in paths])

    input_image.change(fn=update_input_image,
                       inputs=input_image,
                       outputs=input_image)

    # model_type.change(fn=update_model_name,
    #                   inputs=model_type,
    #                   outputs=model_name)
    # model_type.change(fn=update_visualization_score_threshold,
    #                   inputs=model_type,
    #                   outputs=visualization_score_threshold)
    # model_type.change(fn=update_redraw_button,
    #                   inputs=model_type,
    #                   outputs=redraw_button)

    model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
    run_button.click(fn=model.run,
                     inputs=[
                         model_name,
                         input_image,
                         visualization_score_threshold,
                     ],
                     outputs=[
                         prediction_results,
                         visualization,
                     ])
    redraw_button.click(fn=model.visualize_detection_results,
                        inputs=[
                            input_image,
                            prediction_results,
                            visualization_score_threshold,
                        ],
                        outputs=visualization)
    example_images.click(fn=set_example_image,
                         inputs=example_images,
                         outputs=input_image)


demo.launch()   

{'config': 'configs/refs/fcos_r50_caffe_fpn_gn-head_1x_coco.py', 'model': '3ouwht0k'}
dict_keys(['config', 'model'])


[34m[1mwandb[0m: Downloading large artifact run_3ouwht0k_model:latest, 367.90MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.0


load checkpoint from local path: ./artifacts/run_3ouwht0k_model:v0/epoch_12.pth
{'config': 'configs/det_fpn/fcos_van_b0_fpn_dcn_1x_coco_adam_scp_bof.py', 'model': 'xhgm8eyk'}
dict_keys(['config', 'model'])


[34m[1mwandb[0m: Downloading large artifact run_xhgm8eyk_model:latest, 137.37MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.0
2023-01-09 10:42:38,023 - root - INFO - ModulatedDeformConvPack bbox_head.cls_convs.3.conv is upgraded to version 2.
2023-01-09 10:42:38,025 - root - INFO - ModulatedDeformConvPack bbox_head.reg_convs.3.conv is upgraded to version 2.


load checkpoint from local path: ./artifacts/run_xhgm8eyk_model:v0/epoch_12.pth
{'config': 'configs/det_fpn/fcos_van_b2_fpn_coco_adam_scp.py', 'model': '111lxdne'}
dict_keys(['config', 'model'])


[34m[1mwandb[0m: Downloading large artifact run_111lxdne_model:latest, 391.94MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.0


load checkpoint from local path: ./artifacts/run_111lxdne_model:v0/epoch_12.pth
{'config': 'configs/det_fpn/fcos_van_b0_fpn_dcn_1x_coco_adam_scp_bof.py', 'model': 'xhgm8eyk'}
dict_keys(['config', 'model'])


[34m[1mwandb[0m: Downloading large artifact run_xhgm8eyk_model:latest, 137.37MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.0
2023-01-09 10:42:39,928 - root - INFO - ModulatedDeformConvPack bbox_head.cls_convs.3.conv is upgraded to version 2.
2023-01-09 10:42:39,930 - root - INFO - ModulatedDeformConvPack bbox_head.reg_convs.3.conv is upgraded to version 2.


load checkpoint from local path: ./artifacts/run_xhgm8eyk_model:v0/epoch_12.pth
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.






In [None]:
# demo.close()

Closing server running on port: 7860
