In [12]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()  # for plot styling4
import numpy as np

from sklearn.datasets import load_digits

from scipy.stats import mode
from sklearn.metrics import pairwise_distances_argmin, pairwise_distances, \
                            accuracy_score, adjusted_rand_score, fowlkes_mallows_score

## Базовая реализация k-means

In [13]:
def kmeans_find_clusters(X, n_clusters, iter_num=50):
    # 1. Randomly choose clusters

    ind = np.random.permutation(X.shape[0])[:n_clusters]
    centers = X[ind]
    
    for _ in range(iter_num):
        # 2a. Assign labels based on closest center
        labels = pairwise_distances_argmin(X, centers)
        
        # 2b. Find new centers from means of points
        new_centers = np.array([X[labels == i].mean(0) for i in range(n_clusters)])
        
        # 2c. Check for convergence
        if np.all(centers == new_centers):
            break
        centers = new_centers
    
    return centers, labels

## 1) k-means++

In [14]:
def kmeans_plus_plus(X, n_clusters, iter_num=50):
    # 1. Artfully choose clusters

    centers = np.array([ X[ np.random.choice(X.shape[0]) ] ])
  
    for _ in range(1, n_clusters):
        dist_matrix = pairwise_distances(X, Y=centers)
        min_dists = np.amin(dist_matrix, axis=1)
        min_dists_squared = np.square(min_dists)
        sample_probs = min_dists_squared / np.sum(min_dists_squared)

        new_centroid_idx = np.random.choice(X.shape[0], p=sample_probs)
        new_centroid = X[new_centroid_idx]
        centers = np.vstack((centers, new_centroid)) 
    
    for _ in range(iter_num):
        # 2a. Assign labels based on closest center
        labels = pairwise_distances_argmin(X, centers)
        
        # 2b. Find new centers from means of points
        new_centers = np.array([X[labels == i].mean(0) for i in range(n_clusters)])
        
        # 2c. Check for convergence
        if np.all(centers == new_centers):
            break
        centers = new_centers
    
    return centers, labels

## Датасет MNIST

In [15]:
digits = load_digits()
digits.data.shape

(1797, 64)

In [16]:
np.random.seed(891642)
centers_pp, labels_pp = kmeans_plus_plus(digits.data, n_clusters=10, iter_num=50)
centers, labels = kmeans_find_clusters(digits.data, n_clusters=10, iter_num=50)

In [17]:
#fig, ax = plt.subplots(2, 5, figsize=(8, 3))
#centers_img = centers.reshape(10, 8, 8)
#for axi, center in zip(ax.flat, centers_img):
#    axi.set(xticks=[], yticks=[])
#    axi.imshow(center, interpolation='nearest', cmap=plt.cm.binary)

## 2) Kmeans ++ accuracy domination

In [25]:
# baseline
# 0.7935447968836951

labels_dig = np.zeros_like(labels)
for i in range(10):
    mask = (labels == i)
    labels_dig[mask] = mode(digits.target[mask])[0]
    
labels_dig_pp = np.zeros_like(labels_pp)
for i in range(10):
    mask = (labels_pp == i)
    labels_dig_pp[mask] = mode(digits.target[mask])[0]
    
print('k-means:', accuracy_score(digits.target, labels_dig))
print('k-means++:', accuracy_score(digits.target, labels_dig_pp))

k-means: 0.7646076794657763
k-means++: 0.867557039510295


## 3) Additional metrics

In [30]:
# Adjusted Rand index
print('ARI k-means:', adjusted_rand_score(digits.target, labels_dig))
print('ARI k-means++:', adjusted_rand_score(digits.target, labels_dig_pp))

ARI k-means: 0.6287824025276249
ARI k-means++: 0.7369974725247017


In [35]:
# Fowlkes-Mallows index
print('FMI k-means:', fowlkes_mallows_score(digits.target, labels_dig))
print('FMI k-means++:', fowlkes_mallows_score(digits.target, labels_dig_pp))

FMI k-means: 0.6697787210148078
FMI k-means++: 0.76351462380623
