# Mean Shift
#### Basic Implementation
The following two functions find_peak and meanshift execute the basic mean shift algorithm.

In [268]:
import sys
import cv2
import math
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
from mpl_toolkits.mplot3d import Axes3D
import requests
import numpy as np
from pathlib import Path
from scipy.io import loadmat
from scipy import spatial
from sklearn.datasets import *
from IPython.core.display import clear_output
from util import log_progress
%matplotlib inline
pylab.rcParams['figure.figsize'] = 16, 12
sys.setrecursionlimit(10000)

##### Load  and visualize sample data
The matrix is loaded into a numpy array of dimensions (2000, 3)

In [269]:
"""
https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300/html/dataset/images/color/181091.html
https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300/html/dataset/images/color/55075.html
https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300/html/dataset/images/color/368078.html
"""
SAMPLE_DATA = loadmat("pts.mat")['data'].transpose()
img_path = Path('images')
IMG_A = cv2.imread(str(img_path / "a.jpg"))
IMG_B = cv2.imread(str(img_path / "b.jpg"))
IMG_C = cv2.imread(str(img_path / "c.jpg"))

##### Utility functions


In [282]:
cached_tree = None
def get_neighbours(data, point, r):
    global cached_tree
    tree = spatial.KDTree(data) if cached_tree is None else cached_tree
    return tree.query_ball_point(point, r)

def get_neighbours_cdist(data, point, r):
    distances = spatial.distance.cdist(np.array([point]), data)[0]
    return data[np.where(distances < r)]

In [271]:
def find_peak(data, point, r, t = 0.01):
    def calc_new_shift(data, point, r):
        return data[get_neighbours(data, point, r)].mean(axis=0)
    
    dist = t
    while dist >= t:
        peak = calc_new_shift(data, point, r)
        dist = spatial.distance.euclidean(peak, point)
        point = peak
    return peak

def meanshift(data, r):
    peaks, points, point_peaks = [], [], []
    for point in log_progress(data, 1, len(data)):
        peak = find_peak(data, point, r)
        # Match peak to possible neighbours. Use cdist because we have only few peaks
        neighbours = get_neighbours_cdist(np.array(peaks), peak, r/2.) if len(peaks) > 0 else []
        if len(neighbours) > 1:
            peak = neighbours[0]
        else:
            peaks.append(peak)
        points.append(point)
        point_peaks.append(np.where(peaks==peak)[0][0])
    return np.array(peaks), np.array(points), np.array(point_peaks)

In [287]:
def find_peak_opt(data, point, r, t = 0.01):
    def calc_new_shift(data, point, r):
        return data[get_neighbours(data, point, r)].mean(axis=0)
    
    dist = t
    while dist >= t:
        peak = calc_new_shift(data, point, r)
        dist = spatial.distance.euclidean(peak, point)
        point = peak
    return peak

def meanshift_opt(data, r):
    peaks, point_peaks = [], np.zeros(data.shape[0], dtype='int16')-1
    for i, point in log_progress(enumerate(data), every=100, size=len(data)):
        if point_peaks[i] != -1:
            continue
        peak = find_peak(data, point, r)
        # Match peak to possible neighbours. Use cdist because we have only few peaks
        peak_neighbours = get_neighbours_cdist(np.array(peaks), peak, r/2.) if len(peaks) > 0 else []
        if len(peak_neighbours) > 1:
            peak = neighbours[0]
        else:
            peaks.append(peak)
        # Basin of Attraction
        neighbours = get_neighbours(data, peak, r)
        print(neighbours)
        point_peaks[neighbours] = np.where(peaks == peak)[0]
    return np.array(peaks), point_peaks

In [273]:
def image_segment(image, r, scale=0.05):
    # preprocess the image
    image = cv2.resize(image, None, fx = scale, fy = scale)
    orig_img = np.array(image)
    image = cv2.GaussianBlur(image, (5,5), 5.0)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    image = image.reshape(image.shape[0]*image.shape[1], image.shape[2])
    print("Image has {} points".format(image.shape))
    peaks, point_peaks = meanshift_opt(image, r)
    print("Found {} peaks !".format(len(peaks)))
    # convert back to show format
    converted_peaks = cv2.cvtColor(np.array([peaks[:, 0:3]], dtype=np.uint8), cv2.COLOR_LAB2BGR)[0]
    im = converted_peaks[point_peaks]
    im = im.reshape(orig_img.shape[0], orig_img.shape[1], orig_img.shape[2])
    plt.imshow(im)

##### Execute the meanshift function
Visualize the results

In [274]:
def visualize(image, r, func):
    peaks, _ = func(image, r)
    print("Found {} peaks in {} points !".format(len(peaks), image.shape))
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111, projection="3d")
    ax.scatter(*peaks.transpose(), c='black', s=100)
    ax.scatter(*image.transpose(), c='blue', s=1)

In [288]:
#visualize(SAMPLE_DATA, r=5, func=meanshift_opt)
image_segment(IMG_A, r=10)

Image has (384, 3) points


[190, 191, 302, 303, 347, 363, 110, 236, 268, 285, 315, 111, 252, 267, 189, 205, 219, 235, 251, 331, 378, 173, 286, 188, 283, 287, 206, 221, 207, 269, 362, 94, 203, 234, 299, 95, 250, 157, 222, 237, 172, 253, 270, 1, 264, 0, 223, 248, 249, 265, 266, 218, 271]


  split = (maxval+minval)/2


[190, 191, 302, 303, 347, 363, 110, 236, 268, 285, 315, 111, 252, 267, 189, 205, 219, 235, 251, 331, 378, 173, 286, 188, 283, 287, 206, 221, 207, 269, 362, 94, 203, 234, 299, 95, 250, 157, 222, 237, 172, 253, 270, 1, 264, 0, 223, 248, 249, 265, 266, 218, 271]
[346, 254, 141, 239, 232, 2, 187, 255, 202, 361, 376, 125, 280, 282, 3, 78, 281, 330, 156, 217, 375, 16, 17, 79, 360, 18, 171, 19, 314, 359, 109, 345, 298]
[35, 296, 77, 87, 297, 36, 279, 313, 328, 60, 76, 124, 312, 86, 92, 108, 28, 62, 155, 170, 327, 103, 247, 44, 61, 358, 20, 4, 12, 34, 139, 295, 342, 27, 52, 63, 70, 59, 311, 69, 75, 104, 43, 45, 91]
[47, 101, 121, 50, 30, 67, 293, 21, 73, 38, 55, 230, 277, 5, 14, 199, 324, 57, 134, 151, 261, 31, 41, 56, 308]
[167, 183, 214, 49, 100, 150, 15, 245, 22, 40, 83, 198, 9, 48, 133, 292, 166, 182, 229, 66]
[307, 99, 181, 197, 260, 165, 8, 148, 65, 7, 244, 82, 115, 164, 180, 228, 322, 291, 131, 306, 64]
[4, 12, 34, 139, 295, 27, 52, 63, 70, 59, 311, 69, 75, 104, 231, 43, 45, 91, 123, 51

IndexError: tuple index out of range