### `cluster.FeatureAgglomeration` applies Hierarchical clustering to group together features that behave similarly.

In [29]:
from tempfile import mkdtemp
import joblib

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.datasets import load_digits
from sklearn.feature_extraction.image import grid_to_graph
from sklearn.cluster import FeatureAgglomeration

#### Load Data

In [9]:
digits = load_digits()
X = np.reshape(digits.images, newshape=(digits.images.shape[0], -1))
print('Shape of  X :',X.shape)

Shape of  X : (1797, 64)


In [10]:
print('Image shape',digits.images.shape)

Image shape (1797, 8, 8)


#### Connectivity Matrix | used to impose connectivity in estimators that use connectivity information

In [12]:
connectivity_matrix = grid_to_graph(*digits.images[0].shape)

#### Feature Agglomeration | similar features are merged together using feature agglomeration

A context object for caching a function's return value each time it is called with the same input arg

In [14]:
tmp = mkdtemp()
mem = joblib.memory.Memory(location=tmp)

In [15]:
feat_agg = FeatureAgglomeration(n_clusters=32, connectivity=connectivity_matrix, memory=mem)

#### Fit 

In [16]:
x_reduced = feat_agg.fit_transform(X)

________________________________________________________________________________
[Memory] Calling sklearn.cluster._agglomerative.ward_tree...
ward_tree(array([[0., ..., 0.],
       ...,
       [0., ..., 0.]]), connectivity=<64x64 sparse matrix of type '<class 'numpy.int32'>'
	with 288 stored elements in COOrdinate format>, n_clusters=None, return_distance=False)
________________________________________________________ward_tree - 0.1s, 0.0min


In [18]:
print('Shape of the reduced Images :',x_reduced.shape)

Shape of the reduced Images : (1797, 32)


#### Inverse Transform

In [21]:
x_restored = feat_agg.inverse_transform(x_reduced)
print('Shape of the restored Images :',x_restored.shape)


Shape of the restored Images : (1797, 64)


### Visualize

In [36]:
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15,10))
ax = ax.ravel()
for i , frame in enumerate(ax):
    frame.imshow(x_restored[i].reshape(8,8), cmap=plt.cm.binary)

<img src='./plots/digits-feat-agglomeration-transformed.png'>

In [28]:
len(feat_agg.labels_)

64

In [35]:
sns.heatmap(feat_agg.labels_.reshape(8,8), annot=True)
plt.title('Labels');

<img src='./plots/digits-feat-agglomeration-labels.png'>