In [16]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
from numba import njit, prange

In [17]:
theta = np.array([0, 0.5, 0.5])
eps = 1e-5
data = {
    'none': {
        'color': np.array([255, 255, 255]),
        'index': 0
    },
    'background': {
        'color': np.array([0, 0, 255]),
        'index': 1
    },
    'foreground': {
        'color': np.array([255, 0, 0]),
        'index': 2
    },
}

In [18]:
@njit(parallel=True)
def growcut(image, num_mask, theta, eps):
    height, width, _ = image.shape
    for _ in range(50):
        for row_idx in prange(2, height - 2):
            for col_idx in prange(2, width - 2):
                win = num_mask[row_idx, col_idx]
                a_max = 0
                for window_row_idx in range(-2, 3):
                    for window_col_idx in range(-2, 3):
                        if num_mask[row_idx + window_row_idx, col_idx + window_col_idx] == 0 and num_mask[row_idx, col_idx] == 0:
                            continue
                        a = 1 / (np.sqrt(np.sum(np.power(image[row_idx, col_idx] - image[row_idx + window_row_idx, col_idx + window_col_idx], 2))) + eps)
                        if a - a_max > -1e-5:
                            win = num_mask[row_idx + window_row_idx, col_idx + window_col_idx]
                            a_max = a
                if win != 0 and theta[win] * a_max > theta[num_mask[row_idx, col_idx]]:
                    num_mask[row_idx, col_idx] = win
    return num_mask

In [19]:
def growCut_segmentation():
    height, width, _ = image.shape
    mask = cv2.resize(origin_mask, (width, height), interpolation=cv2.INTER_AREA)
    repeated = np.swapaxes(np.repeat(mask, 3, axis=2).reshape((height, width, 3, 3)), 2, 3)
    concatenated = np.concatenate(list(map(lambda item: item[1]['color'], sorted(data.items(), key=lambda item: item[1]['index'])))).reshape((3, 3))
    color_dists = np.linalg.norm(repeated[:, :] - concatenated, axis=3)
    num_mask = np.argmin(color_dists, axis=2)

    num_mask = growcut(image.copy(), num_mask.copy(), theta.copy(), eps)

    segmentation = image.copy()
    segmentation[num_mask != data['foreground']['index']] = np.array([0, 0, 0])
    plt.imshow(segmentation)


In [None]:
image = cv2.cvtColor(cv2.imread('image.png'), cv2.COLOR_BGR2RGB)
origin_mask = cv2.cvtColor(cv2.imread('mask_image.png'), cv2.COLOR_BGR2RGB)
plt.imshow(np.concatenate([origin_mask, image], axis=1))

In [None]:
growCut_segmentation()