In [1]:
from sklearn.mixture import GaussianMixture
from sklearn.datasets import fetch_mldata
from collections import Counter
import numpy as np
import os

In [2]:
def evaluation_metrics(pred_labels, true_labels):
    N = len(pred_labels)
    
    cluster_labels = {}
    for i in range(len(pred_labels)):
        cluster_labels.setdefault(pred_labels[i], []).append(true_labels[i])

    K = len(cluster_labels)

    # Store list of labels as a Counter
    for key,value in cluster_labels.items():
        cluster_labels[key] = Counter(value)

    # Calculate purity
    purity = 0
    for cluster in cluster_labels:
        purity += max(cluster_labels[cluster].values())
    purity /= N

    # Calculate gini index
    gini_index = 0
    for key,value in cluster_labels.items():
        gini = 0
        for k,v in value.items():
            gini += (v / sum(cluster_labels[key].values())) ** 2
        gini_index += 1 - gini
    gini_index /= K

    # Final result
    print('Purity -', round(purity, 4), 'Gini Index -', round(gini_index, 4), '\n')

In [3]:
# Fetch data
mnist_dataset = fetch_mldata('mnist original')

# Data and labels
mnist_data = mnist_dataset.data
mnist_labels = mnist_dataset.target

print(mnist_data.shape)
print(mnist_labels.shape)

(70000, 784)
(70000,)


In [4]:
print('Without Normalizing')
model = GaussianMixture(n_components=10, covariance_type='diag', init_params='kmeans', max_iter=200)
model.fit(mnist_data)

pred_train_labels = model.predict(mnist_data)
evaluation_metrics(pred_train_labels, mnist_labels)

Without Normalizing
Purity - 0.3765 Gini Index - 0.6833 



In [5]:
# Normalize data
norm_mnist_data = np.divide(mnist_data, 255)
print(norm_mnist_data.shape)

(70000, 784)


In [6]:
print('With Normalizing')
model = GaussianMixture(n_components=10, covariance_type='diag', init_params='kmeans', max_iter=200)
model.fit(norm_mnist_data)

pred_test_labels = model.predict(norm_mnist_data)
evaluation_metrics(pred_test_labels, mnist_labels)

With Normalizing
Purity - 0.4111 Gini Index - 0.7102 

