In [None]:
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.utils.validation import check_random_state
from sklearn.metrics.pairwise import euclidean_distances

In [None]:
class KMeans(BaseEstimator, ClusterMixin):
    def __init__(self, n_clusters=8, max_iter=300, tol=1e-4, random_state=None):
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.tol = tol
        self.random_state = random_state

    def fit(self, X):
        random_state = check_random_state(self.random_state)
        n_samples, n_features = X.shape

        # Initialize cluster centers randomly
        self.cluster_centers_ = X[random_state.choice(n_samples, self.n_clusters, replace=False)]

        for i in range(self.max_iter):
            # Assign labels based on closest center
            labels = self._assign_labels(X)

            # Compute new centers as the mean of the points in each cluster
            new_centers = np.array([X[labels == j].mean(axis=0) for j in range(self.n_clusters)])

            # Check for convergence
            if np.linalg.norm(new_centers - self.cluster_centers_) < self.tol:
                break

            self.cluster_centers_ = new_centers

        self.labels_ = labels
        return self

    def _assign_labels(self, X):
        # Calculate distances from points to cluster centers
        distances = euclidean_distances(X, self.cluster_centers_)
        # Assign each point to the nearest cluster center
        return np.argmin(distances, axis=1)

    def fit_predict(self, X):
        # Fit the model and return cluster labels for the dataset
        self.fit(X)
        return self.labels_

    def split_clusters(self, X, is_recursive=False, split_clusters = None):
        if is_recursive:
            unique_clusters = split_clusters
        else:
            # For initial call, process all clusters
            unique_clusters = np.unique(self.labels_)

        for i in unique_clusters:
            # Get data points belonging to the current cluster
            cluster_data = X[self.labels_ == i]
            filtered_cluster_data = self.remove_edge_points(cluster_data, i)
            density_matrix = self.get_density_matrix(filtered_cluster_data)

            if self.is_chain_connected(density_matrix):
                # Run K-Means to split the cluster into two
                kmeans_split = KMeans(n_clusters=2, random_state=self.random_state)
                split_labels = kmeans_split.fit_predict(cluster_data)

                # Assign new cluster numbers
                max_label = self.labels_.max() + 1  # New cluster number
                for idx, point in enumerate(cluster_data):
                    original_index = np.where((X == point).all(axis=1))[0][0]
                    if split_labels[idx] == 1:
                        self.labels_[original_index] = max_label  # Assign new cluster label

                # Update the cluster centers
                self.cluster_centers_[i] = kmeans_split.cluster_centers_[0]  # Update existing cluster center
                self.cluster_centers_ = np.vstack([
                    self.cluster_centers_,
                    kmeans_split.cluster_centers_[1]  # Add new cluster center
                ])

                # Increment cluster count
                self.n_clusters += 1
                self.labels_ = self.split_clusters(X, is_recursive=True, split_clusters = [i, max_label])


        return self.labels_


    def remove_edge_points(self, cluster_data, i):
        pass

    def is_chain_connected(self, matrix):
        pass


    def get_density_matrix(self, filtered_cluster_data):
        pass