In [1]:
USE_NETWORKX = True

In [2]:
# Some imports and helper functions
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import imageio
import time
import cv2
if USE_NETWORKX:
    import networkx as nx
    from networkx.algorithms.flow import preflow_push
else:
    import pygco

def draw_mask_on_image(image, mask, color=(0, 255, 255)):
    """Return a visualization of a mask overlaid on an image."""
    result = image.copy()
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    dilated = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_DILATE, kernel)
    outline = dilated > mask
    result[mask == 1] = (result[mask == 1] * 0.4 + 
                         np.array(color) * 0.6).astype(np.uint8)
    result[outline] = color
    return result

In [3]:
im = imageio.imread('lotus320.jpg')
h,w = im.shape[:2]
print( h, w)


320 427


In [4]:
im.shape

(320, 427, 3)

In [7]:
# POINTS: 1

im = imageio.imread('lotus320.jpg')
h,w = im.shape[:2]

# Set up initial foreground and background
# regions for building the color model
init_fg_mask = np.zeros([h, w])
init_bg_mask = np.zeros([h, w])

# Now set some rectangular region of the initial foreground mask to 1.
# This should be a part of the image that is fully foreground.
# The indices in the following line are just an example,
# and they need to be corrected so that only flower pixels are included
# init_fg_mask[10:20, 15:30] = 1

# Same for the background (replace the indices)
# init_bg_mask[60:90, 50:110] = 1

# YOUR CODE HERE

init_fg_mask[100:200, 150:300] = 1
init_bg_mask[0:70, 0:100] = 1

#raise NotImplementedError()

fig, axes = plt.subplots(1, 2, figsize=(9.5,5))
axes[0].set_title('Initial foreground mask')
axes[0].imshow(draw_mask_on_image(im, init_fg_mask))
axes[1].set_title('Initial background mask')
axes[1].imshow(draw_mask_on_image(im, init_bg_mask))
fig.tight_layout()
fig.show()

<IPython.core.display.Javascript object>

In [8]:
# POINTS: 3

def calculate_histogram(im, mask, n_bins):
    histogram = np.full((n_bins, n_bins, n_bins), fill_value=0.001)
    
    # YOUR CODE HERE
    binned_im = (im.astype(np.float32)/256*n_bins).astype(int)

    for y in range(im.shape[0]):
        for x in range(im.shape[1]):
            if mask[y, x] != 0:
                histogram[binned_im[y, x, 0],
                          binned_im[y, x, 1], 
                          binned_im[y, x, 2]] += 1
                
    # normalize
    histogram /= np.sum(histogram)
    #raise NotImplementedError()
    return histogram

In [16]:
n_bins = 10
fg_histogram = calculate_histogram(im, init_fg_mask, n_bins)
bg_histogram = calculate_histogram(im, init_bg_mask, n_bins)

fig, axes = plt.subplots(
    3, 2, figsize=(5,5), sharex=True, 
    sharey=True, num='Relative frequency of color bins')

x = np.arange(n_bins)
axes[0,0].bar(x, np.sum(fg_histogram, (1, 2)))
axes[0,0].set_title('red (foreground)')
axes[1,0].bar(x, np.sum(fg_histogram, (0, 2)))
axes[1,0].set_title('green (foreground)')
axes[2,0].bar(x, np.sum(fg_histogram, (0, 1)))
axes[2,0].set_title('blue (foreground)')

axes[0,1].bar(x, np.sum(bg_histogram, (1, 2)))
axes[0,1].set_title('red (background)')
axes[1,1].bar(x, np.sum(bg_histogram, (0, 2)))
axes[1,1].set_title('green (background)')
axes[2,1].bar(x, np.sum(bg_histogram, (0, 1)))
axes[2,1].set_title('blue (background)')
fig.tight_layout()
fig.show()

<IPython.core.display.Javascript object>

In [34]:
def foreground_pmap(img, fg_histogram, bg_histogram):
    #YOUR CODE HERE
    h, w, c = img.shape
    n_bins = len(fg_histogram)
    binned_im = (img.astype(np.float32)/256*n_bins).astype(int)
    
    # prior probabilities
    p_fg = 0.5
    p_bg = 1 - p_fg
    
    # extract fg & bg prob from histograms
    p_rgb_given_fg = fg_histogram[binned_im[:, :, 0],
                                  binned_im[:, :, 1], 
                                  binned_im[:, :, 2]]
    
    p_rgb_given_bg = bg_histogram[binned_im[:, :, 0],
                                  binned_im[:, :, 1],
                                  binned_im[:, :, 2]]
    
    p_fg_given_rgb = (p_fg * p_rgb_given_fg /                                 #using baysian rule
                      (p_fg * p_rgb_given_fg + p_bg * p_rgb_given_bg))
    return p_fg_given_rgb
    #raise NotImplementedError()

