# Object Masking

## Load Datasets

In [None]:
import datetime
import itertools as it
import random
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

In [None]:
def load_images_from_folder(folder):
    images = []
    img_lists = os.listdir(folder)
    img_lists = sorted(img_lists, key=lambda p: int(Path(p).stem))
    for filename in img_lists:
        img = cv2.imread(os.path.join(folder, filename), cv2.IMREAD_COLOR)
        if img is not None:
            images.append(img)
    return images

In [None]:
images_path = '../data/datasets/vegetable_dog/images'
images = load_images_from_folder(images_path)
sample_size = 10
sequential = True
plot_dpi = 300

images_count = len(images)
if sequential:
    start = random.randint(0, images_count - sample_size)
    taken_images_index = [*range(start, start + sample_size)]
else:
    taken_images_index = random.sample(range(images_count), sample_size if images_count > sample_size else images_count)
print('Taken {} from {} images'.format(taken_images_index, images_count))

sample_images = [images[i] for i in taken_images_index]

In [None]:
def display_samples(images):
    fig, ax = plt.subplots(1, len(images), figsize=(14.5, 6), constrained_layout=True, sharey=True)
    fig.set_dpi(plot_dpi)
    fig.tight_layout()

    for i in range(len(images)):
        ax[i].set_title('Color Image {}'.format(i))
        ax[i].imshow(images[i])

display_samples(sample_images)

## Extract Objects by Masks

### DIS

#### Build Model

In [None]:
import sys
import torch
from torchvision import transforms

thirdparty_dir = os.path.join(os.getcwd(), '../thirdparty')
sys.path.append(thirdparty_dir)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Running inference on device \'{}\''.format(device))

In [None]:
from DIS.IS_Net.data_loader_cache import normalize, im_reader, im_preprocess 
from DIS.IS_Net.models.isnet import ISNetGTEncoder, ISNetDIS

hypar = {
    'model_path': '../data/models/is-net',  # load trained weights from this path
    'restore_model': 'isnet.pth',           # name of the to-be-loaded weights
    'interm_sup': False,                    # indicate if activate intermediate feature supervision
    'model_digit': 'full',                  # indicates 'half' or 'full' accuracy of float number
    'seed': 0,
    'cache_size': [1024, 1024],             # cached input spatial resolution, can be configured into different size
    'input_size': [1024, 1024],             # mdoel input spatial size, usually use the same value hypar['cache_size'], which means we don't further resize the images
    'crop_size': [1024, 1024],              # random crop size from the input, it is usually set as smaller than hypar['cache_size'], e.g., [920,920] for data augmentation
    'model': ISNetDIS()
}                                           # paramters for inferencing

In [None]:
def build_model(hypar, device):
    net = hypar['model'] # GOSNETINC(3, 1)
    # convert to half precision
    if(hypar['model_digit'] == 'half'):
        net.half()
        for layer in net.modules():
            if isinstance(layer, torch.nn.BatchNorm2d):
                layer.float()
    net.to(device)
    if(hypar['restore_model'] != ''):
        net.load_state_dict(torch.load(hypar['model_path'] + '/' + hypar['restore_model'], map_location=device))
        net.to(device)
    net.eval() 
    return net

dis_net_full = build_model(hypar, device)
dis_net_full.eval()

#### Inference

In [None]:
# Normalize the Image using torch.transforms
class GOSNormalize(object):
    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        image = normalize(image, self.mean, self.std)
        return image

transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])

def normalize_2_tensor_with_size(im, hypar):
    im, im_shp = im_preprocess(im, hypar['cache_size'])
    im = torch.divide(im, 255.0)
    shape = torch.from_numpy(np.array(im_shp))
    return transform(im).unsqueeze(0), shape.unsqueeze(0) 

normalized_sample_images = [normalize_2_tensor_with_size(im, hypar) for im in sample_images]

