# Mean Shift
#### Basic Implementation
The following two functions find_peak and meanshift execute the basic mean shift algorithm.

In [281]:
import cv2
import math
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from scipy.io import loadmat
from scipy import spatial
from sklearn.datasets import *
from IPython.core.display import clear_output
from util import log_progress
%matplotlib inline
pylab.rcParams['figure.figsize'] = 16, 12

##### Load  and visualize sample data
The matrix is loaded into a numpy array of dimensions (2000, 3)

In [282]:
SAMPLE_DATA = loadmat("pts.mat")['data']
if False:  # Set to true to show the data
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection="3d")
    ax.scatter(SAMPLE_DATA[0], SAMPLE_DATA[1], SAMPLE_DATA[2])

##### Utility functions


In [283]:
cached_tree = None
def get_neighbours(data, point, r):
    global cached_tree
    tree = spatial.KDTree(data) if cached_tree is None else cached_tree
    ret = []
    return data[tree.query_ball_point(point, r)]

def get_neighbours_cdist(data, point, r):
    ret = []
    for i, d in enumerate(spatial.distance.cdist(np.array([point]), data)[0]):
        if d < r:
            ret.append(data[i])
    return np.array(ret)

In [284]:
def find_peak(data, point, r):
    def calc_new_shift(data, point, r):
        neighbours = get_neighbours(data, point, r)
        return neighbours.mean(axis=0)
    
    dist = None
    last_peak = calc_new_shift(data, point, r)
    while dist is None or dist >= t:
        peak = calc_new_shift(data, point, r)
        dist = spatial.distance.euclidean(peak, last_peak)
        last_peak = peak
    return peak

def meanshift(data, r, t=0.01):
    peaks, points, point_peaks = [], [], []
    for point in log_progress(data, 1, len(data)):
        peak = find_peak(data, point, r)
        # Match peak to possible neighbours. Use cdist because we have only few peaks
        neighbours = get_neighbours_cdist(np.array(peaks), peak, r/2.) if len(peaks) > 0 else []
        if len(neighbours) > 1:
            peak = neighbours[0]
        else:
            peaks.append(peak)
        points.append(point)
        point_peaks.append(np.where(peaks==peak))
    return np.array(peaks), np.array(points), np.array(point_peaks)

##### Execute the meanshift function
Visualize the results

In [285]:
def visualize_meanshift(image, r):
    peaks, points, point_peaks = meanshift(image, r)
    print("Found {} peaks in {} points !".format(len(peaks), len(points),))
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111, projection="3d")
    ax.scatter(*peaks.transpose(), color='yellow', s=10)
    ax.scatter(*points.transpose(), color=point_peaks, s=1, colormap = ['r','g','b'])

In [286]:
visualize_meanshift(SAMPLE_DATA.transpose(), r=2)
#visualize_meanshift(cv2.imread())

NameError: name 't' is not defined