In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
from ml_algorithms.unsupervised import KMeans

iris = load_iris()
X = iris.data
y_true = iris.target

kmeans = KMeans(n_clusters=3, max_iter=500)
kmeans.fit(X)
historical_centroids = np.array(kmeans._historical_centroids)
historical_labels = np.array(kmeans._historical_labels)

pca = PCA(n_components=2)
X_2d = pca.fit_transform(X)
centroids_pca = np.array([pca.transform(c) for c in historical_centroids])

def best_label_mapping(true_labels, cluster_labels):
    D = confusion_matrix(true_labels, cluster_labels)
    row_ind, col_ind = linear_sum_assignment(-D)
    mapping = dict(zip(col_ind, row_ind))
    new_labels = np.vectorize(mapping.get)(cluster_labels)
    return new_labels

final_labels = best_label_mapping(y_true, historical_labels[-1])
misclassified = final_labels != y_true

fig, ax = plt.subplots(figsize=(8, 6))
scatter = ax.scatter(X_2d[:, 0], X_2d[:, 1], c=historical_labels[0], cmap='viridis', s=30)
misclassified_scat = ax.scatter([], [], edgecolors='red', facecolors='none', s=80, linewidths=1.5)
centroid_scat = ax.scatter([], [], c='black', marker='X', s=200)

def update(frame):
    coords = centroids_pca[frame]
    labels = historical_labels[frame]
    mapped_labels = best_label_mapping(y_true, labels)
    
    scatter.set_array(mapped_labels)
    centroid_scat.set_offsets(coords)

    misclassified = mapped_labels != y_true
    mis_points = X_2d[misclassified]
    misclassified_scat.set_offsets(mis_points)

    ax.set_title(f"KMeans Step {frame + 1}")
    return scatter, centroid_scat, misclassified_scat

ani = FuncAnimation(fig, update, frames=len(centroids_pca), interval=500, blit=True)

plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.grid(True)
plt.tight_layout()
plt.close()

from IPython.display import HTML
%matplotlib inline

HTML(ani.to_jshtml())