In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

## Finding the Optimum Number of Clusters

To make sure that K-Means and Gaussian Mixture Modelling models are working efficiently we need to provide them with a starting number of clusters. If the number of clusters is incorrectly selected, the algorithms may not perform well.

We can attempt to identify the optimum number of clusters using an elblow plot, where the goal is to select a number for the clusters based on the ‘elbow’ or inflexion formed in the results. There are other methods such as the silhouette method for picking the number of clusters.

In [None]:
def optimise_k_means(data, max_k):
    means = []
    inertias = []
    
    for k in tqdm(range(1,max_k)):
        kmeans = KMeans(n_clusters=k)
        kmeans.fit(data)
        means.append(k)
        inertias.append(kmeans.inertia_)
        
    fig = plt.subplots(figsize=(10, 5))
    plt.plot(means, inertias, 'o-')
    plt.xlabel("Number of Clusters")
    plt.ylabel("Inertia")
    plt.grid(True)
    plt.show()

In [None]:
optimise_k_means(df, 15)

According to the In plot above, we can see how the inertia (sum of the squared distances to the nearest cluster center) changes as we increase the number of clusters. Therefore we can find out at which point the slope is changing drastically. That point can be show an approximate of the number of clusters in our dataset.

## Fitting the Clustering Models

In [None]:
# Create the KMeans model with the selected number of clusters
kmeans = KMeans(n_clusters=4) # 4 is an example

# Fit the model to our dataset
kmeans.fit(df)

# Assign the data back to the df
df['KMeans'] = kmeans.labels_

In [None]:
# Create the gmm model with the selected number of clusters/components
gmm = GaussianMixture(n_components=4) # 4 is an example

# Fit the model to our dataset
gmm.fit(df)

# Predict the labels
gmm_labels = gmm.predict(df)

# Assign the labels back to the df
df['GMM'] = gmm_labels