# Mean Shift

## Mean Shiftとは<a name="description"></a>

- 特徴空間で、指定した距離内の密度が最大化するようにクラスタ数を求める

## 使用方法<a name="example"></a>

### データ準備<a name="data"></a>

In [None]:
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

centers = [[1, 1], [-1, -1], [1, -1]]
data, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6, random_state=0)

plt.figure(figsize=(4, 4))

plt.scatter(data[:, 0], data[:, 1], c='black', s=1, linewidth=0)

plt.xticks(())
plt.yticks(())

plt.show()

### 学習<a name="training"></a>

In [None]:
from sklearn.cluster import MeanShift, estimate_bandwidth

bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(data)

### 可視化<a name="visualization"></a>

In [None]:
import numpy as np
from matplotlib.cm import ScalarMappable

labels = ms.labels_
cluster_centers = ms.cluster_centers_
n_clusters = len(np.unique(labels))

print('number of estimated clusters: {n}'.format(n=n_clusters))

colors = ScalarMappable(cmap='gist_rainbow').to_rgba(np.arange(n_clusters))

plt.figure(figsize=(4, 4))

for k, color in enumerate(colors):
    is_member = labels == k
    cluster_center = cluster_centers[k]

    plt.scatter(data[is_member, 0], data[is_member, 1], c=color, s=1, linewidth=0)
    plt.scatter(cluster_center[0], cluster_center[1], c='black', marker='x', s=150, linewidth=3)

plt.xticks(())
plt.yticks(())

plt.show()