In [35]:
foreground_prob = foreground_pmap(im, fg_histogram, bg_histogram)
fig, axes = plt.subplots(1, 2, figsize=(10,5), sharey=True)
axes[0].imshow(im)
axes[0].set_title('Input image')
im_plot = axes[1].imshow(foreground_prob, cmap='viridis')
axes[1].set_title('Foreground posterior probability')
fig.tight_layout()
fig.colorbar(im_plot, ax=axes)
foreground_map = (foreground_prob > 0.5)

<IPython.core.display.Javascript object>

In [37]:
def unary_potentials(probability_map, unary_weight):
    #YOUR CODE HERE
    return -unary_weight * np.log(probability_map)
    
    
unary_weight = 1
unary_fg = unary_potentials(foreground_prob, unary_weight)
unary_bg = unary_potentials(1 - foreground_prob, unary_weight)
fig, axes = plt.subplots(1, 2, figsize=(10,5), sharey=True)
axes[0].imshow(unary_fg)
axes[0].set_title('Unary potentials (foreground)')
im_plot = axes[1].imshow(unary_bg)
axes[1].set_title('Unary potentials (background)')
fig.tight_layout()
fig.colorbar(im_plot, ax=axes)
fig.show()

<IPython.core.display.Javascript object>

In [38]:
def pairwise_potential_prefactor(img, x1, y1, x2, y2, pairwise_weight):
    #YOUR CODE HERE
    return pairwise_weight # * np.exp(-1e-1*np.sum((img[y1,x1]-img[y2,x2])**2))
    

In [39]:
# POINTS: 4

def coords_to_index(x, y, width):
    return y * width + x

def pairwise_potentials(im, pairwise_weight):
    # YOUR CODE HERE
    edges = []
    costs = []

    im = im.astype(np.float32)/255
    h, w = im.shape[:2]
    for y in range(h):
        for x in range(w):
            # Neighbor coordinates
            xs_neigh = x + np.array([0, 1, 0, -1])
            ys_neigh = y + np.array([-1, 0, 1, 0])
            
            # Make sure neighbors are within image
            mask = np.logical_and(
                np.logical_and(xs_neigh >= 0, xs_neigh < w),
                np.logical_and(ys_neigh >= 0, ys_neigh < h))
            xs_neigh = xs_neigh[mask]
            ys_neigh = ys_neigh[mask]
            
            center_index = coords_to_index(x, y, w)
            for x_neigh, y_neigh in zip(xs_neigh, ys_neigh):
                cost = pairwise_potential_prefactor(
                    im, x, y, x_neigh, y_neigh, pairwise_weight)
                neighbor_index = coords_to_index(x_neigh, y_neigh, w)
                edges.append((center_index, neighbor_index))
                costs.append(cost)
                
    edges = np.array(edges)
    costs = np.array(costs)
    #raise NotImplementedError()
    return edges, costs

pairwise_edges, pairwise_costs = pairwise_potentials(im, pairwise_weight=5)

In [40]:
def graph_cut(unary_fg, unary_bg, pairwise_edges, pairwise_costs):
    if USE_NETWORKX:

        graph = nx.Graph()
        s = object()
        t = object()

        edges = []
        for i, cost in enumerate(unary_bg.flat):
            edges.append((s, i, cost))
        for i, cost in enumerate(unary_fg.flat):
            edges.append((i, t, cost))
        for (i,j), cost in zip(pairwise_edges, pairwise_costs):
            edges.append((i, j, cost))

        graph.add_weighted_edges_from(edges, 'capacity')

        nodes_connected_to_s = nx.minimum_cut(
            graph, s, t, flow_func=preflow_push)[1][0]

        fg_pixel_indices = list(set(nodes_connected_to_s) - {s})
        labels = np.zeros_like(unary_fg, dtype=np.int)
        labels.flat[fg_pixel_indices] = 1
        return labels
    else:
        unaries = np.stack([unary_bg.flat, unary_fg.flat], axis=-1)
        labels = pygco.cut_general_graph(
            pairwise_edges, pairwise_costs, unaries, 
            1-np.eye(2), n_iter=-1, algorithm='swap')
        return labels.reshape(unary_fg.shape)

graph_cut_result = graph_cut(unary_fg, unary_bg, pairwise_edges, pairwise_costs)
fig, axes = plt.subplots(1, 2, figsize=(9.5,5), sharey=True)
axes[0].set_title('Thresholding of per-pixel foreground probability at 0.5')
axes[0].imshow(draw_mask_on_image(im, foreground_prob>0.5))
axes[1].set_title('Graph cut result')
axes[1].imshow(draw_mask_on_image(im, graph_cut_result))
fig.tight_layout()
fig.show()

<IPython.core.display.Javascript object>