<a href="https://colab.research.google.com/github/Mantis-Ryuji/Playground/blob/main/KMeans_Animation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.animation import FuncAnimation
from sklearn.datasets import make_moons
from IPython.display import HTML

### データ生成

In [2]:
X, _ = make_moons(n_samples=500, noise=0.3)

### アニメーションとKmeansの作成

In [4]:
num_points = X.shape[0]
num_clusters = 2
data_points = X
num_iterations = 10

# 初期クラスタ中心
np.random.seed(42)
centroids = np.random.uniform(low=-3, high=5, size=(num_clusters, 2))
centroid_history = [[] for _ in range(num_clusters)]

fig, ax = plt.subplots(figsize=(8, 8))
colors = cm.rainbow(np.linspace(0, 1, num_clusters))

def init():
    ax.set_title('K-Means Clustering', fontsize=15)
    ax.set_facecolor('lightgrey')
    ax.set_xticks([])
    ax.set_yticks([])
    return []

def update(frame):
    global centroids

    distances = np.sqrt(np.sum((data_points - centroids[:, np.newaxis]) ** 2, axis=2))
    labels = np.argmin(distances, axis=0)

    new_centroids = np.copy(centroids)
    for i in range(num_clusters):
        if np.any(labels == i):
            new_centroids[i] = np.mean(data_points[labels == i], axis=0)

    centroids[:] = new_centroids

    for i in range(num_clusters):
        centroid_history[i].append(centroids[i].copy())

    ax.cla()
    for i in range(num_clusters):
        cluster_points = data_points[labels == i]
        ax.scatter(cluster_points[:, 0], cluster_points[:, 1], c=[colors[i]], label=f'Cluster {i}', alpha=0.6)

        for point in cluster_points:
            ax.plot([point[0], centroids[i, 0]], [point[1], centroids[i, 1]],
                    c=colors[i], lw=1.2, alpha=0.4)

        if len(centroid_history[i]) > 1:
            traj = np.array(centroid_history[i])
            ax.plot(traj[:, 0], traj[:, 1], '--', c='k', lw=2, zorder=10)

    ax.scatter(centroids[:, 0], centroids[:, 1], c='black',
               marker='x', s=500, linewidths=3, zorder=10)

    ax.set_title(f'K-Means Clustering (Iteration {frame + 1}/{num_iterations})', fontsize=15)
    ax.set_facecolor('lightgrey')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.legend()
    return []

ani = FuncAnimation(fig, update, frames=range(num_iterations), init_func=init, blit=False)
plt.close(fig)

### KMean.gif が出力結果