In [None]:
def dis_predict_masks(dis_net, images_tensor_with_size):
    predicted_masks = []
    time_ms = []

    for (inputs_val, shapes_val) in images_tensor_with_size:
        inputs_val = inputs_val.type(torch.FloatTensor if hypar['model_digit'] == 'full' else torch.HalfTensor)
        inputs_val_v = torch.autograd.Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable

        b = datetime.datetime.now()
        ds_val = dis_net(inputs_val_v)[0] # list of 6 results
        e = datetime.datetime.now()
        time_ms.append((e - b).microseconds / 1000)

        pred_val = ds_val[0][0, :, :, :] # B x 1 x H x W, we want the first one which is the most accurate prediction
        # recover the prediction spatial size to the orignal image size
        pred_val = torch.squeeze(torch.nn.functional.upsample(
            torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]),
            mode='bilinear')
        )
        ma = torch.max(pred_val)
        mi = torch.min(pred_val)
        pred_val = (pred_val - mi) / (ma - mi) # max = 1
        mask = (pred_val.detach().cpu().numpy() * 255).astype(np.uint8) # it is the mask we need
        predicted_masks.append(mask)
    
    return predicted_masks, time_ms

dis_full_predicted_masks, dis_full_time_ms = dis_predict_masks(dis_net_full, normalized_sample_images)
print(dis_full_time_ms)

In [None]:
from PIL import Image

def display_mask_results(original_images, masks):
    fig, ax = plt.subplots(
        len(original_images), 2, figsize=(14.5, 4 * len(original_images)),
        gridspec_kw={'width_ratios': [2, 1]},
        constrained_layout=True, sharey=True
    )
    fig.set_dpi(plot_dpi)
    fig.tight_layout()

    for i in range(0, len(original_images)):
        mask_3 = cv2.cvtColor(masks[i], cv2.COLOR_GRAY2RGB)
        (h0, w0), (h1, w1) = original_images[i].shape[:2], mask_3.shape[:2]
        stitch_image = np.empty((max(h0, h1), w0 + w1, 3), dtype=np.uint8)
        stitch_image[:h0, :w0, :3] = original_images[i]
        stitch_image[:h1, w0:, :3] = mask_3
        pil_mask = Image.fromarray(mask_3).convert('L')
        pil_img = Image.fromarray(original_images[i])
        pil_img_rgba = pil_img.copy()
        pil_img_rgba.putalpha(pil_mask)
        
        ax[i, 0].set_title('Original Image {} / Mask'.format(i))
        ax[i, 0].imshow(stitch_image)
        ax[i, 1].set_title('Masked Image {}'.format(i))
        ax[i, 1].imshow(pil_img_rgba)
        
display_mask_results(sample_images, dis_full_predicted_masks)

#### Inference (Half Precision)

In [None]:
hypar['model_digit'] = 'half'
dis_net_half = build_model(hypar, device)
dis_net_half.eval()

In [None]:
dis_half_predicted_masks, dis_half_time_ms = dis_predict_masks(dis_net_half, normalized_sample_images)
print(dis_half_time_ms)

In [None]:
display_mask_results(sample_images, dis_half_predicted_masks)

#### Time Statics

In [None]:
import matplotlib.cm as cm

def display_time_consumption(durations_with_name):
    total_width, n = 0.8, len(durations_with_name)
    width = total_width / n
    x_max = 0
    cmap = cm.get_cmap()

    plt.figure(figsize=(14.5, 5), dpi=plot_dpi)
    plt.xlabel('DIS Predict Index')
    plt.ylabel('Time (ms)')
    plt.title('Time Consumption for Each Prediction')

    for ((y, t), i) in zip(durations_with_name, range(0, n)):
        x = np.arange(0, len(y))
        x_max = max(len(y), x_max)
        avg = np.average(y)
        color = cmap(y * 255)
        
        bar = plt.bar(x + (i * width - (total_width - width) / 2), y, width=width, label='{}, average {:.3f}ms'.format(t, avg))
        axhline = plt.axhline(avg, color=bar.patches[0].get_facecolor(), linestyle='--')
        plt.legend()

    plt.xticks(range(0, x_max, 1))
    plt.yscale('function', functions=(lambda v: v ** 0.5, lambda v: v ** 2))

dis_half_time_ms = [741.184, 857.849, 18.552, 938.649, 347.034, 197.042, 740.818, 152.959, 817.542, 331.542] # test only

durations_with_name = [
    (dis_full_time_ms, 'DIS (FP32, CPU, M1)'),
    (dis_half_time_ms, 'DIS (FP16, CPU, M1, test only)')
]

display_time_consumption(durations_with_name)