In [None]:
%load_ext autoreload
%autoreload 2

from cluster import *
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

# Create test dataset
X, _ = make_blobs(
   n_samples=150, n_features=2,
   centers=5, cluster_std=0.5,
   shuffle=True
)

# plot of data
plt.scatter(
   X[:, 0], X[:, 1],
   c='white', marker='o',
   edgecolor='black', s=50,
   label='Data'
)
plt.show()

# Hyperparameters for the K-means ISODATA algorithm
# Initial number of clusters
K_init = 10
# Threshold for standard deviation, if a clusters standard deviation is higher than this value, we split the cluster
std_thres = 1.2

# Threshold for distance between cluster centers, if the distance is lower than this value, we merge these centers
dist_thres = 1.5

# Max iteration of K-means to converge
max_iter = 200
# Maximum amount of merges per iteration, useful for bad initialization when the model does not converge
max_merged = 2
# Minimum amount of cluster centers of the algorithm
min_clusters = 3
# maximum amount of cluster splits per iteration
max_splits = 2

cluster_assignments, centers = Kmeans_ISODATA(X ,K_init, std_thres, dist_thres, max_iter, max_merged, min_clusters, max_splits, verbose=True)

print(f'Shape of final centers and assignments: {centers.shape}, {cluster_assignments.shape}')

# Plot the centers with their attributed data
for id, _ in enumerate(centers):
    data = X[cluster_assignments == id]
    plt.scatter(
    data[:, 0], data[:, 1],
    marker='o',
    edgecolor='black', s=50,
    label='Center: ' + str(id)
)
    
plt.scatter(
   centers[:, 0], centers[:, 1],
   c='blue', marker='*',
   edgecolor='black', s=50,
   label='Centers'
)
plt.legend()
plt.show()
