In [1]:
import numpy as np
from sklearn import datasets

import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML

In [2]:
palette = np.array(['#332288', '#88CCEE', '#44AA99', '#117733',
                    '#999933', '#DDCC77', '#CC6677', '#882255', '#AA4499'])

In [3]:
N, K = 1500, 7
blobs = datasets.make_blobs(n_samples=N, centers=K)
data, _ = blobs

In [68]:
# hyperparameters
# beta = 0.75
beta = 20
d = lambda x, y: beta * np.linalg.norm(x[:, None, :] - y, axis=-1)
eps = 0.0001

In [69]:
class KMeans:
    def __init__(self, data):
        self.data = data
        self.centers = self.data[np.random.randint(N, size=K)]
        self.infer_probs()
        
    def infer_probs(self):
        # compute the information content of the event that data x came from cluster k
        content = d(self.data, self.centers)
        raw_probs = np.exp(-content)
        self.probs = raw_probs / (raw_probs.sum(-1)[:, None]+eps)
    
    def est_centers(self):
        # compute the MLE for centers of distributions
        self.centers = np.sum(self.probs[:, :, None] * self.data[:, None, :], 0) / self.probs.sum(0)[:, None]

In [73]:
plt.close()
fig, ax = plt.subplots(figsize=(12, 12))
scat_data = ax.scatter([], [], s=10, marker='o')
scat_means = ax.scatter([], [], s=100, marker='+', color='black')
radii = [plt.Circle([0, 0], 1.96*1/np.sqrt(2*beta), edgecolor='black', facecolor=(0, 0, 0, .0125))
         for _ in range(K)]

model = KMeans(data)

def init():
    ax.axis([-15, 15, -15, 15])
    for r in radii:
        ax.add_artist(r)
    return [scat_data, scat_means, *radii]

def animate(i):
    model.est_centers()
    model.infer_probs()
    pred = model.probs.argmax(1)
    centers = model.centers
    
    scat_data.set_offsets(data)
    scat_data.set_color(np.array([palette[c] for c in pred]))
    
    scat_means.set_offsets(centers)
#     scat_means.set_color(palette)
    
    for i, r in enumerate(radii):
        r.set_center(centers[i])
    
    return [scat_data, scat_means, *radii]

anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=20, interval=100, blit=True)
HTML(anim.to_html5_video())