In [1]:
%matplotlib notebook

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs

# Algorithm

In [5]:
class kmedian:
    def __init__(self, K):
        self.centroids = None
        self.K = K
    
    # Class methods
    def fit(self, X, max_iter=100, tols=1e-6):
        m, n = X.shape
        
        # Initialize centroids
        old_centroids = np.random.randn(self.K, n)
        
        for iteration in range(max_iter):
            # Find the distance of each point to the centroids
            distances = []
            for j in range(self.K):
                distances.append(np.array([np.linalg.norm(X[i] - old_centroids[j]) for i in range(len(X))]))
            distances = np.array(distances).T

            # Labels each point according to its closest centroid
            y = np.array([np.argmin(i) for i in distances])

            # Finds the median of each cluster of points
            new_centroids = np.random.randn(self.K, n)
            for i in range(self.K):
                new_centroids[i] = [np.median(j) for j in X[np.ravel(np.array((y == i).nonzero()))].T]
            
            # Stopping criterion
            if np.linalg.norm(new_centroids - old_centroids) > tols:
                old_centroids = new_centroids
                
            else:
                print("Converged after {} iterations".format(iteration+1))
                break
        
        self.centroids = new_centroids
        # Assigns the final cluster of each point
        self.labels = y
        
    def predict(self, X_test):
        y_test = np.zeros(len(X_test))
        for i in range(len(y_test)):
            y_test[i] = np.argmin([np.linalg.norm(X_test[i] - j) for j in self.centroids])
        return y_test

### Make data

In [6]:
X, y = make_blobs(n_samples=100, centers=2, n_features=2, random_state=1, shuffle=True)

plt.figure()
plt.scatter(X.T[0], X.T[1])

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x2597a38f310>

In [8]:
kmedian_algorithm_test = kmedian(2)
kmedian_algorithm_test.fit(X, tols=1e-6)
kmedian_algorithm_test.centroids

Converged after 3 iterations


array([[-1.57407655,  4.46695833],
       [-9.57528446, -3.91375762]])

In [9]:
plt.figure()
plt.scatter(X.T[0], X.T[1], c=kmedian_algorithm_test.labels)
plt.scatter(kmedian_algorithm_test.centroids.T[0], kmedian_algorithm_test.centroids.T[1], c='red', marker='x', s=100)

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x2597a812b80>

### Make a prediction

In [10]:
test = np.array([[-7.28, -3.46], [-.66, 4.49]])
kmedian_algorithm_test.predict(test)

array([1., 0.])