In [52]:
import numpy as np
import ipympl
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, RadioButtons
from sklearn.cluster import AgglomerativeClustering
from sklearn import datasets
%matplotlib notebook



In [53]:
# datasets
n_samples = 1500
noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5, noise=.05)
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
blobs = datasets.make_blobs(n_samples=n_samples, random_state=8)
no_structure = np.random.rand(n_samples, 2), None

datasets = [noisy_circles, noisy_moons, blobs, no_structure]

In [54]:
# datasets visualization
fig, axs = plt.subplots(2, 2)

for i, dataset in enumerate(datasets):
    X, y = dataset
    axs[i//2, i%2].scatter(X[:,0], X[:,1])    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [77]:
from functools import partial
params = [2, 'ward']

def sliders_on_changed(val, params):
    params[0] = val
    for i, dataset in enumerate(datasets):
        axs[i//2, i%2].clear()
        X, y = dataset
        clustering = AgglomerativeClustering(n_clusters=params[0], linkage=params[1]).fit(X)
        axs[i//2, i%2].scatter(X[:,0], X[:,1], c=clustering.labels_)
    fig.canvas.draw_idle()

def linkage_on_clicked(val, params):
    params[1] = val
    for i, dataset in enumerate(datasets):
        axs[i//2, i%2].clear()
        X, y = dataset
        clustering = AgglomerativeClustering(n_clusters=params[0], linkage=params[1]).fit(X)
        axs[i//2, i%2].scatter(X[:,0], X[:,1], c=clustering.labels_)
    fig.canvas.draw_idle()

fig, axs = plt.subplots(2, 2)

for i, dataset in enumerate(datasets):
    X, y = dataset
    clustering = AgglomerativeClustering().fit(X)
    
    axs[i//2, i%2].scatter(X[:,0], X[:,1], c=clustering.labels_)
fig.subplots_adjust(left=0.25, bottom=0.35)
ax = fig.add_axes([0.25, 0.1, 0.6, 0.03])
freq_slider = Slider(ax, 'n_clusters', 2, 10, valinit=2, valstep=1)
freq_slider.on_changed(partial(sliders_on_changed, params=params))

ax = fig.add_axes([0.025, 0.5, 0.15, 0.15])
linkage_type = RadioButtons(ax, ('ward', 'complete', 'average', 'single'), active=0)
linkage_type.on_clicked(partial(linkage_on_clicked, params=params))

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …