# Mean Shift
#### Basic Implementation
The following two functions find_peak and meanshift execute the basic mean shift algorithm.

In [282]:
import sys
import cv2
import math
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
from mpl_toolkits.mplot3d import Axes3D
import requests
import numpy as np
from pathlib import Path
from scipy.io import loadmat
from scipy import spatial
from sklearn.datasets import *
from IPython.core.display import clear_output
from util import log_progress
%matplotlib inline
pylab.rcParams['figure.figsize'] = 16, 12
sys.setrecursionlimit(10000)

##### Load  and visualize sample data
The matrix is loaded into a numpy array of dimensions (2000, 3)

In [283]:
"""
https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300/html/dataset/images/color/181091.html
https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300/html/dataset/images/color/55075.html
https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300/html/dataset/images/color/368078.html
"""
SAMPLE_DATA = loadmat("pts.mat")['data'].transpose()
img_path = Path('images')
IMG_A = cv2.imread(str(img_path / "a.jpg"))
IMG_B = cv2.imread(str(img_path / "b.jpg"))
IMG_C = cv2.imread(str(img_path / "c.jpg"))

##### Utility functions


In [284]:
cached_tree = None
def get_neighbours(data, point, r):
    global cached_tree
    tree = spatial.KDTree(data) if cached_tree is None else cached_tree
    ret = []
    return data[tree.query_ball_point(point, r)]

def get_neighbours_cdist(data, point, r):
    ret = []
    for i, d in enumerate(spatial.distance.cdist(np.array([point]), data)[0]):
        if d < r:
            ret.append(data[i])
    return np.array(ret)

In [285]:
def find_peak(data, point, r, t = 0.01):
    def calc_new_shift(data, point, r):
        neighbours = get_neighbours(data, point, r)
        return neighbours.mean(axis=0)
    
    dist = None
    last_peak = calc_new_shift(data, point, r)
    while dist is None or dist >= t:
        peak = calc_new_shift(data, point, r)
        dist = spatial.distance.euclidean(peak, last_peak)
        last_peak = peak
    return peak

def meanshift(data, r):
    global cached_tree
    cached_tree = None
    peaks, points, point_peaks = [], [], []
    for point in log_progress(data, 1, len(data)):
        peak = find_peak(data, point, r)
        # Match peak to possible neighbours. Use cdist because we have only few peaks
        neighbours = get_neighbours_cdist(np.array(peaks), peak, r/2.) if len(peaks) > 0 else []
        if len(neighbours) > 1:
            peak = neighbours[0]
        else:
            peaks.append(peak)
        points.append(point)
        point_peaks.append(np.where(peaks==peak)[0][0])
    return np.array(peaks), np.array(points), np.array(point_peaks)

In [343]:
def find_peak_opt(data, point, r, t = 0.01):
    def calc_new_shift(data, point, r):
        neighbours = get_neighbours(data, point, r)
        return neighbours.mean(axis=0)
    
    dist = None
    last_peak = calc_new_shift(data, point, r)
    while dist is None or dist >= t:
        peak = calc_new_shift(data, point, r)
        dist = spatial.distance.euclidean(peak, last_peak)
        last_peak = peak
    return peak

def meanshift_opt(data, r):
    global cached_tree
    cached_tree = None
    peaks, points, point_peaks = [], np.empty((0,3)), []
    for point in log_progress(data, 1, len(data)):
        if np.any(points == point, axis=1):
            continue
        peak = find_peak(data, point, r)
        # Match peak to possible neighbours. Use cdist because we have only few peaks
        neighbours = get_neighbours_cdist(np.array(peaks), peak, r/2.) if len(peaks) > 0 else []
        if len(neighbours) > 1:
            peak = neighbours[0]
        else:
            peaks.append(peak)
        # Basin of Attraction
        peak_idx = np.where(peaks==peak)[0][0]
        for p in get_neighbours(data, peak, r):
            print(np.any(points == p, axis=1))
            if not np.any(points == p, axis=1):
                points = np.append(points, [p], axis=0)
                point_peaks.append(peak_idx)
    return np.array(peaks), np.array(points), np.array(point_peaks)

In [287]:
def image_segment(image, r, scale=0.01):
    image = cv2.resize(image, None, fx = scale, fy = scale)
    orig_img = np.array(image)
    image = cv2.GaussianBlur(image, (5,5), 5.0)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    image = image.reshape(image.shape[0]*image.shape[1], image.shape[2])
    peaks, points, point_peaks = meanshift_opt(image, r)
    image = image.reshape(image.shape)
    print("Found {} peaks in {} points !".format(len(peaks), len(points)))
    converted_peaks = cv2.cvtColor(np.array([peaks[:, 0:3]], dtype=np.uint8), cv2.COLOR_LAB2BGR)[0]
    im = converted_peaks[point_peaks]
    im = im.reshape(orig_img.shape[0], orig_img.shape[1], orig_img.shape[2])
    plt.imshow(im)

##### Execute the meanshift function
Visualize the results

In [288]:
def visualize(image, r, func):
    peaks, points, point_peaks = func(image, r)
    print("Found {} peaks in {} points !".format(len(peaks), len(points)))
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111, projection="3d")
    ax.scatter(*peaks.transpose(), c='black', s=100)
    ax.scatter(*points.transpose(), c='blue', s=1)

In [344]:
#visualize(SAMPLE_DATA, r=5, func=meanshift_opt)
image_segment(IMG_A, r=10)

[]
[ True]
[False]
[False False]


  split = (maxval+minval)/2


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()