In [None]:
from mantisshrimp.all import *
import PIL, requests
import torch
from torchvision import transforms

In [None]:
MASK_PENNFUNDAN_WEIGHTS_URL = "https://mantisshrimp-models.s3.us-east-2.amazonaws.com/pennfundan_maskrcnn_resnet50fpn.zip"
# img_url = "https://raw.githubusercontent.com/ai-fast-track/ice-streamlit/master/images/image2.png"

In [None]:
def load_model(class_map=class_map, url=None):
    if url is None:
        # print("Please provide a valid URL")
        return None
    else:
        model = mask_rcnn.model(num_classes=len(class_map))
        state_dict = torch.hub.load_state_dict_from_url(
            url, map_location=torch.device("cpu")
        )
        model.load_state_dict(state_dict)
        return model

In [None]:
class_map = datasets.pennfundan.class_map()
model = load_model(class_map=class_map, url=MASK_PENNFUNDAN_WEIGHTS_URL)
print("class_map: ", class_map)

In [None]:
def predict(
    model, image, detection_threshold: float = 0.5, mask_threshold: float = 0.5
):
    # img = np.array(image)
    tfms_ = tfms.A.Adapter([tfms.A.Normalize()])
    # Whenever you have images in memory (numpy arrays) you can use `Dataset.from_images`
    infer_ds = Dataset.from_images([img], tfms_)

    batch, samples = mask_rcnn.build_infer_batch(infer_ds)
    preds = mask_rcnn.predict(
        model=model,
        batch=batch,
        detection_threshold=detection_threshold,
        mask_threshold=mask_threshold,
    )
    return samples[0]["img"], preds[0]


def show_prediction(img, pred, bbox=False, class_map=None):
    """Returns a PIL image"""
    show_pred(
        img=img,
        pred=pred,
        class_map=class_map,
        denormalize_fn=denormalize_imagenet,
        show=True,
        bbox=bbox,
    )

    # Grab image from the current matplotlib figure
    fig = plt.gcf()
    fig.canvas.draw()
    fig_arr = np.array(fig.canvas.renderer.buffer_rgba())
    img = PIL.Image.fromarray(fig_arr)
    return img


def get_masks(input_image):
    # input_image = PIL.Image.open(io.BytesIO(binary_image)).convert("RGB")
    img, pred = predict(model=model, image=input_image)
    return show_prediction(img=img, pred=pred, class_map=class_map)

In [None]:

gr_interface = gr.Interface(fn=get_masks, inputs=gr.inputs.Image(shape=(512, 512)), outputs="image", title='IceVision - Instance Segmentation')
gr_interface.launch(inline=False,share=True)