In [9]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
from sklearn.metrics.pairwise import rbf_kernel

class RBFClustering:
    def __init__(self, n_centers=3, gamma=1.0, n_epochs=100, radius_factor=3):
        self.n_centers = n_centers
        self.gamma = gamma
        self.n_epochs = n_epochs
        self.centers = None
        self.frames = []  # Store frames for animation
        self.radius_factor = radius_factor  # Factor to scale the radius for visualization

    def fit(self, X, save_path="RBF_Clustering_with_Radii.gif"):
        # Randomly initialize the centers of the RBFs (cluster centers)
        np.random.seed(42)
        self.centers = X[np.random.choice(X.shape[0], self.n_centers, replace=False)]

        for epoch in range(self.n_epochs):
            # Compute the RBF kernel (similarity between data points and centers)
            K = rbf_kernel(X, self.centers, gamma=self.gamma)
            
            # Assign each point to the nearest center (cluster)
            labels = np.argmax(K, axis=1)
            
            # Update the centers based on the mean of points assigned to each cluster
            for i in range(self.n_centers):
                points_in_cluster = X[labels == i]
                if len(points_in_cluster) > 0:
                    self.centers[i] = np.mean(points_in_cluster, axis=0)
            
            # Save the frame for the animation
            self._save_plot(X, labels, epoch)

        # Save GIF animation
        imageio.mimsave(save_path, self.frames, duration=0.3)
        print(f"GIF saved successfully as {save_path}")

    def _save_plot(self, X, labels, epoch):
        """Visualizes the RBF clustering process over training epochs."""
        fig, ax = plt.subplots(figsize=(8, 8))

        # Scatter plot of data points, colored by cluster
        scatter = ax.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', alpha=0.7, edgecolors='k', s=100)

        # Plot the RBF centers
        ax.scatter(self.centers[:, 0], self.centers[:, 1], c='red', s=200, marker='X', label='Centers', edgecolors='black', linewidth=2)

        # Draw circles representing the radius of each RBF center
        for center in self.centers:
            circle = plt.Circle(center, self.radius_factor * (1/self.gamma), color='red', fill=False, linestyle='--', linewidth=2)
            ax.add_patch(circle)

        # Title and labels
        ax.set_title(f'RBF Clustering (Epoch {epoch+1}/{self.n_epochs})', fontsize=14)
        ax.set_xlabel('Feature 1', fontsize=12)
        ax.set_ylabel('Feature 2', fontsize=12)
        ax.legend(loc='best', fontsize=12)
        
        # Adding grid and other plot properties
        ax.grid(True, linestyle='--', alpha=0.3)

        # Save frame for GIF
        fig.canvas.draw()
        image = np.array(fig.canvas.renderer.buffer_rgba())
        self.frames.append(image)
        plt.close(fig)

# Generate synthetic data (random data points)
np.random.seed(42)
X = np.concatenate([
    np.random.randn(50, 2) * 1 + [1, 1],
    np.random.randn(50, 2) * 1 + [-1, -1],
    np.random.randn(50, 2) * 1 + [1, -1]
])

# Train RBF Clustering model and save GIF
model = RBFClustering(n_centers=3, gamma=1.0, n_epochs=100)
model.fit(X, save_path="RBF_Clustering_with_Radii.gif")


GIF saved successfully as RBF_Clustering_with_Radii.gif
