In [21]:
# Importing necessary packages
import numpy as np
import matplotlib.pyplot as plt
import operator
from mpl_toolkits.mplot3d import Axes3D
from sklearn import mixture
from sklearn.cluster import KMeans, AgglomerativeClustering, MeanShift, estimate_bandwidth
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.datasets import load_iris
import scipy.stats as stats

In [26]:
# Function goal is to determine the right number of clusters for a given algorithm
# Runs algorithms with different numbers of clusters and calculates silhouette coefficient
# Returns the number of clusters that leads to the highest silhouette coefficient
def get_num_clust(feat_mat,alg):
    silhouettes = {}
    for k in range(2,11):
        if alg == 'agg':
            labels = clust_agg(feat_mat,k)
            avg_score = silhouette_score(feat_mat, labels)
            silhouettes[str(k)] = avg_score
        if alg == 'kmeans':
            (labels,centers) = clust_Kmeans(feat_mat,k)
            avg_score = silhouette_score(feat_mat, labels)
            silhouettes[str(k)] = avg_score
    num_clust = int(max(silhouettes.items(), key=operator.itemgetter(1))[0])
    return num_clust

In [4]:
# Function to estimate the bandwidth for MeanShift Clustering
def get_bw(feat_mat):
    bw = estimate_bandwidth(feat_mat)
    return bw

In [5]:
# Function runs MeanShift Clustering on a given feature matrix with a specified bandwidth parameters
# Returns the label for each sample
def clust_MS(feat_mat, bw):
    ms = MeanShift(bandwidth=bw).fit(feat_mat)
    labels = ms.labels_
    centers = ms.cluster_centers_
    return (labels, centers)

In [6]:
# Function runs K-means on a given feature matrix using a specified number of clusters
# Returns the feature matrix, the labels for each data point, and the cluster centers
def clust_Kmeans(feat_mat, num_clust):
    kmeans = KMeans(init='random', n_clusters=num_clust, n_init=100)
    kmeans.fit(feat_mat)
    labels = kmeans.labels_
    centers = kmeans.cluster_centers_
    return (labels, centers)

In [7]:
# Function runs Agglomerative Clustering on a given deature matrix using a specified number of clusters
# Retruns the label for each sample
def clust_agg(feat_mat, num_clust):
    labels = AgglomerativeClustering(n_clusters=num_clust).fit_predict(feat_mat)
    return labels

In [44]:
# Function calculates BIC for different numbers of clusters and different types of fits and returns resulting
# bic array
def get_bic(feat_mat):
    lowest_bic = np.infty
    bic = []
    n_components_range = range(1, 10)
    cv_types = ['spherical', 'tied', 'diag', 'full']
    for cv_type in cv_types:
        for n_components in n_components_range:
            # Fit a Gaussian mixture with EM
            gmm = mixture.GaussianMixture(n_components=n_components,
                                          covariance_type=cv_type)
            gmm.fit(fm)
            bic.append(gmm.bic(fm))
            if bic[-1] < lowest_bic:
                lowest_bic = bic[-1]
                best_gmm = gmm
    bic = np.array(bic)
    return bic

In [47]:
# Function calculates BIC for different numbers of clusters and different types of fits and returns resulting
# aic array
def get_aic(feat_mat):
    lowest_aic = np.infty
    aic = []
    n_components_range = range(1, 10)
    cv_types = ['spherical', 'tied', 'diag', 'full']
    for cv_type in cv_types:
        for n_components in n_components_range:
            # Fit a Gaussian mixture with EM
            gmm = mixture.GaussianMixture(n_components=n_components,
                                          covariance_type=cv_type)
            gmm.fit(fm)
            aic.append(gmm.aic(fm))
            if bic[-1] < lowest_aic:
                lowest_aic = aic[-1]
                best_gmm = gmm
    aic = np.array(aic)
    return aic

In [58]:
# Comparing algorithms on iris dataset

# Loading dataset
fm = load_iris().data

# True labels
true_labels = np.concatenate((np.zeros(50), np.ones(50), np.full(50, 2)))

# Agglomerative Clustering & Silhouette Score
numc_agg = get_num_clust(fm, 'agg')
labels_agg = clust_agg(fm, numc_agg)
ari_agg = adjusted_rand_score(true_labels, labels_agg)

# Kmeans & Silhouette Score
numc_kmeans = get_num_clust(fm, 'kmeans')
(labels_kmeans, centers_kmeans) = clust_Kmeans(fm, numc_kmeans)
ari_kmeans = adjusted_rand_score(true_labels, labels_kmeans)

# MeanShift Clustering & Estimate Bandwidth
bw = get_bw(fm)
(labels_ms, centers_ms) = clust_MS(fm, bw)
ari_ms = adjusted_rand_score(true_labels, labels_ms)

# GMM and BIC
bic = get_bic(fm)
low_bc = np.argmin(bic)

# GMM and AIC
aic = get_aic(fm)
low_aic = np.argmin(aic)

# Getting BIC and AIC labels
gmm = mixture.GaussianMixture(n_components=2, covariance_type='full')
labels_bic = gmm.fit_predict(fm)
ari_bic = adjusted_rand_score(true_labels, labels_bic)
gmm = mixture.GaussianMixture(n_components=6, covariance_type='full')
labels_aic = gmm.fit_predict(fm)
ari_aic = adjusted_rand_score(true_labels, labels_aic)

In [49]:
low_bc

28

In [38]:
bic

array([1804.08543789, 1012.23517979,  853.8093405 ,  784.93215612,
        747.01453574,  705.89437061,  738.9873041 ,  753.50480169,
        714.74597226,  829.97815451,  688.09722028,  633.84624673,
        618.03822133,  625.28622866,  605.17631374,  619.32773739,
        673.05024299,  647.44959123, 1522.12015273,  857.55149412,
        812.46761503,  705.13592649,  700.97348956,  708.10605017,
        715.5800191 ,  734.48531487,  735.65113149,  829.97815451,
        574.01783272,  580.86127847,  629.77902872,  681.45898657,
        714.99771734,  772.3872465 ,  807.81588265,  861.76317353])

In [50]:
low_aic

32

In [51]:
aic

array([1789.03226142,  979.11819156,  802.6285405 ,  715.42597274,
        659.70091295,  600.43649931,  615.55528766,  577.35685422,
        548.14337637,  787.82926039,  630.89514969,  561.57004121,
        530.7297978 ,  563.58719142,  503.30500762,  507.60151994,
        476.55077496,  478.80069804, 1498.03507037,  806.37069412,
        666.35669131,  599.75042534,  568.6812067 ,  547.77676717,
        517.98801329,  516.60482595,  515.57886024,  787.82926039,
        486.70940919,  448.39147183,  452.21036647,  458.67197481,
        394.0100179 ,  445.55394866,  450.99260977,  431.37778658])

In [61]:
print(ari_agg)
print(ari_kmeans)
print(ari_ms)
print(ari_bic)
print(ari_aic)

0.5681159420289855
0.5399218294207123
0.5583714437541352
0.5681159420289855
0.5353270220432224
