In [1]:
import pandas as pd
import numpy as np
from sklearn.mixture import GaussianMixture

In [2]:
df = pd.read_csv('../data/fashion/fashion-mnist_train.csv')

In [3]:
df.head()

Unnamed: 0,label,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,pixel9,...,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783,pixel784
0,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,9,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,6,0,0,0,0,0,0,0,5,0,...,0,0,0,30,43,0,0,0,0,0
3,0,0,0,0,1,2,0,0,0,0,...,3,0,0,0,0,1,0,0,0,0
4,3,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [4]:
y = df['label']
del df['label']

In [18]:
def get_confusion_matrix(n_components, y, yhat):
    conf_mat = np.empty(shape=(n_components, n_components), dtype='int64')
    for cluster_id in range(n_components):
        true_labels = y[np.where(cluster_id == yhat)]
        conf_mat[cluster_id] = np.bincount(true_labels, minlength=n_components)
    return conf_mat

In [20]:
def purity(confusion_matrix):
    Pj = confusion_matrix.max(axis=1)
    Mj = confusion_matrix.sum(axis=1)
    return sum(Pj)/sum(Mj)

In [21]:
def gini(confusion_matrix):
    Mj = confusion_matrix.sum(axis=1, keepdims=True)
    gj = 1 - ((confusion_matrix/Mj)**2).sum(axis=1, keepdims=True)
    return (sum(gj*Mj)/sum(Mj))[0]

In [5]:
%%time
gmm = GaussianMixture(n_components=10, covariance_type='diag').fit(df)

CPU times: user 1min 11s, sys: 15.6 s, total: 1min 27s
Wall time: 45.3 s


In [10]:
yhat = gmm.predict(df)

In [22]:
model1_conf_mat = get_confusion_matrix(10, y.values, yhat)
model1_conf_mat

array([[   0,    0,    0,    0,    6,    0,    3,    0, 1235,    0],
       [ 121,    5,   60,    8,   33,    0,  206,    0,  411,   15],
       [ 670, 5815,  103, 4846,  802,    5,  448,    0,  304,    1],
       [ 673,   47,  826,   81,  375,   21,  900,    0,  289,   18],
       [ 135,    3,   66,   13,   11,  185,  180,    1,  143,   38],
       [   0,    0,    0,    0,    0,  903,    0,   24,    2, 1975],
       [4002,    7,   10,  297,   14,    0,  908,    0,    2,    1],
       [   0,    0,    0,    0,    0, 3929,    0, 5782,   12, 2034],
       [  41,    1,   29,    7,   11,  938,   66,  193, 2642, 1897],
       [ 358,  122, 4906,  748, 4748,   19, 3289,    0,  960,   21]])

In [23]:
purity(model1_conf_mat)

0.46421666666666667

In [24]:
gini(model1_conf_mat)

0.6323260716569258

Full covar matrix

In [7]:
%%time
gmm2 = GaussianMixture(n_components=10, covariance_type='full').fit(df)

CPU times: user 24min 17s, sys: 2min 28s, total: 26min 45s
Wall time: 15min 6s


In [25]:
yhat2 = gmm2.predict(df)

In [29]:
model2_conf_mat = get_confusion_matrix(10, y.values, yhat2)
model2_conf_mat

array([[ 635, 5795,   50, 4627,  540,    5,  383,    0,  242,    1],
       [   0,    0,    0,    0,    0, 1738,    0,  270,    5, 4414],
       [ 300,   20,  737,   33,  335,    0,  775,    0,  311,    4],
       [ 275,   34,  300,   74,  146,   16,  416,    0,   58,   10],
       [  36,    2,   21,    8,    4,  343,   57,    2,   65,   28],
       [  33,    3,   22,    5,   12,   22,   58,   14, 1829,   24],
       [2496,   21,    8,  101,    9,    1,  539,    0,    5,    0],
       [2209,  125, 4851, 1151, 4942,    2, 3742,    0,  976,    4],
       [   3,    0,    0,    0,    1, 3869,    7, 5714,  829, 1514],
       [  13,    0,   11,    1,   11,    4,   23,    0, 1680,    1]])

In [30]:
purity(model2_conf_mat)

0.4734

In [31]:
gini(model2_conf_mat)

0.6241150499261668

In [32]:
gmm2.covariances_

array([[[ 1.00000000e-06,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  1.00000000e-06,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  1.00000000e-06, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        ...,
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          1.00000000e-06,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  1.00000000e-06,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  1.00000000e-06]],

       [[ 1.00000000e-06,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  1.31551736e-01,  1.31550736e-01, ...,
          8.92948633e-02,  1.04557512e