<a href="https://colab.research.google.com/github/ai-fast-track/icevision-gradio/blob/master/IceApp_masks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install icedata

In [None]:
pip install gradio

In [3]:
from icevision.all import *
import icedata
import PIL, requests
import torch
from torchvision import transforms
import gradio as gr

In [None]:
class_map = icedata.pennfudan.class_map()
model = icedata.pennfudan.trained_models.mask_rcnn_resnet50_fpn()

In [5]:
def predict(
    model, image, detection_threshold: float = 0.5, mask_threshold: float = 0.5
):
    tfms_ = tfms.A.Adapter([tfms.A.Normalize()])
    # Whenever you have images in memory (numpy arrays) you can use `Dataset.from_images`
    infer_ds = Dataset.from_images([image], tfms_)

    batch, samples = mask_rcnn.build_infer_batch(infer_ds)
    preds = mask_rcnn.predict(
        model=model,
        batch=batch,
        detection_threshold=detection_threshold,
        mask_threshold=mask_threshold,
    )
    return samples[0]["img"], preds[0]

def get_masks(input_image, display_list, detection_threshold, mask_threshold):
    display_label = ("Label" in display_list)
    display_bbox = ("BBox" in display_list)
    display_mask = ("Mask" in display_list)

    if detection_threshold==0: detection_threshold=0.5
    if mask_threshold==0: mask_threshold=0.5
    
    img, pred = predict(model=model, image=input_image, detection_threshold=detection_threshold, mask_threshold=mask_threshold)
    # print(pred)
    img = draw_pred(img=img, pred=pred, class_map=class_map, denormalize_fn=denormalize_imagenet, display_label=display_label, display_bbox=display_bbox, display_mask=display_mask)
    img = PIL.Image.fromarray(img)
    # print("Output Image: ", img.size, type(img))
    return img

In [None]:
display_chkbox = gr.inputs.CheckboxGroup(["Label", "BBox", "Mask"], label="Display")
detection_threshold_slider = gr.inputs.Slider(minimum=0, maximum=1, step=0.1, default=0.5, label="Detection Threshold")
mask_threshold_slider = gr.inputs.Slider(minimum=0, maximum=1, step=0.1, default=0.5, label="Mask Threshold")

outputs = gr.outputs.Image(type="pil")

gr_interface = gr.Interface(fn=get_masks, inputs=["image", display_chkbox, detection_threshold_slider, mask_threshold_slider], outputs=outputs, title='IceApp - Masks')
gr_interface.launch(inline=False, share=True, debug=True)
