In [1]:
import numpy as np
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

In [2]:
n_data = 1000
seed = 1
n_centers = 4

# Generate random Gaussian blobs and run K-means
blobs, blob_labels = make_blobs(n_samples=n_data, n_features=2, 
                                centers=n_centers, random_state=seed)
clusters_blob = KMeans(n_clusters=n_centers, random_state=seed).fit_predict(blobs)

# Generate data uniformly at random and run K-means
uniform = np.random.rand(n_data, 2)
clusters_uniform = KMeans(n_clusters=n_centers, random_state=seed).fit_predict(uniform)

In [3]:
figure = plt.figure(figsize=(20,10))
plt.subplot(221)
plt.scatter(blobs[:, 0], blobs[:, 1], c=blob_labels, edgecolors='k', cmap='PuRd')
plt.title("(a) Four randomly generated blobs", fontsize=14)
plt.axis('off')

plt.subplot(222)
plt.scatter(blobs[:, 0], blobs[:, 1], c=clusters_blob, edgecolors='k', cmap='PuRd')
plt.title("(b) Clusters found via K-means", fontsize=14)
plt.axis('off')

plt.subplot(223)
plt.scatter(uniform[:, 0], uniform[:, 1], edgecolors='k')
plt.title("(c) 1000 randomly generated pois", fontsize=14)
plt.axis('off')

plt.subplot(224)
plt.scatter(uniform[:, 0], uniform[:, 1], c=clusters_uniform, edgecolors='k', cmap='PuRd')
plt.title("(d) Clusters found via K-means", fontsize=14)
plt.axis('off')

<IPython.core.display.Javascript object>

(-0.053821532626414627,
 1.0542231655936423,
 -0.059205837049238999,
 1.0592196036676365)

In [21]:
figure.savefig('kmeans-example.png')

In [4]:
from mpl_toolkits.mplot3d import Axes3D
from sklearn import manifold, datasets

In [5]:
X, color = datasets.samples_generator.make_swiss_roll(n_samples=1500)

In [6]:
clusters_swiss_roll = KMeans(n_clusters=100, random_state=seed).fit_predict(X)

In [7]:
fig2 = plt.figure()
ax = fig2.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=clusters_swiss_roll, cmap='Spectral')

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x17e32cc3d30>

In [32]:
fig2.savefig('kmeans-roll100.png')