In [None]:
#  The MIT License (MIT)
#
#  Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the 'Software'), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in
#  all copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
#  THE SOFTWARE.


## Import MIGraphX Python Library

In [None]:
import migraphx
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

## Fetch U-NET ONNX Model

In [None]:
!wget -nc https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx

## Load ONNX Model

In [None]:
model = migraphx.parse_onnx("unet_13_256.onnx")

In [None]:
model.compile(migraphx.get_target("gpu"))

## Print model parameters

In [None]:
print(model.get_parameter_names())
print(model.get_parameter_shapes())

In [None]:
def preprocess(pil_img, newW, newH):
    w, h = pil_img.size
    assert newW > 0 and newH > 0, 'Scale is too small'
    pil_img = pil_img.resize((newW, newH))

    img_nd = np.array(pil_img)

    if len(img_nd.shape) == 2:
        img_nd = np.expand_dims(img_nd, axis=2)

    # HWC to CHW
    img_print = pil_img
    img_trans = img_nd.transpose((2, 0, 1))
    if img_trans.max() > 1:
        img_trans = img_trans / 255
        
    img_trans = np.expand_dims(img_trans, 0)

    return img_trans, img_print

def plot_img_and_mask(img, mask):
    classes = mask.shape[0] if len(mask.shape) > 3 else 1
    print(classes)
    fig, ax = plt.subplots(1, classes + 1)
    ax[0].set_title('Input image')
    ax[0].imshow(img)
    if classes > 1:
        for i in range(classes):
            ax[i+1].set_title(f'Output mask (class {i+1})')
            ax[i+1].imshow(mask[:, :, i])
    else:
        ax[1].set_title(f'Output mask')
        ax[1].imshow(mask[0,0])
    plt.xticks([]), plt.yticks([])
    plt.show()

In [None]:
img = Image.open("./car1.jpeg")
img, imPrint = preprocess(img, 256, 256)
input_im = np.zeros((1,3,256,256),dtype='float32') 
np.lib.stride_tricks.as_strided(input_im, shape=img.shape, strides=input_im.strides)[:] = img #getting correct stride
print(input_im.strides)
print(input_im.shape)
imPrint.show()

In [None]:
mask = model.run({'inputs':input_im}) # Your first inference would take longer than the following ones.
output_mask = np.array(mask[0])
print(output_mask.shape)

In [None]:
def sigmoid(x):
  return 1 / (1 + np.exp(-x))

In [None]:
probs = sigmoid(output_mask)
full_mask = probs > 0.996
plot_img_and_mask(imPrint, full_mask)

<b>NOTE:</b> The model weights utilized here are trained by using car images with plain backgrounds. The imperfect result on a "real-world" image as shown above is expected. To get a better result fine-tuning the model on a dataset of real-world examples is recommended. 