# SLIC Superpixels
stough, 202-

DIP 10.5

This shows [Simple Linear Iterative Clustering](https://www.pyimagesearch.com/2014/07/28/a-slic-superpixel-tutorial-using-python/) at work to show superpixel segmentation of images. We'll split this up to look at the algorithm in detail. [`skimage`](https://scikit-image.org/docs/dev/api/skimage.segmentation.html) includes numerous segmentation methods, but we're interested here in understanding the details.

In [1]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import cdist
import scipy.stats as stats
import scipy.ndimage as ndimage

import skimage.filters as filters
import skimage.color as color

import sys  
sys.path.insert(0, '../dip_utils')

from matrix_utils import (arr_info,
                          make_linmap)
from vis_utils import (vis_rgb_cube,
                       vis_hists,
                       vis_pair,
                       vis_surface)

IMAGE = 'clown_fish.jpg'
# IMAGE = 'bellagio.jpg'

SUPERPIXELS = 300
MAXITER = 6 #20
T = 10 # convergence threshold.

## Pre-defined functions

In [2]:
#clean-up function for the label image.
def myfunc(x):
    return stats.mode(x, axis=None)[0]

def reinitM():
    global M, I, s, ishape

    # Should pick the lowest gradient pixel among neighbors of each cluster
    # to reinitialize it. So, get gradient image.
    gradI = filters.sobel(color.rgb2gray(I))

    # Now, loop over each cluster and see if we should move it slightly.
    for mi in range(len(M)):
        i, j = [int(q) for q in np.round(M[mi, 3:])]
        curbest = gradI[i,j] # minimum gradient so far.
        bi, bj = i, j # best i and j so far.

        # look over all our immediate neighbors for lowest gradient.
        for x in range(i-1, i+2):
            if x < 0 or x > I.shape[0]:
                continue
            for y in range(j-1, j+2):
                if y < 0 or y > I.shape[1]:
                    continue

                if gradI[x,y] < curbest:
                    curbest = gradI[x,y]
                    bi, bj = x, y

        M[mi, :] = np.concatenate((I[bi, bj], [bi, bj]), axis=0)

def convergedYet():
    global M, Mprev, iteration
    normdiff = np.sqrt(np.sum((M-Mprev)**2))
    print('iteration %d normdiff %f.' % (iteration, normdiff))
    return normdiff < T

# Now like in the paper, the borders between superpixels in white
def onEdge(x):
    return ((x[1] != x[4]) or (x[3]!=x[4]) or (x[5]!=x[4]) or (x[7]!=x[4]))

## Load the image

In [3]:
I = plt.imread('../dip_pics/' + IMAGE).astype('float')
ishape = I.shape

# s = the sampling interval.
s = int(np.round(np.sqrt((ishape[0]*ishape[1])//SUPERPIXELS)))

# cs = the color distance scaler square, see Eq. 10-91
# cs = 3*255**2 #Just the maximum color discrepancy: weights regular regions highly.
cs = 100**2 # Trying for image boundaries instead of spatial regularity.

In [4]:
vis_hists(I/255)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
s

13

In [6]:
256/13

19.692307692307693

In [7]:
# Make M, the initial cluster centers. This should be SUPERPIXELS x 5, for
# the spatial coords of the cluster centers and the average color of them.
M = np.concatenate([np.expand_dims(IC, axis=1) for IC in
                    [I[s::s, s::s, 0].ravel(),
                     I[s::s, s::s, 1].ravel(),
                     I[s::s, s::s, 2].ravel(),
                     ]], axis=1)
xm = np.meshgrid(np.arange(s, ishape[0], s), np.arange(s, ishape[1], s), indexing='ij')
M = np.append(M, np.concatenate([np.expand_dims(xi, axis=1) for xi in
                                 [x.ravel() for x in xm]], axis=1),
              axis=1)

In [8]:
M.shape

(288, 5)

In [9]:
M[:3, :]

array([[ 52., 114., 139.,  13.,  13.],
       [ 28.,  90., 141.,  13.,  26.],
       [ 23.,  57.,  59.,  13.,  39.]])

In [10]:
xm[0], xm[1]

(array([[ 13,  13,  13,  13,  13,  13,  13,  13,  13,  13,  13,  13,  13,
          13,  13,  13,  13,  13],
        [ 26,  26,  26,  26,  26,  26,  26,  26,  26,  26,  26,  26,  26,
          26,  26,  26,  26,  26],
        [ 39,  39,  39,  39,  39,  39,  39,  39,  39,  39,  39,  39,  39,
          39,  39,  39,  39,  39],
        [ 52,  52,  52,  52,  52,  52,  52,  52,  52,  52,  52,  52,  52,
          52,  52,  52,  52,  52],
        [ 65,  65,  65,  65,  65,  65,  65,  65,  65,  65,  65,  65,  65,
          65,  65,  65,  65,  65],
        [ 78,  78,  78,  78,  78,  78,  78,  78,  78,  78,  78,  78,  78,
          78,  78,  78,  78,  78],
        [ 91,  91,  91,  91,  91,  91,  91,  91,  91,  91,  91,  91,  91,
          91,  91,  91,  91,  91],
        [104, 104, 104, 104, 104, 104, 104, 104, 104, 104, 104, 104, 104,
         104, 104, 104, 104, 104],
        [117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117, 117,
         117, 117, 117, 117, 117],
        [130, 130, 

In [11]:
len(xm), xm[1].shape

(2, (16, 18))

In [12]:
# Here we should reinitialize these M to the lowest-gradient point in the 3x3,
# but we won't yet.
reinitM()

# Maintain a copy of the previous M, so that we can tell if it is converging.
Mprev = M.copy()


# Initial distance measure D for every pixel, and label
# D = np.finfo.max*np.ones((ishape[0], ishape[1]))
D = 1.0e20*np.ones((ishape[0], ishape[1]))
L = -1*np.ones((ishape[0], ishape[1]))


# Eq. 10-91 shows that the distance we want is a combination of the
# spatial and color distance. The spatial distance doesn't change though,
# in that i+x, j+y is always a fixed distance from i,j. We should precompute
# all of those distances...but not now.


# The spatial distance is constant over all windows [-s, s] for every
# cluster, so we'll just precompute. We'll also get the coordinate of the
# whole window in xr, yr, for use later as well.
xr, yr = np.meshgrid(np.arange(-s, s + 1), np.arange(-s, s + 1), indexing='ij')
scoords = np.concatenate([np.expand_dims(x, axis=1) for x in [xr.ravel(), yr.ravel()]], axis=1)
sDist = np.sqrt(np.sum(scoords**2, axis = 1))
spaceDist = np.reshape(sDist, (2*s+1, 2*s+1)) # distance from origin of the window, as matrix.

In [13]:
arr_info(spaceDist)

((27, 27), dtype('float64'), 0.0, 18.384776310850235)

In [14]:
plt.figure()
plt.imshow(spaceDist)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x1f6f8c7ae88>

In [15]:
# xs are the x and y coords of every pixel. used for average computing later.
xs = np.meshgrid(np.arange(ishape[0]), np.arange(ishape[1]), indexing='ij')


iteration = 0

while True:
    # First, loop over clusters and assign pixels to them.
    for mi, mx in enumerate(M):
        # mi is 1x5 of the average color and position of the cluster.
        i, j = [int(q) for q in np.round(mx[3:5])]
        # look over all pixels in the 2s x 2s neighborhood to see if they belong
        # to this cluster.

        # This is the double for-loop in python, exhaustive scheme.
        # for x in range(i - s, i + s + 1):
        #     if x < 0 or x >= ishape[0]:
        #         continue
        #
        #     for y in range(j - s, j + s + 1):
        #         if y < 0 or y >= ishape[1]:
        #             continue
        #
        #         sDist = np.sqrt((x-i)**2 + (y-j)**2)
        #         cxy = I[x,y,:]
        #         cDist = np.sum((mx[:3]-cxy)**2)
        #         oDist = np.sqrt(cDist + cs*(sDist/s)**2)
        #         if oDist < D[x,y]:
        #             D[x,y] = oDist
        #             L[x,y] = mi

        # This will be the numpy approach--A lot faster.
        # A bit harder to work out, but there are
        # some simplifications. First, the spatial distance calculation is the same
        # for the window around every cluster, since we're using the rounded i,j of mi.
        # That's spaceDist above.

        # The first key is to determine what part of the image we're sampling (around i,j)
        xrange = xr+i # This is a matrix of the x-coordinates of the window surrounding mi
        yrange = yr+j

        # Now cut out the out of range parts of the grid.
        valid = np.logical_and(np.logical_and(xrange >= 0, xrange < ishape[0]),
                               np.logical_and(yrange >= 0, yrange < ishape[1]))
        # So now for example I[xrange[valid], yrange[valid], :] are all the valid pixels
        # in the window. To set a piece of an image correctly one might say:
        # bb = np.zeros(list(valid.shape) + [3])
        # bb[valid,:] = I[xrange[valid], yrange[valid],:]

        # Here we're computing the color distance to the centroid of every valid pixel in the
        # window.
        colDist = cdist(I[xrange[valid], yrange[valid],:],
                        np.expand_dims(mx[:3], axis=1).T, metric='euclidean').squeeze()
        # sDist need not be computed each round, as it is constant for any given center and the window size.
        oDist = np.sqrt(colDist**2 + cs * (spaceDist[valid] / s) ** 2)

        #Find out which pixels in the window should have their label changed.
        whichToSwitch = np.zeros(valid.shape).astype(bool)
        whichToSwitch[valid] = oDist < D[xrange[valid], yrange[valid]]

        #Switch those pixels up:
        D[xrange[whichToSwitch], yrange[whichToSwitch]] = oDist[whichToSwitch[valid]]
        L[xrange[whichToSwitch], yrange[whichToSwitch]] = mi



        # print('looped over mi %d' % mi)


    # Then, recompute clusters.
    for mi in range(len(M)):
        M[mi, :3] = np.mean(I[L==mi, :], axis=0)
        M[mi, 3] = np.mean(xs[0][L==mi])
        M[mi, 4] = np.mean(xs[1][L==mi])
        
    # Clean up the label image:
    ndimage.generic_filter(L, function=myfunc, size=5, output=L, mode='reflect')

    # IR = np.zeros(I.shape)
    # for mi in range(len(M)):
    #     IR[L == mi, :] = M[mi,:3]
    IR = M[L.astype(int), :3]

    theBorders = ndimage.generic_filter(L, function=onEdge,
                                        size=3, mode='reflect').astype(bool)
    IRwB = IR.copy()
    IRwB[theBorders, :] = 255

    f, ax = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
    ax[0].imshow(I / 255)
    ax[0].set_title('Original Image')
    ax[1].imshow(IR / 255)
    ax[1].set_title('%d Superpixels after %d Iter' % (SUPERPIXELS, iteration))
    ax[2].imshow(IRwB / 255)
    ax[2].set_title('With border highlights')

    plt.tight_layout()

    # Then, test for convergence.
    iteration += 1
    if convergedYet() or iteration >= MAXITER:
        break
    Mprev = M.copy() # If we didn't break, then readjust Mprev




Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

iteration 1 normdiff 537.882416.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

iteration 2 normdiff 196.515533.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

iteration 3 normdiff 155.828760.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

iteration 4 normdiff 115.516046.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

iteration 5 normdiff 90.698880.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

iteration 6 normdiff 70.539990.


In [16]:
'''
# Clean up the label image:
ndimage.generic_filter(L, function=myfunc, size=5, output=L, mode='reflect')

# IR = np.zeros(I.shape)
# for mi in range(len(M)):
#     IR[L == mi, :] = M[mi,:3]
IR = M[L.astype(int), :3]




theBorders = ndimage.generic_filter(L, function=onEdge,
                                    size=3, mode='reflect').astype(bool)
IRwB = IR.copy()
IRwB[theBorders, :] = 255



f, ax = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
ax[0].imshow(I / 255)
ax[0].set_title('Original Image')
ax[1].imshow(IR / 255)
ax[1].set_title('%d Superpixels after %d Iter' % (SUPERPIXELS, iteration))
ax[2].imshow(IRwB / 255)
ax[2].set_title('With border highlights')


plt.tight_layout()
'''

"\n# Clean up the label image:\nndimage.generic_filter(L, function=myfunc, size=5, output=L, mode='reflect')\n\n# IR = np.zeros(I.shape)\n# for mi in range(len(M)):\n#     IR[L == mi, :] = M[mi,:3]\nIR = M[L.astype(int), :3]\n\n\n\n\ntheBorders = ndimage.generic_filter(L, function=onEdge,\n                                    size=3, mode='reflect').astype(bool)\nIRwB = IR.copy()\nIRwB[theBorders, :] = 255\n\n\n\nf, ax = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)\nax[0].imshow(I / 255)\nax[0].set_title('Original Image')\nax[1].imshow(IR / 255)\nax[1].set_title('%d Superpixels after %d Iter' % (SUPERPIXELS, iteration))\nax[2].imshow(IRwB / 255)\nax[2].set_title('With border highlights')\n\n\nplt.tight_layout()\n"