In [None]:
!pip install faiss-cpu
import numpy as np
import pandas as pd
import faiss
import time
from pathlib import Path
import gc
from scipy.linalg import orth # For creating orthogonal matrices

class FCFC:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                    min_points_per_centroid=1, max_points_per_centroid=1000000000,
                    seed=1234, gpu=False, spherical=False,
                    update_index=True, frozen_centroids=False,
                    lambda_=1.0):
        # d: dimensionality of data
        # k: number of clusters
        self.d_features = d # Renamed for clarity, consistent with other classes
        self.k = k
        self.niter = niter
        self.max_iter = niter # Keep for consistency with existing loop

        # Other parameters (some might not be used by this specific FCFC logic but kept for interface)
        self.nredo = nredo
        self.verbose = verbose
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.seed = seed
        self.gpu = gpu # This FCFC implementation is CPU-based
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids
        self.lambda_ = lambda_  # Balance parameter for the objective function in get_distance

        # Results storage
        self.centroids = None               # Final cluster centroids (k, d_features)
        self.labels_ = None                 # Final cluster assignments for each point (n_samples,)
        self.runtime_ = None                # Total training time

        self.objective_history_ = None      # History of sum_dis (sum of D(i,j) values)
        self.sse_history_ = None            # History of Sum of Squared Errors per iteration
        self.balance_loss_history_ = None   # History of Balance Loss per iteration

        self.final_objective_ = None        # Final value from objective_history_
        self.final_sse_ = None              # Final Sum of Squared Errors
        self.final_balance_loss_ = None     # Final Balance Loss
        self.final_cluster_sizes_ = None    # Final size of each cluster (k,)

        self.sse_ = 0
        self.balance_loss_ = 0

        # For compatibility, self.obj can point to the primary objective history
        self.obj = None


    def train(self, x, weights=None, init_centroids_arg=None): # Renamed init_centroids to avoid conflict
        np.random.seed(self.seed) # Set seed for reproducibility
        start_time = time.time()

        K = self.k
        if x.shape[1] != self.d_features:
            raise ValueError(f"Input data feature dimension {x.shape[1]} "
                             f"does not match class initialized dimension {self.d_features}")
        n_samples, d_features_data = x.shape # n_samples, d in original code

        # Initialize arrays for storing per-iteration metrics
        # sse_history stores the traditional SSE
        # balance_loss_history stores the balance penalty
        # objective_value_history stores the sum of D(i,j) which is being minimized directly
        sse_history = np.zeros(self.max_iter)
        balance_loss_history = np.zeros(self.max_iter)
        objective_value_history = np.zeros(self.max_iter) # Corresponds to pre_dis

        # Initialize centroids
        # If init_centroids_arg is provided, use it, otherwise use random initialization
        if init_centroids_arg is not None:
            if init_centroids_arg.shape != (K, d_features_data):
                raise ValueError(f"Provided init_centroids shape {init_centroids_arg.shape} "
                                 f"is not ({K}, {d_features_data})")
            current_centroids = np.copy(init_centroids_arg)
        else:
            current_centroids = initial_centroid(x, K, n_samples) # Uses np.random internally

        # size_cluster is 1*K vector, stores size of each cluster for the get_distance objective
        # Initialized to ones to avoid issues if lambda_ > 0 and a cluster is initially empty,
        # though it gets updated immediately in the first iteration.
        # A more common initialization might be n_samples/K or based on initial assignment.
        # Let's base it on an initial quick assignment or n_samples/K to be more robust.
        # For simplicity of matching the provided code, it starts with ones and is quickly updated.
        current_size_cluster = np.ones(K) # Will be updated after first assignment

        current_labels = np.zeros(n_samples, dtype=int) # To store labels for each point

        for i in range(self.max_iter):
            # Assignment step: Calculate D matrix and assign points to clusters
            # D(point, cluster_j) = distance(point, centroid_j)^2 + lambda * size_cluster_j
            D_matrix = get_distance(x, current_centroids, K, n_samples, d_features_data,
                                    current_size_cluster, self.lambda_)

            min_dist_to_centroid_plus_balance = np.min(D_matrix, axis=1) # (n_samples,)
            assigned_labels = np.argmin(D_matrix, axis=1)           # (n_samples,)
            sum_objective_values = np.sum(min_dist_to_centroid_plus_balance)

            current_labels = assigned_labels
            objective_value_history[i] = sum_objective_values

            # Update step: Recalculate centroids and cluster sizes
            # current_size_cluster is based on the new assignments
            current_size_cluster = np.bincount(current_labels, minlength=K)
            current_centroids = get_centroid(x, current_labels, K, n_samples, d_features_data)

            # Calculate SSE and Balance Loss for this iteration (for monitoring)
            iter_sse = 0
            iter_balance_penalty_terms = np.zeros(K)

            for j in range(K):
                cluster_points = x[current_labels == j, :]
                if cluster_points.shape[0] > 0: # If cluster is not empty
                    # SSE part: sum of squared distances to its actual centroid
                    iter_sse += np.sum(np.sum((cluster_points - current_centroids[j, :])**2, axis=1))
                # Balance loss part (using current_size_cluster which is already updated)
                iter_balance_penalty_terms[j] = (current_size_cluster[j] - n_samples / K)**2

            sse_history[i] = iter_sse
            balance_loss_history[i] = np.sum(iter_balance_penalty_terms)

            if self.verbose and (i % 5 == 0 or i == self.max_iter -1) :
                print(f"Iter {i+1}/{self.max_iter}: Objective={objective_value_history[i]:.4f}, "
                      f"SSE={sse_history[i]:.4f}, BalanceLoss={balance_loss_history[i]:.4f}")

        self.runtime_ = time.time() - start_time

        # Store results
        self.centroids = current_centroids
        self.labels_ = current_labels
        self.final_cluster_sizes_ = current_size_cluster

        self.objective_history_ = objective_value_history
        self.sse_history_ = sse_history
        self.balance_loss_history_ = balance_loss_history

        self.final_objective_ = objective_value_history[-1]
        self.final_sse_ = sse_history[-1]
        self.final_balance_loss_ = balance_loss_history[-1]

        self.obj = self.objective_history_ # Storing the history of the optimized objective

        self.sse_ = self.final_sse_
        self.balance_loss_ = self.final_balance_loss_

        if self.verbose:
            print(f"FCFC training completed in {self.runtime_:.4f} seconds.")
            print(f"Final Objective (sum of D(i,j)): {self.final_objective_:.4f}")
            print(f"Final SSE: {self.final_sse_:.4f}")
            print(f"Final Balance Loss: {self.final_balance_loss_:.4f}")
            print(f"Final cluster sizes: {self.final_cluster_sizes_}")

class Lloyd:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                 min_points_per_centroid=1, max_points_per_centroid=1000000000,
                 seed=1234, gpu=False, spherical=False,
                 update_index=True, frozen_centroids=False):
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo
        self.verbose = verbose
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.seed = seed
        self.gpu = gpu
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        self.centroids = None
        self.obj_history_ = None
        self.labels_ = None

        self.sse_ = None
        self.balance_loss_ = None
        self.runtime_ = None
        self.obj = None

    def compute_centroids_from_data(self, data_points, labels, num_clusters, data_dim):
        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        if labels is None:
            if self.verbose:
                print("Warning: Labels are None in compute_centroids_from_data.")
            return centroids

        for i in range(len(data_points)):
            label = labels[i]
            centroids[label] += data_points[i]
            counts[label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning: Cluster {j} is empty. Assigning random point.")
                centroids[j] = data_points[np.random.randint(len(data_points))]

        return centroids

    def train(self, x_orig_data, weights=None, init_centroids=None):
        start_time = time.time()
        np.random.seed(self.seed)

        if x_orig_data.shape[1] != self.d:
            raise ValueError(f"Input dimension {x_orig_data.shape[1]} != {self.d}")

        n, dim = x_orig_data.shape
        x = np.ascontiguousarray(x_orig_data, dtype='float32')

        kmeans = faiss.Kmeans(
            d=self.d,
            k=self.k,
            niter=self.niter,
            nredo=self.nredo,
            verbose=self.verbose,
            min_points_per_centroid=self.min_points_per_centroid,
            max_points_per_centroid=self.max_points_per_centroid,
            seed=self.seed,
            gpu=self.gpu,
            spherical=self.spherical,
            update_index=self.update_index,
            frozen_centroids=self.frozen_centroids
        )

        kmeans.train(x, init_centroids=init_centroids)

        _, self.labels_ = kmeans.index.search(x, 1)
        self.labels_ = self.labels_.flatten()

        self.centroids = kmeans.centroids
        self.obj_history_ = kmeans.obj if kmeans.obj is not None and len(kmeans.obj) > 0 else np.zeros(self.niter)
        self.obj = kmeans.obj[-1] if kmeans.obj is not None and len(kmeans.obj) > 0 else None
        self.runtime_ = time.time() - start_time


        # Print every 5th iteration's objective value
        if self.verbose and self.obj_history_ is not None and len(self.obj_history_) > 0:
            print("\n--- Objective Value (every 5 iterations) ---")
            for i, val in enumerate(self.obj_history_):
                if (i + 1) % 5 == 0 or i == len(self.obj_history_) - 1:
                    print(f"  Iter {i+1:2d}: {val:.6f}")

        final_sse = 0
        for i in range(n):
            cluster_idx = self.labels_[i]
            final_sse += np.sum((x_orig_data[i] - self.centroids[cluster_idx]) ** 2)
        self.sse_ = final_sse

        final_balance_loss = 0
        if self.labels_ is not None:
            sizes = np.bincount(self.labels_, minlength=self.k)
            ideal = n / self.k
            final_balance_loss = np.sum((sizes - ideal) ** 2)
        self.balance_loss_ = final_balance_loss

        if self.verbose:
            print(f"Lloyd training finished in {self.runtime_:.4f}s")
            print(f"Final obj: {self.obj}")
            print(f"Cluster sizes: {dict(zip(*np.unique(self.labels_, return_counts=True)))}")
            print(f"SSE: {self.sse_:.4f}")
            print(f"Balance Loss: {self.balance_loss_:.4f}")

class BCLS:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                 min_points_per_centroid=1, max_points_per_centroid=1e9,
                 seed=1234, gpu=False, spherical=False,
                 update_index=True, frozen_centroids=False,
                 lambda_=1.0): # lambda_ from FCFC, but BCLS uses 'lam' internally
        # d: dimensionality of data
        # k: number of clusters
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo # Not used by BCLS algorithm itself
        self.verbose = verbose
        # The following Faiss-like parameters are not directly used by BCLS's core logic:
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.gpu = gpu
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        self.seed = seed
        self.lambda_bcls = lambda_ # BCLS specific lambda for sum_Y term in objective
                                  # If the lambda_ parameter was meant for this, it's used as 'lam' below.
                                  # If it was for something else, then 'lam' needs its own source.
                                  # Assuming lambda_ is the 'lam' for BCLS objective.

        # Results storage
        self.centroids = None       # Will store centroids in original data space
        self.obj_history_ = None    # Stores Obj2 from the loop
        self.labels_ = None         # Final cluster assignments (0-indexed)
        self.Y_final_ = None        # Final Y matrix (one-hot indicators)

        # Final metrics
        self.sse_ = None
        self.balance_loss_ = None
        self.runtime_ = None

        # For compatibility with previous structure if any part expects 'obj'
        self.obj = None


    def init1(self, n_samples, num_clusters):
        """
        Initializes the Y matrix (n_samples x num_clusters) with one-hot encoding.
        Labels are 1 to num_clusters, then converted to 0-indexed for Python.
        """
        # np.random is affected by self.seed if set before calling train
        labels_1_indexed = np.random.randint(1, num_clusters + 1, size=n_samples)
        F = np.zeros((n_samples, num_clusters))
        F[np.arange(n_samples), labels_1_indexed - 1] = 1
        # F = csr_matrix(F) # Can be sparse if n and k are very large
        return F

    def compute_centroids_from_data(self, data_points, labels, num_clusters, data_dim):
        """
        Calculates centroids from data points and their labels.
        data_points: (n_samples, n_features) - original or centered
        labels: (n_samples,) - 0-indexed
        num_clusters: k
        data_dim: d
        Returns: (num_clusters, data_dim) centroids
        """
        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        if labels is None:
            if self.verbose:
                print("Warning: Labels are None in compute_centroids_from_data. Returning zero centroids.")
            return centroids

        for i in range(len(data_points)):
            label = labels[i]
            centroids[label] += data_points[i, :]
            counts[label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning: Cluster {j} is empty. Assigning a random data point as its centroid.")
                if len(data_points) > 0:
                    # Seed this random choice for consistency if multiple empty clusters
                    rng_empty_fallback = np.random.RandomState(self.seed + j + 1000) # Offset seed
                    centroids[j] = data_points[rng_empty_fallback.choice(len(data_points)), :]
                # else: centroids[j] remains zeros
        return centroids


    def train(self, x_orig_data, weights=None, init_centroids=None): # x_orig_data is n x dim
        np.random.seed(self.seed) # Ensure reproducibility for operations within train
        start_time = time.time()

        if x_orig_data.shape[1] != self.d:
            raise ValueError(f"Input data feature dimension {x_orig_data.shape[1]} "
                             f"does not match class initialized dimension {self.d}")

        ITER = self.niter
        # BCLS Algorithm Hyperparameters (taken from the provided snippet)
        gamma = 0.00001  # Regularization for W
        lam = self.lambda_bcls # Controls balance term in objective (sum_Y**2)
        mu = 0.01        # ALM parameter

        n, dim = x_orig_data.shape
        c = self.k  # number of clusters

        # Initialize Y
        Y = self.init1(n, c) # Y is n x c

        # Center the data (BCLS works with centered data)
        meanX = np.mean(x_orig_data, axis=0, keepdims=True) # 1 x dim
        x_centered = x_orig_data - meanX # n x dim

        # ALM variables
        Lambda_alm = np.zeros((n, c)) # Lagrange multipliers for Y - Z = 0
        rho = 1.005                # Update factor for mu

        # Precompute part of W update
        # P_inv = x_centered.T @ x_centered + gamma * np.eye(dim)
        # P = np.linalg.inv(P_inv)
        # Using pseudo-inverse for potentially better stability if P_inv is singular/ill-conditioned
        try:
            P = np.linalg.inv(x_centered.T @ x_centered + gamma * np.eye(dim))
        except np.linalg.LinAlgError:
            if self.verbose:
                print("Warning: Standard inverse failed for P. Using pseudo-inverse.")
            P = np.linalg.pinv(x_centered.T @ x_centered + gamma * np.eye(dim))


        obj_history = np.zeros(ITER)
        # Optional: if you want to track SSE/BalanceLoss per iteration (on centered data)
        # sse_iter_history = np.zeros(ITER)
        # balance_loss_iter_history = np.zeros(ITER)


        for iter_idx in range(ITER):
            # --- Solve W and b ---
            # W: dim x c, b: 1 x c
            W = P @ (x_centered.T @ Y)
            b = np.mean(Y, axis=0, keepdims=True) # Or (1/n) * (np.ones((1,n)) @ Y)

            # E = XW + 1b' - Y (Error term for reconstruction using centered X)
            # E_recon: n x c
            E_recon = x_centered @ W + np.ones((n, 1)) @ b - Y

            # --- Solve Z (auxiliary variable for Y) ---
            # Z: n x c
            # Denominator matrix for Z update:
            # Factor = mu**2 + 2 * n * lam * mu  (scalar)
            # Coeff_matrix_inv = (-2 * lam * np.ones((n,n)) + (mu + 2 * n * lam) * np.eye(n)) / Factor
            # Z = Coeff_matrix_inv @ (mu * Y + Lambda_alm)
            # Simpler if Z is updated element-wise or if structure allows.
            # The provided formula for Z seems like a direct solution from a specific formulation.
            # Let's assume the formula is correct as given:
            # Note: (mu**2 + 2 * n * lam * mu) is a scalar.
            # The matrix to invert for Z is effectively ( (mu + 2*n*lam)*I - 2*lam*J ), where J is all-ones matrix.
            # This matrix has a specific inverse (Sherman-Woodbury).
            # For now, using the provided direct calculation:
            mat_for_Z_inv_num = -2 * lam * np.ones((n, n)) + (mu + 2 * n * lam) * np.eye(n)
            mat_for_Z_inv_den = (mu**2 + 2 * n * lam * mu)
            if np.abs(mat_for_Z_inv_den) < 1e-9: # Avoid division by zero
                 if self.verbose: print(f"Warning: Denominator for Z is near zero at iter {iter_idx}")
                 Z = Y # Fallback or handle error
            else:
                 Z = (mat_for_Z_inv_num / mat_for_Z_inv_den) @ (mu * Y + Lambda_alm)


            # --- Solve Y (indicator matrix) ---
            # V: n x c
            V_update = (1 / (2 + mu)) * (2 * x_centered @ W + 2 * np.ones((n, 1)) @ b + mu * Z - Lambda_alm)

            # Update Y by selecting the max element in each row of V_update
            current_labels = np.argmax(V_update, axis=1) # n-element array of 0-indexed labels
            Y = np.zeros((n, c))
            Y[np.arange(n), current_labels] = 1

            # --- Update Lambda (Lagrange multipliers) and mu (penalty parameter) for ALM ---
            Lambda_alm = Lambda_alm + mu * (Y - Z)
            mu = min(mu * rho, 1e5) # Cap mu to avoid very large values

            # --- Calculate Objective Value (for centered data) ---
            sum_Y_elements = np.sum(Y) # Sum of all elements in Y (should be n if Y is strictly one-hot)
            obj_history[iter_idx] = np.trace(E_recon.T @ E_recon) + \
                                    gamma * np.trace(W.T @ W) + \
                                    lam * (sum_Y_elements**2) # Or lam * np.sum( (np.sum(Y, axis=0) - n/c)**2 ) if balance is per cluster size


            # --- In-loop SSE and Balance Loss (on centered data, for monitoring if needed) ---
            # These are calculated based on current Y and centered data.
            # Centroids for centered data: c x dim
            temp_centroids_centered = self.compute_centroids_from_data(x_centered, current_labels, c, dim)

            sse_iter = 0
            for i in range(n):
                cluster_idx = current_labels[i]
                # Using np.sum for squared norm for clarity with dimensions
                sse_iter += np.sum((x_centered[i, :] - temp_centroids_centered[cluster_idx, :])**2)
            # sse_iter_history[iter_idx] = sse_iter

            cluster_sizes_iter = np.sum(Y, axis=0) # n_elements per cluster (1 x c)
            ideal_size_iter = n / c
            balance_loss_iter = np.sum((cluster_sizes_iter - ideal_size_iter)**2)
            # balance_loss_iter_history[iter_idx] = balance_loss_iter

            if self.verbose and (iter_idx % 10 == 0 or iter_idx == ITER -1):
                print(f"Iter {iter_idx+1}/{ITER}, BCLS Obj: {obj_history[iter_idx]:.4f}, "
                      f"Iter SSE (centered): {sse_iter:.2f}, Iter Bal (centered): {balance_loss_iter:.2f}")


        # --- End of iterations ---
        self.runtime_ = time.time() - start_time

        # Store final results
        self.labels_ = np.argmax(Y, axis=1) # Final 0-indexed labels
        self.Y_final_ = Y                   # Final one-hot indicator matrix
        self.obj_history_ = obj_history
        self.obj = obj_history # Compatibility

        # Calculate final centroids in ORIGINAL data space
        # Use x_orig_data and self.labels_
        final_centroids_orig_space = self.compute_centroids_from_data(x_orig_data, self.labels_, c, dim)
        self.centroids = final_centroids_orig_space # Store k x dim centroids

        # Calculate final SSE using ORIGINAL data and ORIGINAL space centroids
        final_sse = 0
        if self.labels_ is not None and self.centroids is not None:
            for i in range(n):
                cluster_idx = self.labels_[i]
                point_orig = x_orig_data[i, :]
                centroid_orig = self.centroids[cluster_idx, :]
                final_sse += np.sum((point_orig - centroid_orig)**2)
        self.sse_ = final_sse

        # Calculate final Balance Loss
        final_balance_loss = 0
        if self.labels_ is not None:
            final_cluster_sizes = np.bincount(self.labels_, minlength=c)
            ideal_size = n / c
            final_balance_loss = np.sum((final_cluster_sizes - ideal_size)**2)
        self.balance_loss_ = final_balance_loss

        if self.verbose:
            print(f"BCLS training completed in {self.runtime_:.4f} seconds.")
            print(f"Final BCLS objective value: {self.obj_history_[-1]:.4f}")
            unique_labels_final, counts_final = np.unique(self.labels_, return_counts=True)
            print(f"Final cluster sizes: {dict(zip(unique_labels_final, counts_final))}")
            if self.centroids is not None:
                print(f"Shape of final centroids (original space): {self.centroids.shape}")
            print(f"Final SSE (original space): {self.sse_:.4f}")
            print(f"Final Balance Loss: {self.balance_loss_:.4f}")


    def compute_centroids(self, x_transposed, F_indicator):
        """
        Computes centroids.
        x_transposed: (dim, n_samples) data matrix (e.g., centered data transposed)
        F_indicator: (n_samples, k) one-hot cluster indicator matrix
        Returns: (k, dim) centroids
        DEPRECATED in favor of compute_centroids_from_data for clarity, but kept if used elsewhere.
        This version is slightly different from compute_centroids_from_data input format.
        """
        num_clusters = F_indicator.shape[1]
        data_dim = x_transposed.shape[0]
        n_samples_check = x_transposed.shape[1]

        if F_indicator.shape[0] != n_samples_check:
            raise ValueError("Mismatch in number of samples between x_transposed and F_indicator.")

        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        # Determine labels from F_indicator
        labels = np.argmax(F_indicator, axis=1) # (n_samples,)

        for i in range(n_samples_check):
            cluster_label = labels[i]
            centroids[cluster_label] += x_transposed[:, i] # x_transposed[:, i] is a data point (dim,)
            counts[cluster_label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning (compute_centroids): Cluster {j} is empty. Assigning random point.")
                if n_samples_check > 0:
                    rng_empty_fallback = np.random.RandomState(self.seed + j + 2000)
                    centroids[j] = x_transposed[:, rng_empty_fallback.choice(n_samples_check)]
        return centroids

class CDKM_PurePy:
    def __init__(self, X: np.ndarray, c_true: int, debug: int = 0):
        self.X = X.astype(np.float64)  # shape (N, dim)
        self.N, self.dim = self.X.shape
        self.c_true = c_true
        self.debug = debug

        self.Y = []             # replicate list of label vectors
        self.n_iter_ = []       # number of iterations per replicate

        if debug:
            print(f"N = {self.N}, dim = {self.dim}, k = {self.c_true}")

    def opt(self, init_Y: np.ndarray, ITER: int):
        """
        init_Y: (rep, N) array of integer labels
        """
        rep = init_Y.shape[0]
        for rep_i in range(rep):
            y = init_Y[rep_i].copy()
            n_iter = self.opt_once(y, ITER)
            self.Y.append(y)
            self.n_iter_.append(n_iter)

    def opt_once(self, y: np.ndarray, ITER: int) -> int:
        """
        y: shape (N,), initial cluster assignment
        """
        X = self.X
        N, dim, c_true = self.N, self.dim, self.c_true

        xnorm = np.sum(X**2, axis=1)  # shape (N,)
        Sx = np.zeros((dim, c_true))
        n = np.zeros(c_true)

        for i in range(N):
            Sx[:, y[i]] += X[i]
            n[y[i]] += 1

        s = np.sum(Sx**2, axis=0)  # squared norm of each cluster sum vector

        for iter in range(ITER):
            converge = True
            for i in range(N):
                c_old = y[i]
                if n[c_old] == 1:
                    continue

                xi = X[i]
                xiSx = xi @ Sx  # (c,)
                tmp1 = s + 2 * xiSx + xnorm[i]
                tmp1 = tmp1 / (n + 1)
                tmp2 = s / n

                delta = tmp1 - tmp2
                delta[c_old] = s[c_old] / n[c_old] - \
                    (s[c_old] - 2 * xiSx[c_old] + xnorm[i]) / (n[c_old] - 1)

                c_new = np.argmax(delta)

                if c_new != c_old:
                    converge = False
                    y[i] = c_new

                    Sx[:, c_old] -= xi
                    Sx[:, c_new] += xi

                    s[c_old] = np.sum(Sx[:, c_old]**2)
                    s[c_new] = np.sum(Sx[:, c_new]**2)

                    n[c_old] -= 1
                    n[c_new] += 1

                if self.debug and i % 10000 == 0:
                    print(f"i = {i}")

            if self.debug:
                print(f"iter = {iter}")

            if converge:
                break

        # if iter + 1 == ITER:
            # print("not converge")

        return iter + 1

    @property
    def y_pre(self):
        return self.Y


class CDKM:
    def __init__(self, d, k, niter=200, nredo=10, verbose=False, seed=1234, debug=0):
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo
        self.verbose = verbose
        self.seed = seed
        self.debug = debug
        self.centroids = None
        self.labels_ = None
        self.Y_final_ = None
        self.sse_ = None
        self.balance_loss_ = None
        self.runtime_ = None
        self.n_iter_ = None

    def train(self, x_orig_data, weights=None, init_centroids=None):
        np.random.seed(self.seed)
        start_time = time.time()

        n, dim = x_orig_data.shape
        if dim != self.d:
            raise ValueError(f"Data dimension {dim} does not match expected {self.d}.")

        if init_centroids is None:
            init_Y = initial_Y(x_orig_data, self.k, self.nredo, "random")
        else:
            init_Y = init_centroids

        model = CDKM_PurePy(x_orig_data, self.k, debug=self.debug)
        model.opt(init_Y, ITER=self.niter)
        Y = model.y_pre
        self.n_iter_ = model.n_iter_

        centroids = compute_cluster_centers_cdkm(x_orig_data, Y)
        labels = np.argmax(one_hot(Y[0], self.k), axis=1)

        # Compute SSE
        sse = np.sum((x_orig_data - centroids[labels])**2)

        # Compute balance loss
        counts = np.bincount(labels, minlength=self.k)
        ideal_size = n / self.k
        balance_loss = np.sum((counts - ideal_size)**2)

        self.Y_final_ = one_hot(Y[0], self.k)
        self.centroids = centroids
        self.labels_ = labels
        self.sse_ = sse
        self.balance_loss_ = balance_loss
        self.runtime_ = time.time() - start_time

        if self.verbose:
            print(f"CDKM finished in {self.runtime_:.4f}s; "
                  f"SSE = {self.sse_:.4f}; "
                  f"Balance Loss = {self.balance_loss_:.4f}; "
                  f"Iterations = {self.n_iter_}")

class BKNC:
    def __init__(self, d, k, niter=25, nredo=1, verbose=False,
                    min_points_per_centroid=1, max_points_per_centroid=1e9,
                    seed=1234, gpu=False, spherical=False,
                    update_index=True, frozen_centroids=False,
                    lambda_=1.0):
        # d: dimensionality of data (n_features)
        # k: number of clusters (c in BKNC)
        self.d_features = d
        self.k = k  # c in BKNC
        self.niter = niter # Niter in BKNC
        self.lambda_ = lambda_ # lambda in BKNC
        self.seed = seed
        self.verbose = verbose

        # Other Faiss Kmeans parameters - not directly used by BKNC logic
        self.nredo = nredo
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.gpu = gpu # BKNC as implemented here is CPU-only
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        # BKNC specific results
        self.F_ = None          # The F matrix from BKNC (n_samples x k)
        self.R_ = None          # The R matrix (k x k)
        self.Y_ = None          # The Y matrix (one-hot labels, n_samples x k)
        self.labels_ = None     # Final cluster assignments (idx, shape: n_samples)
        self.obj_history_ = []  # History of the objective function trace(F'X_m'X_mF)
        self.final_obj_ = None
        self.runtime_ = 0

        # For compatibility with original FCFC structure
        self.centroids = None # Will be populated with cluster means
        self.obj = None # Can store obj_history_ here

        # Final metrics as requested
        self.sse_ = None
        self.balance_loss_ = None


    def _initialize_Y_bknc(self, n_samples, c):
        """
        Equivalent to MATLAB's init function for Y.
        Creates an n_samples x c one-hot encoded matrix from random labels.
        """
        # labels are 0 to c-1
        # This internal seeding should be fine as long as the main train method sets the overall seed.
        # If this method were called multiple times independently *within* one train call,
        # and expected different Ys, then it would need a different seeding strategy.
        # For now, it's called once per train.
        labels = np.random.randint(0, c, size=n_samples)
        Y = np.zeros((n_samples, c), dtype=int)
        Y[np.arange(n_samples), labels] = 1
        return Y

    def _calculate_cluster_centroids(self, data, labels, num_clusters, data_dim):
        """
        Calculates the mean of points in each cluster.
        data: (n_samples, n_features)
        labels: (n_samples,)
        num_clusters: k
        data_dim: d_features
        """
        centroids = np.zeros((num_clusters, data_dim))
        if labels is None: # Should not happen if called after labels are set
             if self.verbose:
                print("Warning: Labels are None in _calculate_cluster_centroids. Returning zero centroids.")
             return centroids

        for i in range(num_clusters):
            cluster_points = data[labels == i]
            if len(cluster_points) > 0:
                centroids[i] = np.mean(cluster_points, axis=0)
            else:
                if self.verbose:
                    print(f"Warning: Cluster {i} is empty during centroid calculation. Assigning a random data point as its centroid.")
                if len(data) > 0:
                    # Use a random number generator seeded by self.seed for consistent fallback
                    rng_fallback = np.random.RandomState(self.seed + i) # Add i for variety if multiple fallbacks
                    centroids[i] = data[rng_fallback.choice(len(data))]
                else: # No data points at all (edge case)
                    centroids[i] = np.zeros(data_dim)
        return centroids

    def train(self, x, weights=None, init_centroids=None):
        """
        Implements the BKNC algorithm.
        x: data matrix (n_samples, n_features)
        weights: Not used by BKNC.
        init_centroids: Not used by BKNC.
        """
        np.random.seed(self.seed) # Set seed for reproducibility for the entire train method
        start_time = time.time()

        if x.shape[1] != self.d_features:
            raise ValueError(f"Input data feature dimension {x.shape[1]} "
                             f"does not match class initialized dimension {self.d_features}")

        X_m = x.T  # (n_features, n_samples) - X_m is the MATLAB-like X
        n_features_internal, n_samples = X_m.shape # n_features_internal is self.d_features
        c = self.k # Number of clusters

        # Initialize Y (n_samples, c)
        # _initialize_Y_bknc uses np.random, which is now seeded by self.seed
        Y = self._initialize_Y_bknc(n_samples, c)

        # Initialize R (c, c) as a random orthogonal matrix
        # np.random.rand is also affected by the global seed set above
        R = orth(np.random.rand(c, c))

        obj_log = np.zeros(self.niter)

        # F_loop initialization is also seeded
        for iter_num in range(self.niter):
            F_loop = orth(np.random.rand(n_samples, c))
            G = Y @ R.T

            for _ in range(10):
                TempM_F = X_m @ F_loop
                M_calc_F = 2 * X_m.T @ TempM_F + self.lambda_ * G
                U_f, _, Vh_f = np.linalg.svd(M_calc_F, full_matrices=False)
                F_loop = U_f @ Vh_f
            F_current = F_loop

            N_calc_R = F_current.T @ Y
            U_r, _, Vh_r = np.linalg.svd(N_calc_R, full_matrices=False)
            R = U_r @ Vh_r

            P_calc_Y = R.T @ F_current.T
            idx = np.argmax(P_calc_Y, axis=0)
            Y = np.zeros((n_samples, c), dtype=int)
            Y[np.arange(n_samples), idx] = 1

            TempF_obj = X_m @ F_current
            obj_log[iter_num] = np.trace(TempF_obj.T @ TempF_obj)

            if self.verbose and (iter_num % 5 == 0 or iter_num == self.niter -1):
                print(f"Iter {iter_num+1}/{self.niter}, BKNC Obj: {obj_log[iter_num]:.4f}")

        self.runtime_ = time.time()

        # Store BKNC results
        self.F_ = F_current
        self.R_ = R
        self.Y_ = Y # This is the one-hot version of labels from the last iteration
        self.labels_ = idx # finalInd in MATLAB (0-indexed labels)
        self.obj_history_ = obj_log
        self.final_obj_ = obj_log[-1]
        self.obj = self.obj_history_ # Compatibility

        # --- Calculate final centroids, SSE, and Balance Loss ---
        # self.centroids are calculated based on original data `x` and final `self.labels_`
        self.centroids = self._calculate_cluster_centroids(x, self.labels_, self.k, self.d_features)

        # Calculate SSE
        current_sse = 0
        if self.labels_ is not None and self.centroids is not None:
            for i in range(n_samples):
                cluster_idx = self.labels_[i]
                point = x[i, :]
                centroid_val = self.centroids[cluster_idx, :]
                current_sse += np.sum((point - centroid_val)**2) # Squared Euclidean distance
        self.sse_ = current_sse

        # Calculate Balance Loss
        current_balance_loss = 0
        if self.labels_ is not None:
            cluster_sizes = np.bincount(self.labels_, minlength=self.k)
            ideal_size = n_samples / self.k
            current_balance_loss = np.sum((cluster_sizes - ideal_size)**2)
        self.balance_loss_ = current_balance_loss

        # Final runtime calculation
        self.runtime_ = time.time() - start_time # Corrected runtime calculation

        if self.verbose:
            print(f"BKNC training completed in {self.runtime_:.4f} seconds.")
            print(f"Final BKNC objective (trace): {self.final_obj_:.4f}")
            unique_labels_final, counts_final = np.unique(self.labels_, return_counts=True)
            print(f"Final cluster sizes: {dict(zip(unique_labels_final, counts_final))}")
            if self.centroids is not None:
                print(f"Shape of calculated centroids: {self.centroids.shape}")
            print(f"Final SSE: {self.sse_:.4f}")
            print(f"Final Balance Loss: {self.balance_loss_:.4f}")

class MyKMeans:
    def __init__(self, d, k, niter=25, nredo=1, verbose=False,
                    min_points_per_centroid=1, max_points_per_centroid=1e9,
                    seed=1234, gpu=False, spherical=False,
                    update_index=True, frozen_centroids=False,
                    lambda_=1.0):
        # d: dimensionality of data (n_features)
        # k: number of clusters (c in BKNC)
        self.d_features = d
        self.k = k  # c in BKNC
        self.niter = niter # Niter in BKNC
        self.lambda_ = lambda_ # lambda in BKNC
        self.lambda_reformed = (1-lambda_)/lambda_
        self.seed = seed
        self.verbose = verbose

        # Other Faiss Kmeans parameters - not directly used by BKNC logic
        self.nredo = nredo
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.gpu = gpu # BKNC as implemented here is CPU-only
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        # BKNC specific results
        self.F_ = None          # The F matrix from BKNC (n_samples x k)
        self.R_ = None          # The R matrix (k x k)
        self.Y_ = None          # The Y matrix (one-hot labels, n_samples x k)
        self.labels_ = None     # Final cluster assignments (idx, shape: n_samples)
        self.obj_history_ = []  # History of the objective function trace(F'X_m'X_mF)
        self.final_obj_ = None
        self.runtime_ = 0

        # For compatibility with original FCFC structure
        self.centroids = None # Will be populated with cluster means
        self.obj = None # Can store obj_history_ here

        # Final metrics as requested
        self.sse_ = None
        self.balance_loss_ = None


    def _initialize_Y_bknc(self, n_samples, c):
        """
        Equivalent to MATLAB's init function for Y.
        Creates an n_samples x c one-hot encoded matrix from random labels.
        """
        # labels are 0 to c-1
        # This internal seeding should be fine as long as the main train method sets the overall seed.
        # If this method were called multiple times independently *within* one train call,
        # and expected different Ys, then it would need a different seeding strategy.
        # For now, it's called once per train.
        labels = np.random.randint(0, c, size=n_samples)
        Y = np.zeros((n_samples, c), dtype=int)
        Y[np.arange(n_samples), labels] = 1
        return Y

    def _calculate_cluster_centroids(self, data, labels, num_clusters, data_dim):
        """
        Calculates the mean of points in each cluster.
        data: (n_samples, n_features)
        labels: (n_samples,)
        num_clusters: k
        data_dim: d_features
        """
        centroids = np.zeros((num_clusters, data_dim))
        if labels is None: # Should not happen if called after labels are set
             if self.verbose:
                print("Warning: Labels are None in _calculate_cluster_centroids. Returning zero centroids.")
             return centroids

        for i in range(num_clusters):
            cluster_points = data[labels == i]
            if len(cluster_points) > 0:
                centroids[i] = np.mean(cluster_points, axis=0)
            else:
                if self.verbose:
                    print(f"Warning: Cluster {i} is empty during centroid calculation. Assigning a random data point as its centroid.")
                if len(data) > 0:
                    # Use a random number generator seeded by self.seed for consistent fallback
                    rng_fallback = np.random.RandomState(self.seed + i) # Add i for variety if multiple fallbacks
                    centroids[i] = data[rng_fallback.choice(len(data))]
                else: # No data points at all (edge case)
                    centroids[i] = np.zeros(data_dim)
        return centroids

    def train(self, x, weights=None, init_centroids=None):
        """
        Implements the BKNC algorithm.
        x: data matrix (n_samples, n_features)
        weights: Not used by BKNC.
        init_centroids: Not used by BKNC.
        """
        np.random.seed(self.seed) # Set seed for reproducibility for the entire train method
        start_time = time.time()

        if x.shape[1] != self.d_features:
            raise ValueError(f"Input data feature dimension {x.shape[1]} "
                             f"does not match class initialized dimension {self.d_features}")

        X_m = x.T  # (n_features, n_samples) - X_m is the MATLAB-like X
        n_features_internal, n_samples = X_m.shape # n_features_internal is self.d_features
        c = self.k # Number of clusters

        # Initialize Y (n_samples, c)
        # _initialize_Y_bknc uses np.random, which is now seeded by self.seed
        Y = self._initialize_Y_bknc(n_samples, c)

        # Initialize R (c, c) as a random orthogonal matrix
        # np.random.rand is also affected by the global seed set above
        R = orth(np.random.rand(c, c))

        obj_log = np.zeros(self.niter)

        # F_loop initialization is also seeded
        for iter_num in range(self.niter):
            F_loop = orth(np.random.rand(n_samples, c))
            G = Y @ R.T

            for _ in range(10):
                TempM_F = X_m @ F_loop
                M_calc_F = 2 * X_m.T @ TempM_F + self.lambda_reformed * G
                U_f, _, Vh_f = np.linalg.svd(M_calc_F, full_matrices=False)
                F_loop = U_f @ Vh_f
            F_current = F_loop

            N_calc_R = F_current.T @ Y
            U_r, _, Vh_r = np.linalg.svd(N_calc_R, full_matrices=False)
            R = U_r @ Vh_r

            P_calc_Y = R.T @ F_current.T
            idx = np.argmax(P_calc_Y, axis=0)
            Y = np.zeros((n_samples, c), dtype=int)
            Y[np.arange(n_samples), idx] = 1

            TempF_obj = X_m @ F_current
            obj_log[iter_num] = np.trace(TempF_obj.T @ TempF_obj)

            if self.verbose and (iter_num % 5 == 0 or iter_num == self.niter -1):
                print(f"Iter {iter_num+1}/{self.niter}, BKNC Obj: {obj_log[iter_num]:.4f}")

        self.runtime_ = time.time()

        # Store BKNC results
        self.F_ = F_current
        self.R_ = R
        self.Y_ = Y # This is the one-hot version of labels from the last iteration
        self.labels_ = idx # finalInd in MATLAB (0-indexed labels)
        self.obj_history_ = obj_log
        self.final_obj_ = obj_log[-1]
        self.obj = self.obj_history_ # Compatibility

        # --- Calculate final centroids, SSE, and Balance Loss ---
        # self.centroids are calculated based on original data `x` and final `self.labels_`
        self.centroids = self._calculate_cluster_centroids(x, self.labels_, self.k, self.d_features)

        # Calculate SSE
        current_sse = 0
        if self.labels_ is not None and self.centroids is not None:
            for i in range(n_samples):
                cluster_idx = self.labels_[i]
                point = x[i, :]
                centroid_val = self.centroids[cluster_idx, :]
                current_sse += np.sum((point - centroid_val)**2) # Squared Euclidean distance
        self.sse_ = current_sse

        # Calculate Balance Loss
        current_balance_loss = 0
        if self.labels_ is not None:
            cluster_sizes = np.bincount(self.labels_, minlength=self.k)
            ideal_size = n_samples / self.k
            current_balance_loss = np.sum((cluster_sizes - ideal_size)**2)
        self.balance_loss_ = current_balance_loss

        # Final runtime calculation
        self.runtime_ = time.time() - start_time # Corrected runtime calculation

        if self.verbose:
            print(f"BKNC training completed in {self.runtime_:.4f} seconds.")
            print(f"Final BKNC objective (trace): {self.final_obj_:.4f}")
            unique_labels_final, counts_final = np.unique(self.labels_, return_counts=True)
            print(f"Final cluster sizes: {dict(zip(unique_labels_final, counts_final))}")
            if self.centroids is not None:
                print(f"Shape of calculated centroids: {self.centroids.shape}")
            print(f"Final SSE: {self.sse_:.4f}")
            print(f"Final Balance Loss: {self.balance_loss_:.4f}")

# Helper functions (moved outside the class, or could be static methods)
def initial_Y(X, c, rep, way="random"):
        N = X.shape[0]
        Y = np.zeros((rep, N), dtype=np.int32)

        if way == "random":
            for rep_i in range(rep):
                Y[rep_i] = np.random.randint(0, c, N)

        elif way == "k-means++":
            for rep_i in range(rep):
                Y[rep_i] = KMeans(n_clusters=c, init="k-means++", n_init=1, max_iter=1).fit(X).labels_

        else:
            assert 2 == 1

        return Y
def one_hot(y: np.ndarray, k: int):
    n = len(y)
    Y = np.zeros((n, k), dtype=np.float32)
    Y[np.arange(n), y] = 1.0
    return Y
def compute_cluster_centers_cdkm(X, Y):
    """
    X: (n, d)
    Y: list of cluster label arrays, each of shape (n,)
    """
    y = Y[0]  # shape (n,)
    n, k = X.shape[0], np.max(y) + 1
    Y0 = np.zeros((n, k), dtype=np.float64)
    Y0[np.arange(n), y] = 1.0  # one-hot

    weights = np.sum(Y0, axis=0)  # (k,)
    weights[weights == 0] = 1e-10

    centers = (Y0.T @ X) / weights[:, None]
    return centers


def get_centroid(data, label, K, n, d_features):
    """
    Update centroids after the assignment phase.
    data: (n, d_features)
    label: (n,)
    K: number of clusters
    n: number of samples
    d_features: number of features
    """
    centroids = np.zeros((K, d_features))
    for k_idx in range(K):
        members = (label == k_idx)
        if np.any(members):
            # Np.sum on boolean array members gives count of True values
            centroids[k_idx, :] = np.sum(data[members, :], axis=0) / np.sum(members)
        else:
            # Handle empty cluster: assign a random point from data
            # This random choice is now affected by the seed set in train()
            if n > 0 : # Ensure data is not empty
                 centroids[k_idx, :] = data[np.random.choice(n), :]
            # else: centroid remains zeros if data is empty (edge case)
    return centroids


def get_distance(data, centroids, K, n, d_features, size_cluster, lambda_param):
    """
    Objective function term for assignment:
    D(i,j) = distance(i-th data point, j-th centroid)^2 + lambda_param * size_of_jth_cluster
    data: (n, d_features)
    centroids: (K, d_features)
    size_cluster: (K,) - current size of each cluster
    lambda_param: balance weight
    Returns: D_matrix (n, K)
    """
    D_matrix = np.zeros((n, K))
    for k_idx in range(K):
        # Squared Euclidean distance
        dist_sq = np.sum((data - centroids[k_idx, :])**2, axis=1)
        D_matrix[:, k_idx] = dist_sq + lambda_param * size_cluster[k_idx]
    return D_matrix


def initial_centroid(x_data, K, n_samples):
    """
    Initialize centroids randomly by choosing K unique points from the data.
    x_data: (n_samples, d_features)
    K: number of clusters
    n_samples: number of samples
    """
    if K > n_samples:
        raise ValueError("K (number of clusters) cannot be greater than n_samples.")
    # This random choice is now affected by the seed set in train()
    indices = np.random.choice(n_samples, K, replace=False)
    return x_data[indices, :]

# 优化后的数据加载函数
def load_data_chunked(path, dtype='float32', chunksize=1000):
    """分块加载大数据集避免内存溢出"""
    chunks = []
    for chunk in pd.read_csv(path, header=None, chunksize=chunksize):
        chunks.append(chunk.astype(dtype))
    return np.concatenate(chunks, axis=0)

def run_experiment(model_class, model_name, dataset_path, dimensions, n_clusters, n_runs=10):
    """运行实验，返回sse和balance_loss列表"""
    try:
        X_data = load_data_chunked(dataset_path)
        sse_list = []
        balance_loss_list = []

        for run in range(n_runs):
            if model_name == "FCFC":
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)
            elif model_name == "BCLS":
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)
            elif model_name == "Lloyd":
                model = model_class(d=dimensions, k=n_clusters, niter=10, seed=1234+run, verbose=False)
            elif model_name == "CDKM":
                model = model_class(d=dimensions, k=n_clusters, niter=10, seed=1234+run, verbose=False)
            elif model_name == "BKNC":
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)
            else: # MyKMeans
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)

            model.train(X_data) if hasattr(model, 'train') else model.fit(X_data)

            sse_list.append(model.sse_)
            # 这里假设所有模型都有 balance_loss_ 属性，或你可用 getattr(model, 'balance_loss_', np.nan) 容错
            balance_loss_list.append(getattr(model, 'balance_loss_', np.nan))

        print(f"sse_list: {sse_list}")
        print(f"balance_loss_list: {balance_loss_list}")
        return sse_list, balance_loss_list

    finally:
        if 'X_data' in locals(): del X_data
        if 'model' in locals(): del model
        import gc
        gc.collect()


if __name__ == '__main__':
    # 定义要测试的模型和数据集
    models = [
        (Lloyd, "Lloyd"),
        (CDKM, "CDKM"),
        (BCLS, "BCLS"),
        (FCFC, "FCFC"),
        (BKNC, "BKNC"),
        (MyKMeans, "MyKMeans")
    ]

    # 定义数据集路径和对应的维度
    datasets = [
        ("/content/sample_data/Huatuo_1024d_10k.csv", 1024),
        ("/content/sample_data/LiveChat_1024d_10k.csv", 1024),
        ("/content/sample_data/deep_96d_10k.csv", 96),
        ("/content/sample_data/glove_300d_10k.csv", 300),
        ("/content/sample_data/sift_128d_10k.csv", 128)
    ]
    k = 5

    n_runs = 10  # 运行次数

    # 结果DataFrame行数 = datasets数 * models数
    result_rows = []
    for dataset_path, dim in datasets:
        dataset_name = Path(dataset_path).stem
        for model_class, model_name in models:
            print(f"Running {model_name} on {dataset_name}...")
            sse_list, balance_loss_list = run_experiment(model_class, model_name, dataset_path, dim, k, n_runs=n_runs)

            sse_mean = np.mean(sse_list)
            sse_var = np.var(sse_list)
            balance_mean = np.mean(balance_loss_list)
            balance_var = np.var(balance_loss_list)

            result_rows.append({
                'Dataset': dataset_name,
                'Model': model_name,
                'SSE_Mean': sse_mean,
                'SSE_Var': sse_var,
                'BalanceLoss_Mean': balance_mean,
                'BalanceLoss_Var': balance_var,
            })

            print(f"SSE_Mean: {sse_mean}")
            print(f"SSE_Var: {sse_var}")
            print(f"BalanceLoss_Mean: {balance_mean}")
            print(f"BalanceLoss_Var: {balance_var}")

    results_df = pd.DataFrame(result_rows)

    results_df.to_csv('/content/sample_data/metrics_summary.csv', index=False)
    print("\nExperiment results saved to 'metrics_summary.csv'")

Running Lloyd on Huatuo-1024d_10k...
sse_list: [np.float32(2667.1853), np.float32(2658.881), np.float32(2664.398), np.float32(2654.715), np.float32(2666.353), np.float32(2654.886), np.float32(2663.17), np.float32(2687.1843), np.float32(2654.109), np.float32(2671.2185)]
balance_loss_list: [np.float64(1718538.7999999998), np.float64(3395060.8), np.float64(1823978.8), np.float64(2988678.8), np.float64(2776648.8), np.float64(1680162.8), np.float64(1018210.8), np.float64(8414936.8), np.float64(2122044.8), np.float64(9054240.8)]
SSE_Mean: 2664.210205078125
SSE_Var: 89.90210723876953
BalanceLoss_Mean: 3499250.2
BalanceLoss_Var: 7315438422508.842
Running CDKM on Huatuo-1024d_10k...
sse_list: [np.float64(2649.1871792725756), np.float64(2649.187062892103), np.float64(2655.154617549748), np.float64(2659.991362006629), np.float64(2657.138125999354), np.float64(2663.4083936770485), np.float64(2652.1866709480714), np.float64(2664.1451044247283), np.float64(2653.003185827281), np.float64(2652.4251577

In [None]:
import pandas as pd
import numpy as np

# 1️⃣ 读 CSV
df = pd.read_csv('/content/sample_data/metrics_summary.csv')

# 2️⃣ 提取原始 Dataset 和 Model 的顺序
df['Dataset_raw'] = df['Dataset'].str.split('-').str[0].str.split('_').str[0]
df['Model_raw'] = df['Model']

# 3️⃣ 替换 Dataset 和 Model 名
df['Dataset'] = df['Dataset_raw'].replace({
    'deep': 'Deep',
    'sift': 'SIFT',
    'glove': 'GloVe'
})
df['Model'] = df['Model_raw'].replace({'MyKMeans': 'Tub-means'})

# 4️⃣ 用 CSV 中 Dataset 出现顺序定义 Categorical
dataset_order = df.drop_duplicates('Dataset_raw')['Dataset'].tolist()
df['Dataset'] = pd.Categorical(df['Dataset'], categories=dataset_order, ordered=True)

# 5️⃣ 保留原行顺序标记（保证 Model 顺序）
df['row_order'] = np.arange(len(df))

# 6️⃣ 按 Dataset、原行号 排序
df = df.sort_values(['Dataset', 'row_order']).drop(columns=['row_order', 'Dataset_raw', 'Model_raw'])

# ✅ 从这里开始是你原来的生成表格逻辑
def format_row(dataset, model, sse_mean, sse_std, balance_mean, balance_std, last=False, multirow=False, multirow_count=None):
    sse_mean_str = "{:.2E}".format(sse_mean)
    sse_std_str = "{:.2E}".format(sse_std)
    balance_mean_str = "{:.2E}".format(balance_mean)
    balance_std_str = "{:.2E}".format(balance_std)

    if multirow:
        row = f"\\multirow{{{multirow_count}}}{{*}}{{\\ {dataset}}}\n& \\ {model} & {sse_mean_str} & {sse_std_str} && {balance_mean_str} & {balance_std_str} \\\\"
    elif last:
        row = f"& \\ {model} & {sse_mean_str} & {sse_std_str} && {balance_mean_str} & {balance_std_str} \\\\ \\midrule \n"
    else:
        row = f"& \\ {model} & {sse_mean_str} & {sse_std_str} && {balance_mean_str} & {balance_std_str} \\\\"
    return row

# 7️⃣ 分块生成
lines = []
for dataset, group in df.groupby('Dataset'):
    models = group['Model'].tolist()
    sse_mean = group['SSE_Mean'].tolist()
    sse_std = group['SSE_Var'].tolist()
    balance_mean = group['BalanceLoss_Mean'].tolist()
    balance_std = group['BalanceLoss_Var'].tolist()

    for i in range(len(models)):
        is_first = (i == 0)
        is_last = (i == len(models) - 1)  # 注意这里用实际长度
        line = format_row(dataset, models[i], sse_mean[i], sse_std[i], balance_mean[i], balance_std[i],
                          last=is_last, multirow=is_first, multirow_count=len(models) if is_first else None)
        lines.append(line)

# 8️⃣ 输出
latex_table_body = "\n".join(lines)
print(latex_table_body)

with open("latex_table_rows.tex", "w") as f:
    f.write(latex_table_body)


\multirow{6}{*}{\ Huatuo}
& \ Lloyd & 2.66E+03 & 8.99E+01 && 3.50E+06 & 7.32E+12 \\
& \ CDKM & 2.66E+03 & 2.67E+01 && 2.47E+06 & 1.92E+12 \\
& \ BCLS & 3.13E+03 & 2.28E+00 && 3.86E+05 & 6.39E+10 \\
& \ FCFC & 3.14E+03 & 0.00E+00 && 8.00E+07 & 6.22E-16 \\
& \ BKNC & 2.67E+03 & 5.19E+01 && 5.37E+05 & 1.34E+10 \\
& \ Tub-means & 3.11E+03 & 2.56E+02 && 4.62E+05 & 3.70E+10 \\ \midrule 

\multirow{6}{*}{\ LiveChat}
& \ Lloyd & 1.93E+03 & 2.37E+01 && 4.48E+06 & 5.04E+12 \\
& \ CDKM & 1.93E+03 & 4.99E-01 && 3.73E+06 & 7.43E+11 \\
& \ BCLS & 2.03E+03 & 1.58E-01 && 4.72E+05 & 5.40E+10 \\
& \ FCFC & 2.04E+03 & 2.07E-25 && 8.00E+07 & 3.33E-16 \\
& \ BKNC & 1.94E+03 & 2.14E+01 && 1.13E+06 & 7.25E+10 \\
& \ Tub-means & 2.04E+03 & 4.27E+00 && 1.72E+04 & 7.41E+08 \\ \midrule 

\multirow{6}{*}{\ Deep}
& \ Lloyd & 8.12E+03 & 2.51E+02 && 1.29E+06 & 2.76E+11 \\
& \ CDKM & 8.11E+03 & 5.09E+02 && 1.45E+06 & 1.88E+12 \\
& \ BCLS & 8.90E+03 & 9.15E+03 && 1.62E+07 & 4.63E+13 \\
& \ FCFC & 9.31E+03 & 3.31E-24 &

  for dataset, group in df.groupby('Dataset'):


In [None]:
!pip install faiss-cpu
import numpy as np
import pandas as pd
import faiss
import time
from pathlib import Path
import gc
from scipy.linalg import orth # For creating orthogonal matrices

class FCFC:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                    min_points_per_centroid=1, max_points_per_centroid=1000000000,
                    seed=1234, gpu=False, spherical=False,
                    update_index=True, frozen_centroids=False,
                    lambda_=1.0):
        # d: dimensionality of data
        # k: number of clusters
        self.d_features = d # Renamed for clarity, consistent with other classes
        self.k = k
        self.niter = niter
        self.max_iter = niter # Keep for consistency with existing loop

        # Other parameters (some might not be used by this specific FCFC logic but kept for interface)
        self.nredo = nredo
        self.verbose = verbose
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.seed = seed
        self.gpu = gpu # This FCFC implementation is CPU-based
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids
        self.lambda_ = lambda_  # Balance parameter for the objective function in get_distance

        # Results storage
        self.centroids = None               # Final cluster centroids (k, d_features)
        self.labels_ = None                 # Final cluster assignments for each point (n_samples,)
        self.runtime_ = None                # Total training time

        self.objective_history_ = None      # History of sum_dis (sum of D(i,j) values)
        self.sse_history_ = None            # History of Sum of Squared Errors per iteration
        self.balance_loss_history_ = None   # History of Balance Loss per iteration

        self.final_objective_ = None        # Final value from objective_history_
        self.final_sse_ = None              # Final Sum of Squared Errors
        self.final_balance_loss_ = None     # Final Balance Loss
        self.final_cluster_sizes_ = None    # Final size of each cluster (k,)

        self.sse_ = 0
        self.balance_loss_ = 0
        self.normalized_entropy_ = None  # Normalized entropy
        self.cluster_variance_ = None    # Cluster size variance


        # For compatibility, self.obj can point to the primary objective history
        self.obj = None


    def train(self, x, weights=None, init_centroids_arg=None): # Renamed init_centroids to avoid conflict
        np.random.seed(self.seed) # Set seed for reproducibility
        start_time = time.time()

        K = self.k
        if x.shape[1] != self.d_features:
            raise ValueError(f"Input data feature dimension {x.shape[1]} "
                             f"does not match class initialized dimension {self.d_features}")
        n_samples, d_features_data = x.shape # n_samples, d in original code

        # Initialize arrays for storing per-iteration metrics
        # sse_history stores the traditional SSE
        # balance_loss_history stores the balance penalty
        # objective_value_history stores the sum of D(i,j) which is being minimized directly
        sse_history = np.zeros(self.max_iter)
        balance_loss_history = np.zeros(self.max_iter)
        objective_value_history = np.zeros(self.max_iter) # Corresponds to pre_dis

        # Initialize centroids
        # If init_centroids_arg is provided, use it, otherwise use random initialization
        if init_centroids_arg is not None:
            if init_centroids_arg.shape != (K, d_features_data):
                raise ValueError(f"Provided init_centroids shape {init_centroids_arg.shape} "
                                 f"is not ({K}, {d_features_data})")
            current_centroids = np.copy(init_centroids_arg)
        else:
            current_centroids = initial_centroid(x, K, n_samples) # Uses np.random internally

        # size_cluster is 1*K vector, stores size of each cluster for the get_distance objective
        # Initialized to ones to avoid issues if lambda_ > 0 and a cluster is initially empty,
        # though it gets updated immediately in the first iteration.
        # A more common initialization might be n_samples/K or based on initial assignment.
        # Let's base it on an initial quick assignment or n_samples/K to be more robust.
        # For simplicity of matching the provided code, it starts with ones and is quickly updated.
        current_size_cluster = np.ones(K) # Will be updated after first assignment

        current_labels = np.zeros(n_samples, dtype=int) # To store labels for each point

        for i in range(self.max_iter):
            # Assignment step: Calculate D matrix and assign points to clusters
            # D(point, cluster_j) = distance(point, centroid_j)^2 + lambda * size_cluster_j
            D_matrix = get_distance(x, current_centroids, K, n_samples, d_features_data,
                                    current_size_cluster, self.lambda_)

            min_dist_to_centroid_plus_balance = np.min(D_matrix, axis=1) # (n_samples,)
            assigned_labels = np.argmin(D_matrix, axis=1)           # (n_samples,)
            sum_objective_values = np.sum(min_dist_to_centroid_plus_balance)

            current_labels = assigned_labels
            objective_value_history[i] = sum_objective_values

            # Update step: Recalculate centroids and cluster sizes
            # current_size_cluster is based on the new assignments
            current_size_cluster = np.bincount(current_labels, minlength=K)
            current_centroids = get_centroid(x, current_labels, K, n_samples, d_features_data)

            # Calculate SSE and Balance Loss for this iteration (for monitoring)
            iter_sse = 0
            iter_balance_penalty_terms = np.zeros(K)

            for j in range(K):
                cluster_points = x[current_labels == j, :]
                if cluster_points.shape[0] > 0: # If cluster is not empty
                    # SSE part: sum of squared distances to its actual centroid
                    iter_sse += np.sum(np.sum((cluster_points - current_centroids[j, :])**2, axis=1))
                # Balance loss part (using current_size_cluster which is already updated)
                iter_balance_penalty_terms[j] = (current_size_cluster[j] - n_samples / K)**2

            sse_history[i] = iter_sse
            balance_loss_history[i] = np.sum(iter_balance_penalty_terms)

            if self.verbose and (i % 5 == 0 or i == self.max_iter -1) :
                print(f"Iter {i+1}/{self.max_iter}: Objective={objective_value_history[i]:.4f}, "
                      f"SSE={sse_history[i]:.4f}, BalanceLoss={balance_loss_history[i]:.4f}")

        self.runtime_ = time.time() - start_time

        # Store results
        self.centroids = current_centroids
        self.labels_ = current_labels
        self.final_cluster_sizes_ = current_size_cluster

        self.objective_history_ = objective_value_history
        self.sse_history_ = sse_history
        self.balance_loss_history_ = balance_loss_history

        self.final_objective_ = objective_value_history[-1]
        self.final_sse_ = sse_history[-1]
        self.final_balance_loss_ = balance_loss_history[-1]

        self.obj = self.objective_history_ # Storing the history of the optimized objective

        self.sse_ = self.final_sse_
        self.balance_loss_ = self.final_balance_loss_

        # Entropy & CV calculation
        eps = 1e-10
        l = n_samples
        c = self.k
        size0 = current_size_cluster

        entro = 0.0
        cv = 0.0
        for jj in range(c):
            frac = size0[jj] / l
            entro += frac * np.log((size0[jj]+eps)/l)
            cv += np.sqrt((size0[jj] - l / c) ** 2)

        self.normalized_entropy_ = -entro / np.log(c)
        self.cluster_variance_ = (c / l) * cv


        if self.verbose:
            print(f"FCFC training completed in {self.runtime_:.4f} seconds.")
            print(f"Final Objective (sum of D(i,j)): {self.final_objective_:.4f}")
            print(f"Final SSE: {self.final_sse_:.4f}")
            print(f"Final Balance Loss: {self.final_balance_loss_:.4f}")
            print(f"Final cluster sizes: {self.final_cluster_sizes_}")

class Lloyd:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                 min_points_per_centroid=1, max_points_per_centroid=1000000000,
                 seed=1234, gpu=False, spherical=False,
                 update_index=True, frozen_centroids=False):
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo
        self.verbose = verbose
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.seed = seed
        self.gpu = gpu
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        self.centroids = None
        self.obj_history_ = None
        self.labels_ = None

        self.sse_ = None
        self.balance_loss_ = None
        self.normalized_entropy_ = None  # Normalized entropy
        self.cluster_variance_ = None    # Cluster size variance
        self.runtime_ = None
        self.obj = None

    def compute_centroids_from_data(self, data_points, labels, num_clusters, data_dim):
        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        if labels is None:
            if self.verbose:
                print("Warning: Labels are None in compute_centroids_from_data.")
            return centroids

        for i in range(len(data_points)):
            label = labels[i]
            centroids[label] += data_points[i]
            counts[label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning: Cluster {j} is empty. Assigning random point.")
                centroids[j] = data_points[np.random.randint(len(data_points))]

        return centroids

    def train(self, x_orig_data, weights=None, init_centroids=None):
        start_time = time.time()
        np.random.seed(self.seed)

        if x_orig_data.shape[1] != self.d:
            raise ValueError(f"Input dimension {x_orig_data.shape[1]} != {self.d}")

        n, dim = x_orig_data.shape
        x = np.ascontiguousarray(x_orig_data, dtype='float32')

        kmeans = faiss.Kmeans(
            d=self.d,
            k=self.k,
            niter=self.niter,
            nredo=self.nredo,
            verbose=self.verbose,
            min_points_per_centroid=self.min_points_per_centroid,
            max_points_per_centroid=self.max_points_per_centroid,
            seed=self.seed,
            gpu=self.gpu,
            spherical=self.spherical,
            update_index=self.update_index,
            frozen_centroids=self.frozen_centroids
        )

        kmeans.train(x, init_centroids=init_centroids)

        _, self.labels_ = kmeans.index.search(x, 1)
        self.labels_ = self.labels_.flatten()

        self.centroids = kmeans.centroids
        self.obj_history_ = kmeans.obj if kmeans.obj is not None and len(kmeans.obj) > 0 else np.zeros(self.niter)
        self.obj = kmeans.obj[-1] if kmeans.obj is not None and len(kmeans.obj) > 0 else None
        self.runtime_ = time.time() - start_time


        # Print every 5th iteration's objective value
        if self.verbose and self.obj_history_ is not None and len(self.obj_history_) > 0:
            print("\n--- Objective Value (every 5 iterations) ---")
            for i, val in enumerate(self.obj_history_):
                if (i + 1) % 5 == 0 or i == len(self.obj_history_) - 1:
                    print(f"  Iter {i+1:2d}: {val:.6f}")

        final_sse = 0
        for i in range(n):
            cluster_idx = self.labels_[i]
            final_sse += np.sum((x_orig_data[i] - self.centroids[cluster_idx]) ** 2)
        self.sse_ = final_sse

        final_balance_loss = 0
        if self.labels_ is not None:
            sizes = np.bincount(self.labels_, minlength=self.k)
            ideal = n / self.k
            final_balance_loss = np.sum((sizes - ideal) ** 2)
        self.balance_loss_ = final_balance_loss

        # Entropy & CV calculation
        eps = 1e-10
        l = n
        c = self.k
        size0 = np.bincount(self.labels_, minlength=c)

        entro = 0.0
        cv = 0.0
        for jj in range(c):
            frac = size0[jj] / l
            entro += frac * np.log((size0[jj]+eps)/l)
            cv += np.sqrt((size0[jj] - l / c) ** 2)

        self.normalized_entropy_ = -entro / np.log(c)
        self.cluster_variance_ = (c / l) * cv

        if self.verbose:
            print(f"Lloyd training finished in {self.runtime_:.4f}s")
            print(f"Final obj: {self.obj}")
            print(f"Cluster sizes: {dict(zip(*np.unique(self.labels_, return_counts=True)))}")
            print(f"SSE: {self.sse_:.4f}")
            print(f"Balance Loss: {self.balance_loss_:.4f}")

class BCLS:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                 min_points_per_centroid=1, max_points_per_centroid=1e9,
                 seed=1234, gpu=False, spherical=False,
                 update_index=True, frozen_centroids=False,
                 lambda_=1.0): # lambda_ from FCFC, but BCLS uses 'lam' internally
        # d: dimensionality of data
        # k: number of clusters
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo # Not used by BCLS algorithm itself
        self.verbose = verbose
        # The following Faiss-like parameters are not directly used by BCLS's core logic:
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.gpu = gpu
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        self.seed = seed
        self.lambda_bcls = lambda_ # BCLS specific lambda for sum_Y term in objective
                                  # If the lambda_ parameter was meant for this, it's used as 'lam' below.
                                  # If it was for something else, then 'lam' needs its own source.
                                  # Assuming lambda_ is the 'lam' for BCLS objective.

        # Results storage
        self.centroids = None       # Will store centroids in original data space
        self.obj_history_ = None    # Stores Obj2 from the loop
        self.labels_ = None         # Final cluster assignments (0-indexed)
        self.Y_final_ = None        # Final Y matrix (one-hot indicators)

        # Final metrics
        self.sse_ = None
        self.balance_loss_ = None
        self.normalized_entropy_ = None  # Normalized entropy
        self.cluster_variance_ = None    # Cluster size variance
        self.runtime_ = None

        # For compatibility with previous structure if any part expects 'obj'
        self.obj = None


    def init1(self, n_samples, num_clusters):
        """
        Initializes the Y matrix (n_samples x num_clusters) with one-hot encoding.
        Labels are 1 to num_clusters, then converted to 0-indexed for Python.
        """
        # np.random is affected by self.seed if set before calling train
        labels_1_indexed = np.random.randint(1, num_clusters + 1, size=n_samples)
        F = np.zeros((n_samples, num_clusters))
        F[np.arange(n_samples), labels_1_indexed - 1] = 1
        # F = csr_matrix(F) # Can be sparse if n and k are very large
        return F

    def compute_centroids_from_data(self, data_points, labels, num_clusters, data_dim):
        """
        Calculates centroids from data points and their labels.
        data_points: (n_samples, n_features) - original or centered
        labels: (n_samples,) - 0-indexed
        num_clusters: k
        data_dim: d
        Returns: (num_clusters, data_dim) centroids
        """
        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        if labels is None:
            if self.verbose:
                print("Warning: Labels are None in compute_centroids_from_data. Returning zero centroids.")
            return centroids

        for i in range(len(data_points)):
            label = labels[i]
            centroids[label] += data_points[i, :]
            counts[label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning: Cluster {j} is empty. Assigning a random data point as its centroid.")
                if len(data_points) > 0:
                    # Seed this random choice for consistency if multiple empty clusters
                    rng_empty_fallback = np.random.RandomState(self.seed + j + 1000) # Offset seed
                    centroids[j] = data_points[rng_empty_fallback.choice(len(data_points)), :]
                # else: centroids[j] remains zeros
        return centroids


    def train(self, x_orig_data, weights=None, init_centroids=None): # x_orig_data is n x dim
        np.random.seed(self.seed) # Ensure reproducibility for operations within train
        start_time = time.time()

        if x_orig_data.shape[1] != self.d:
            raise ValueError(f"Input data feature dimension {x_orig_data.shape[1]} "
                             f"does not match class initialized dimension {self.d}")

        ITER = self.niter
        # BCLS Algorithm Hyperparameters (taken from the provided snippet)
        gamma = 0.00001  # Regularization for W
        lam = self.lambda_bcls # Controls balance term in objective (sum_Y**2)
        mu = 0.01        # ALM parameter

        n, dim = x_orig_data.shape
        c = self.k  # number of clusters

        # Initialize Y
        Y = self.init1(n, c) # Y is n x c

        # Center the data (BCLS works with centered data)
        meanX = np.mean(x_orig_data, axis=0, keepdims=True) # 1 x dim
        x_centered = x_orig_data - meanX # n x dim

        # ALM variables
        Lambda_alm = np.zeros((n, c)) # Lagrange multipliers for Y - Z = 0
        rho = 1.005                # Update factor for mu

        # Precompute part of W update
        # P_inv = x_centered.T @ x_centered + gamma * np.eye(dim)
        # P = np.linalg.inv(P_inv)
        # Using pseudo-inverse for potentially better stability if P_inv is singular/ill-conditioned
        try:
            P = np.linalg.inv(x_centered.T @ x_centered + gamma * np.eye(dim))
        except np.linalg.LinAlgError:
            if self.verbose:
                print("Warning: Standard inverse failed for P. Using pseudo-inverse.")
            P = np.linalg.pinv(x_centered.T @ x_centered + gamma * np.eye(dim))


        obj_history = np.zeros(ITER)
        # Optional: if you want to track SSE/BalanceLoss per iteration (on centered data)
        # sse_iter_history = np.zeros(ITER)
        # balance_loss_iter_history = np.zeros(ITER)


        for iter_idx in range(ITER):
            # --- Solve W and b ---
            # W: dim x c, b: 1 x c
            W = P @ (x_centered.T @ Y)
            b = np.mean(Y, axis=0, keepdims=True) # Or (1/n) * (np.ones((1,n)) @ Y)

            # E = XW + 1b' - Y (Error term for reconstruction using centered X)
            # E_recon: n x c
            E_recon = x_centered @ W + np.ones((n, 1)) @ b - Y

            # --- Solve Z (auxiliary variable for Y) ---
            # Z: n x c
            # Denominator matrix for Z update:
            # Factor = mu**2 + 2 * n * lam * mu  (scalar)
            # Coeff_matrix_inv = (-2 * lam * np.ones((n,n)) + (mu + 2 * n * lam) * np.eye(n)) / Factor
            # Z = Coeff_matrix_inv @ (mu * Y + Lambda_alm)
            # Simpler if Z is updated element-wise or if structure allows.
            # The provided formula for Z seems like a direct solution from a specific formulation.
            # Let's assume the formula is correct as given:
            # Note: (mu**2 + 2 * n * lam * mu) is a scalar.
            # The matrix to invert for Z is effectively ( (mu + 2*n*lam)*I - 2*lam*J ), where J is all-ones matrix.
            # This matrix has a specific inverse (Sherman-Woodbury).
            # For now, using the provided direct calculation:
            mat_for_Z_inv_num = -2 * lam * np.ones((n, n)) + (mu + 2 * n * lam) * np.eye(n)
            mat_for_Z_inv_den = (mu**2 + 2 * n * lam * mu)
            if np.abs(mat_for_Z_inv_den) < 1e-9: # Avoid division by zero
                 if self.verbose: print(f"Warning: Denominator for Z is near zero at iter {iter_idx}")
                 Z = Y # Fallback or handle error
            else:
                 Z = (mat_for_Z_inv_num / mat_for_Z_inv_den) @ (mu * Y + Lambda_alm)


            # --- Solve Y (indicator matrix) ---
            # V: n x c
            V_update = (1 / (2 + mu)) * (2 * x_centered @ W + 2 * np.ones((n, 1)) @ b + mu * Z - Lambda_alm)

            # Update Y by selecting the max element in each row of V_update
            current_labels = np.argmax(V_update, axis=1) # n-element array of 0-indexed labels
            Y = np.zeros((n, c))
            Y[np.arange(n), current_labels] = 1

            # --- Update Lambda (Lagrange multipliers) and mu (penalty parameter) for ALM ---
            Lambda_alm = Lambda_alm + mu * (Y - Z)
            mu = min(mu * rho, 1e5) # Cap mu to avoid very large values

            # --- Calculate Objective Value (for centered data) ---
            sum_Y_elements = np.sum(Y) # Sum of all elements in Y (should be n if Y is strictly one-hot)
            obj_history[iter_idx] = np.trace(E_recon.T @ E_recon) + \
                                    gamma * np.trace(W.T @ W) + \
                                    lam * (sum_Y_elements**2) # Or lam * np.sum( (np.sum(Y, axis=0) - n/c)**2 ) if balance is per cluster size


            # --- In-loop SSE and Balance Loss (on centered data, for monitoring if needed) ---
            # These are calculated based on current Y and centered data.
            # Centroids for centered data: c x dim
            temp_centroids_centered = self.compute_centroids_from_data(x_centered, current_labels, c, dim)

            sse_iter = 0
            for i in range(n):
                cluster_idx = current_labels[i]
                # Using np.sum for squared norm for clarity with dimensions
                sse_iter += np.sum((x_centered[i, :] - temp_centroids_centered[cluster_idx, :])**2)
            # sse_iter_history[iter_idx] = sse_iter

            cluster_sizes_iter = np.sum(Y, axis=0) # n_elements per cluster (1 x c)
            ideal_size_iter = n / c
            balance_loss_iter = np.sum((cluster_sizes_iter - ideal_size_iter)**2)
            # balance_loss_iter_history[iter_idx] = balance_loss_iter

            if self.verbose and (iter_idx % 10 == 0 or iter_idx == ITER -1):
                print(f"Iter {iter_idx+1}/{ITER}, BCLS Obj: {obj_history[iter_idx]:.4f}, "
                      f"Iter SSE (centered): {sse_iter:.2f}, Iter Bal (centered): {balance_loss_iter:.2f}")


        # --- End of iterations ---
        self.runtime_ = time.time() - start_time

        # Store final results
        self.labels_ = np.argmax(Y, axis=1) # Final 0-indexed labels
        self.Y_final_ = Y                   # Final one-hot indicator matrix
        self.obj_history_ = obj_history
        self.obj = obj_history # Compatibility

        # Calculate final centroids in ORIGINAL data space
        # Use x_orig_data and self.labels_
        final_centroids_orig_space = self.compute_centroids_from_data(x_orig_data, self.labels_, c, dim)
        self.centroids = final_centroids_orig_space # Store k x dim centroids

        # Calculate final SSE using ORIGINAL data and ORIGINAL space centroids
        final_sse = 0
        if self.labels_ is not None and self.centroids is not None:
            for i in range(n):
                cluster_idx = self.labels_[i]
                point_orig = x_orig_data[i, :]
                centroid_orig = self.centroids[cluster_idx, :]
                final_sse += np.sum((point_orig - centroid_orig)**2)
        self.sse_ = final_sse

        # Calculate final Balance Loss
        final_balance_loss = 0
        if self.labels_ is not None:
            final_cluster_sizes = np.bincount(self.labels_, minlength=c)
            ideal_size = n / c
            final_balance_loss = np.sum((final_cluster_sizes - ideal_size)**2)
        self.balance_loss_ = final_balance_loss

        # Entropy & CV calculation
        eps = 1e-10
        l = n
        c = self.k
        size0 = np.bincount(self.labels_, minlength=c)

        entro = 0.0
        cv = 0.0
        for jj in range(c):
            frac = size0[jj] / l
            entro += frac * np.log((size0[jj]+eps)/l)
            cv += np.sqrt((size0[jj] - l / c) ** 2)

        self.normalized_entropy_ = -entro / np.log(c)
        self.cluster_variance_ = (c / l) * cv

        if self.verbose:
            print(f"BCLS training completed in {self.runtime_:.4f} seconds.")
            print(f"Final BCLS objective value: {self.obj_history_[-1]:.4f}")
            unique_labels_final, counts_final = np.unique(self.labels_, return_counts=True)
            print(f"Final cluster sizes: {dict(zip(unique_labels_final, counts_final))}")
            if self.centroids is not None:
                print(f"Shape of final centroids (original space): {self.centroids.shape}")
            print(f"Final SSE (original space): {self.sse_:.4f}")
            print(f"Final Balance Loss: {self.balance_loss_:.4f}")


    def compute_centroids(self, x_transposed, F_indicator):
        """
        Computes centroids.
        x_transposed: (dim, n_samples) data matrix (e.g., centered data transposed)
        F_indicator: (n_samples, k) one-hot cluster indicator matrix
        Returns: (k, dim) centroids
        DEPRECATED in favor of compute_centroids_from_data for clarity, but kept if used elsewhere.
        This version is slightly different from compute_centroids_from_data input format.
        """
        num_clusters = F_indicator.shape[1]
        data_dim = x_transposed.shape[0]
        n_samples_check = x_transposed.shape[1]

        if F_indicator.shape[0] != n_samples_check:
            raise ValueError("Mismatch in number of samples between x_transposed and F_indicator.")

        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        # Determine labels from F_indicator
        labels = np.argmax(F_indicator, axis=1) # (n_samples,)

        for i in range(n_samples_check):
            cluster_label = labels[i]
            centroids[cluster_label] += x_transposed[:, i] # x_transposed[:, i] is a data point (dim,)
            counts[cluster_label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning (compute_centroids): Cluster {j} is empty. Assigning random point.")
                if n_samples_check > 0:
                    rng_empty_fallback = np.random.RandomState(self.seed + j + 2000)
                    centroids[j] = x_transposed[:, rng_empty_fallback.choice(n_samples_check)]
        return centroids

class CDKM_PurePy:
    def __init__(self, X: np.ndarray, c_true: int, debug: int = 0):
        self.X = X.astype(np.float64)  # shape (N, dim)
        self.N, self.dim = self.X.shape
        self.c_true = c_true
        self.debug = debug

        self.Y = []             # replicate list of label vectors
        self.n_iter_ = []       # number of iterations per replicate

        if debug:
            print(f"N = {self.N}, dim = {self.dim}, k = {self.c_true}")

    def opt(self, init_Y: np.ndarray, ITER: int):
        """
        init_Y: (rep, N) array of integer labels
        """
        rep = init_Y.shape[0]
        for rep_i in range(rep):
            y = init_Y[rep_i].copy()
            n_iter = self.opt_once(y, ITER)
            self.Y.append(y)
            self.n_iter_.append(n_iter)

    def opt_once(self, y: np.ndarray, ITER: int) -> int:
        """
        y: shape (N,), initial cluster assignment
        """
        X = self.X
        N, dim, c_true = self.N, self.dim, self.c_true

        xnorm = np.sum(X**2, axis=1)  # shape (N,)
        Sx = np.zeros((dim, c_true))
        n = np.zeros(c_true)

        for i in range(N):
            Sx[:, y[i]] += X[i]
            n[y[i]] += 1

        s = np.sum(Sx**2, axis=0)  # squared norm of each cluster sum vector

        for iter in range(ITER):
            converge = True
            for i in range(N):
                c_old = y[i]
                if n[c_old] == 1:
                    continue

                xi = X[i]
                xiSx = xi @ Sx  # (c,)
                tmp1 = s + 2 * xiSx + xnorm[i]
                tmp1 = tmp1 / (n + 1)
                tmp2 = s / n

                delta = tmp1 - tmp2
                delta[c_old] = s[c_old] / n[c_old] - \
                    (s[c_old] - 2 * xiSx[c_old] + xnorm[i]) / (n[c_old] - 1)

                c_new = np.argmax(delta)

                if c_new != c_old:
                    converge = False
                    y[i] = c_new

                    Sx[:, c_old] -= xi
                    Sx[:, c_new] += xi

                    s[c_old] = np.sum(Sx[:, c_old]**2)
                    s[c_new] = np.sum(Sx[:, c_new]**2)

                    n[c_old] -= 1
                    n[c_new] += 1

                if self.debug and i % 10000 == 0:
                    print(f"i = {i}")

            if self.debug:
                print(f"iter = {iter}")

            if converge:
                break

        # if iter + 1 == ITER:
            # print("not converge")

        return iter + 1

    @property
    def y_pre(self):
        return self.Y


class CDKM:
    def __init__(self, d, k, niter=200, nredo=10, verbose=False, seed=1234, debug=0):
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo
        self.verbose = verbose
        self.seed = seed
        self.debug = debug
        self.centroids = None
        self.labels_ = None
        self.Y_final_ = None
        self.sse_ = None
        self.balance_loss_ = None
        self.runtime_ = None
        self.n_iter_ = None

        self.normalized_entropy_ = None  # Normalized entropy
        self.cluster_variance_ = None    # Cluster size variance

    def train(self, x_orig_data, weights=None, init_centroids=None):
        np.random.seed(self.seed)
        start_time = time.time()

        n, dim = x_orig_data.shape
        if dim != self.d:
            raise ValueError(f"Data dimension {dim} does not match expected {self.d}.")

        if init_centroids is None:
            init_Y = initial_Y(x_orig_data, self.k, self.nredo, "random")
        else:
            init_Y = init_centroids

        model = CDKM_PurePy(x_orig_data, self.k, debug=self.debug)
        model.opt(init_Y, ITER=self.niter)
        Y = model.y_pre
        self.n_iter_ = model.n_iter_

        centroids = compute_cluster_centers_cdkm(x_orig_data, Y)
        labels = np.argmax(one_hot(Y[0], self.k), axis=1)

        # Compute SSE
        sse = np.sum((x_orig_data - centroids[labels])**2)

        # Compute balance loss
        counts = np.bincount(labels, minlength=self.k)
        ideal_size = n / self.k
        balance_loss = np.sum((counts - ideal_size)**2)

        self.Y_final_ = one_hot(Y[0], self.k)
        self.centroids = centroids
        self.labels_ = labels
        self.sse_ = sse
        self.balance_loss_ = balance_loss
        self.runtime_ = time.time() - start_time

        # Entropy & CV calculation
        eps = 1e-10
        l = n
        c = self.k
        size0 = np.bincount(self.labels_, minlength=c)

        entro = 0.0
        cv = 0.0
        for jj in range(c):
            frac = size0[jj] / l
            entro += frac * np.log((size0[jj]+eps)/l)
            cv += np.sqrt((size0[jj] - l / c) ** 2)

        self.normalized_entropy_ = -entro / np.log(c)
        self.cluster_variance_ = (c / l) * cv

        if self.verbose:
            print(f"CDKM finished in {self.runtime_:.4f}s; "
                  f"SSE = {self.sse_:.4f}; "
                  f"Balance Loss = {self.balance_loss_:.4f}; "
                  f"Iterations = {self.n_iter_}")

class BKNC:
    def __init__(self, d, k, niter=25, nredo=1, verbose=False,
                    min_points_per_centroid=1, max_points_per_centroid=1e9,
                    seed=1234, gpu=False, spherical=False,
                    update_index=True, frozen_centroids=False,
                    lambda_=1.0):
        # d: dimensionality of data (n_features)
        # k: number of clusters (c in BKNC)
        self.d_features = d
        self.k = k  # c in BKNC
        self.niter = niter # Niter in BKNC
        self.lambda_ = lambda_ # lambda in BKNC
        self.seed = seed
        self.verbose = verbose

        # Other Faiss Kmeans parameters - not directly used by BKNC logic
        self.nredo = nredo
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.gpu = gpu # BKNC as implemented here is CPU-only
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        # BKNC specific results
        self.F_ = None          # The F matrix from BKNC (n_samples x k)
        self.R_ = None          # The R matrix (k x k)
        self.Y_ = None          # The Y matrix (one-hot labels, n_samples x k)
        self.labels_ = None     # Final cluster assignments (idx, shape: n_samples)
        self.obj_history_ = []  # History of the objective function trace(F'X_m'X_mF)
        self.final_obj_ = None
        self.runtime_ = 0

        # For compatibility with original FCFC structure
        self.centroids = None # Will be populated with cluster means
        self.obj = None # Can store obj_history_ here

        # Final metrics as requested
        self.sse_ = None
        self.balance_loss_ = None
        self.normalized_entropy_ = None  # Normalized entropy
        self.cluster_variance_ = None    # Cluster size variance


    def _initialize_Y_bknc(self, n_samples, c):
        """
        Equivalent to MATLAB's init function for Y.
        Creates an n_samples x c one-hot encoded matrix from random labels.
        """
        # labels are 0 to c-1
        # This internal seeding should be fine as long as the main train method sets the overall seed.
        # If this method were called multiple times independently *within* one train call,
        # and expected different Ys, then it would need a different seeding strategy.
        # For now, it's called once per train.
        labels = np.random.randint(0, c, size=n_samples)
        Y = np.zeros((n_samples, c), dtype=int)
        Y[np.arange(n_samples), labels] = 1
        return Y

    def _calculate_cluster_centroids(self, data, labels, num_clusters, data_dim):
        """
        Calculates the mean of points in each cluster.
        data: (n_samples, n_features)
        labels: (n_samples,)
        num_clusters: k
        data_dim: d_features
        """
        centroids = np.zeros((num_clusters, data_dim))
        if labels is None: # Should not happen if called after labels are set
             if self.verbose:
                print("Warning: Labels are None in _calculate_cluster_centroids. Returning zero centroids.")
             return centroids

        for i in range(num_clusters):
            cluster_points = data[labels == i]
            if len(cluster_points) > 0:
                centroids[i] = np.mean(cluster_points, axis=0)
            else:
                if self.verbose:
                    print(f"Warning: Cluster {i} is empty during centroid calculation. Assigning a random data point as its centroid.")
                if len(data) > 0:
                    # Use a random number generator seeded by self.seed for consistent fallback
                    rng_fallback = np.random.RandomState(self.seed + i) # Add i for variety if multiple fallbacks
                    centroids[i] = data[rng_fallback.choice(len(data))]
                else: # No data points at all (edge case)
                    centroids[i] = np.zeros(data_dim)
        return centroids

    def train(self, x, weights=None, init_centroids=None):
        """
        Implements the BKNC algorithm.
        x: data matrix (n_samples, n_features)
        weights: Not used by BKNC.
        init_centroids: Not used by BKNC.
        """
        np.random.seed(self.seed) # Set seed for reproducibility for the entire train method
        start_time = time.time()

        if x.shape[1] != self.d_features:
            raise ValueError(f"Input data feature dimension {x.shape[1]} "
                             f"does not match class initialized dimension {self.d_features}")

        X_m = x.T  # (n_features, n_samples) - X_m is the MATLAB-like X
        n_features_internal, n_samples = X_m.shape # n_features_internal is self.d_features
        c = self.k # Number of clusters

        # Initialize Y (n_samples, c)
        # _initialize_Y_bknc uses np.random, which is now seeded by self.seed
        Y = self._initialize_Y_bknc(n_samples, c)

        # Initialize R (c, c) as a random orthogonal matrix
        # np.random.rand is also affected by the global seed set above
        R = orth(np.random.rand(c, c))

        obj_log = np.zeros(self.niter)

        # F_loop initialization is also seeded
        for iter_num in range(self.niter):
            F_loop = orth(np.random.rand(n_samples, c))
            G = Y @ R.T

            for _ in range(10):
                TempM_F = X_m @ F_loop
                M_calc_F = 2 * X_m.T @ TempM_F + self.lambda_ * G
                U_f, _, Vh_f = np.linalg.svd(M_calc_F, full_matrices=False)
                F_loop = U_f @ Vh_f
            F_current = F_loop

            N_calc_R = F_current.T @ Y
            U_r, _, Vh_r = np.linalg.svd(N_calc_R, full_matrices=False)
            R = U_r @ Vh_r

            P_calc_Y = R.T @ F_current.T
            idx = np.argmax(P_calc_Y, axis=0)
            Y = np.zeros((n_samples, c), dtype=int)
            Y[np.arange(n_samples), idx] = 1

            TempF_obj = X_m @ F_current
            obj_log[iter_num] = np.trace(TempF_obj.T @ TempF_obj)

            if self.verbose and (iter_num % 5 == 0 or iter_num == self.niter -1):
                print(f"Iter {iter_num+1}/{self.niter}, BKNC Obj: {obj_log[iter_num]:.4f}")

        self.runtime_ = time.time()

        # Store BKNC results
        self.F_ = F_current
        self.R_ = R
        self.Y_ = Y # This is the one-hot version of labels from the last iteration
        self.labels_ = idx # finalInd in MATLAB (0-indexed labels)
        self.obj_history_ = obj_log
        self.final_obj_ = obj_log[-1]
        self.obj = self.obj_history_ # Compatibility

        # --- Calculate final centroids, SSE, and Balance Loss ---
        # self.centroids are calculated based on original data `x` and final `self.labels_`
        self.centroids = self._calculate_cluster_centroids(x, self.labels_, self.k, self.d_features)

        # Calculate SSE
        current_sse = 0
        if self.labels_ is not None and self.centroids is not None:
            for i in range(n_samples):
                cluster_idx = self.labels_[i]
                point = x[i, :]
                centroid_val = self.centroids[cluster_idx, :]
                current_sse += np.sum((point - centroid_val)**2) # Squared Euclidean distance
        self.sse_ = current_sse

        # Calculate Balance Loss
        current_balance_loss = 0
        if self.labels_ is not None:
            cluster_sizes = np.bincount(self.labels_, minlength=self.k)
            ideal_size = n_samples / self.k
            current_balance_loss = np.sum((cluster_sizes - ideal_size)**2)
        self.balance_loss_ = current_balance_loss

        # Final runtime calculation
        self.runtime_ = time.time() - start_time # Corrected runtime calculation

        # Entropy & CV calculation
        eps = 1e-10
        l = n_samples
        c = self.k
        size0 = np.bincount(self.labels_, minlength=c)

        entro = 0.0
        cv = 0.0
        for jj in range(c):
            frac = size0[jj] / l
            entro += frac * np.log((size0[jj]+eps)/l)
            cv += np.sqrt((size0[jj] - l / c) ** 2)

        self.normalized_entropy_ = -entro / np.log(c)
        self.cluster_variance_ = (c / l) * cv

        if self.verbose:
            print(f"BKNC training completed in {self.runtime_:.4f} seconds.")
            print(f"Final BKNC objective (trace): {self.final_obj_:.4f}")
            unique_labels_final, counts_final = np.unique(self.labels_, return_counts=True)
            print(f"Final cluster sizes: {dict(zip(unique_labels_final, counts_final))}")
            if self.centroids is not None:
                print(f"Shape of calculated centroids: {self.centroids.shape}")
            print(f"Final SSE: {self.sse_:.4f}")
            print(f"Final Balance Loss: {self.balance_loss_:.4f}")

class MyKMeans:
    def __init__(self, d, k, niter=25, nredo=1, verbose=False,
                    min_points_per_centroid=1, max_points_per_centroid=1e9,
                    seed=1234, gpu=False, spherical=False,
                    update_index=True, frozen_centroids=False,
                    lambda_=1.0):
        # d: dimensionality of data (n_features)
        # k: number of clusters (c in BKNC)
        self.d_features = d
        self.k = k  # c in BKNC
        self.niter = niter # Niter in BKNC
        self.lambda_ = lambda_ # lambda in BKNC
        self.lambda_reformed = (1-lambda_)/lambda_
        self.seed = seed
        self.verbose = verbose

        # Other Faiss Kmeans parameters - not directly used by BKNC logic
        self.nredo = nredo
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.gpu = gpu # BKNC as implemented here is CPU-only
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        # BKNC specific results
        self.F_ = None          # The F matrix from BKNC (n_samples x k)
        self.R_ = None          # The R matrix (k x k)
        self.Y_ = None          # The Y matrix (one-hot labels, n_samples x k)
        self.labels_ = None     # Final cluster assignments (idx, shape: n_samples)
        self.obj_history_ = []  # History of the objective function trace(F'X_m'X_mF)
        self.final_obj_ = None
        self.runtime_ = 0

        # For compatibility with original FCFC structure
        self.centroids = None # Will be populated with cluster means
        self.obj = None # Can store obj_history_ here

        # Final metrics as requested
        self.sse_ = None
        self.balance_loss_ = None
        self.normalized_entropy_ = None  # Normalized entropy
        self.cluster_variance_ = None    # Cluster size variance


    def _initialize_Y_bknc(self, n_samples, c):
        """
        Equivalent to MATLAB's init function for Y.
        Creates an n_samples x c one-hot encoded matrix from random labels.
        """
        # labels are 0 to c-1
        # This internal seeding should be fine as long as the main train method sets the overall seed.
        # If this method were called multiple times independently *within* one train call,
        # and expected different Ys, then it would need a different seeding strategy.
        # For now, it's called once per train.
        labels = np.random.randint(0, c, size=n_samples)
        Y = np.zeros((n_samples, c), dtype=int)
        Y[np.arange(n_samples), labels] = 1
        return Y

    def _calculate_cluster_centroids(self, data, labels, num_clusters, data_dim):
        """
        Calculates the mean of points in each cluster.
        data: (n_samples, n_features)
        labels: (n_samples,)
        num_clusters: k
        data_dim: d_features
        """
        centroids = np.zeros((num_clusters, data_dim))
        if labels is None: # Should not happen if called after labels are set
             if self.verbose:
                print("Warning: Labels are None in _calculate_cluster_centroids. Returning zero centroids.")
             return centroids

        for i in range(num_clusters):
            cluster_points = data[labels == i]
            if len(cluster_points) > 0:
                centroids[i] = np.mean(cluster_points, axis=0)
            else:
                if self.verbose:
                    print(f"Warning: Cluster {i} is empty during centroid calculation. Assigning a random data point as its centroid.")
                if len(data) > 0:
                    # Use a random number generator seeded by self.seed for consistent fallback
                    rng_fallback = np.random.RandomState(self.seed + i) # Add i for variety if multiple fallbacks
                    centroids[i] = data[rng_fallback.choice(len(data))]
                else: # No data points at all (edge case)
                    centroids[i] = np.zeros(data_dim)
        return centroids

    def train(self, x, weights=None, init_centroids=None):
        """
        Implements the BKNC algorithm.
        x: data matrix (n_samples, n_features)
        weights: Not used by BKNC.
        init_centroids: Not used by BKNC.
        """
        np.random.seed(self.seed) # Set seed for reproducibility for the entire train method
        start_time = time.time()

        if x.shape[1] != self.d_features:
            raise ValueError(f"Input data feature dimension {x.shape[1]} "
                             f"does not match class initialized dimension {self.d_features}")

        X_m = x.T  # (n_features, n_samples) - X_m is the MATLAB-like X
        n_features_internal, n_samples = X_m.shape # n_features_internal is self.d_features
        c = self.k # Number of clusters

        # Initialize Y (n_samples, c)
        # _initialize_Y_bknc uses np.random, which is now seeded by self.seed
        Y = self._initialize_Y_bknc(n_samples, c)

        # Initialize R (c, c) as a random orthogonal matrix
        # np.random.rand is also affected by the global seed set above
        R = orth(np.random.rand(c, c))

        obj_log = np.zeros(self.niter)

        # F_loop initialization is also seeded
        for iter_num in range(self.niter):
            F_loop = orth(np.random.rand(n_samples, c))
            G = Y @ R.T

            for _ in range(10):
                TempM_F = X_m @ F_loop
                M_calc_F = 2 * X_m.T @ TempM_F + self.lambda_reformed * G
                U_f, _, Vh_f = np.linalg.svd(M_calc_F, full_matrices=False)
                F_loop = U_f @ Vh_f
            F_current = F_loop

            N_calc_R = F_current.T @ Y
            U_r, _, Vh_r = np.linalg.svd(N_calc_R, full_matrices=False)
            R = U_r @ Vh_r

            P_calc_Y = R.T @ F_current.T
            idx = np.argmax(P_calc_Y, axis=0)
            Y = np.zeros((n_samples, c), dtype=int)
            Y[np.arange(n_samples), idx] = 1

            TempF_obj = X_m @ F_current
            obj_log[iter_num] = np.trace(TempF_obj.T @ TempF_obj)

            if self.verbose and (iter_num % 5 == 0 or iter_num == self.niter -1):
                print(f"Iter {iter_num+1}/{self.niter}, BKNC Obj: {obj_log[iter_num]:.4f}")

        self.runtime_ = time.time()

        # Store BKNC results
        self.F_ = F_current
        self.R_ = R
        self.Y_ = Y # This is the one-hot version of labels from the last iteration
        self.labels_ = idx # finalInd in MATLAB (0-indexed labels)
        self.obj_history_ = obj_log
        self.final_obj_ = obj_log[-1]
        self.obj = self.obj_history_ # Compatibility

        # --- Calculate final centroids, SSE, and Balance Loss ---
        # self.centroids are calculated based on original data `x` and final `self.labels_`
        self.centroids = self._calculate_cluster_centroids(x, self.labels_, self.k, self.d_features)

        # Calculate SSE
        current_sse = 0
        if self.labels_ is not None and self.centroids is not None:
            for i in range(n_samples):
                cluster_idx = self.labels_[i]
                point = x[i, :]
                centroid_val = self.centroids[cluster_idx, :]
                current_sse += np.sum((point - centroid_val)**2) # Squared Euclidean distance
        self.sse_ = current_sse

        # Calculate Balance Loss
        current_balance_loss = 0
        if self.labels_ is not None:
            cluster_sizes = np.bincount(self.labels_, minlength=self.k)
            ideal_size = n_samples / self.k
            current_balance_loss = np.sum((cluster_sizes - ideal_size)**2)
        self.balance_loss_ = current_balance_loss

        # Final runtime calculation
        self.runtime_ = time.time() - start_time # Corrected runtime calculation

        # Entropy & CV calculation
        eps = 1e-10
        l = n_samples
        c = self.k
        size0 = np.bincount(self.labels_, minlength=c)

        entro = 0.0
        cv = 0.0
        for jj in range(c):
            frac = size0[jj] / l
            entro += frac * np.log((size0[jj]+eps)/l)
            cv += np.sqrt((size0[jj] - l / c) ** 2)

        self.normalized_entropy_ = -entro / np.log(c)
        self.cluster_variance_ = (c / l) * cv

        if self.verbose:
            print(f"BKNC training completed in {self.runtime_:.4f} seconds.")
            print(f"Final BKNC objective (trace): {self.final_obj_:.4f}")
            unique_labels_final, counts_final = np.unique(self.labels_, return_counts=True)
            print(f"Final cluster sizes: {dict(zip(unique_labels_final, counts_final))}")
            if self.centroids is not None:
                print(f"Shape of calculated centroids: {self.centroids.shape}")
            print(f"Final SSE: {self.sse_:.4f}")
            print(f"Final Balance Loss: {self.balance_loss_:.4f}")

# Helper functions (moved outside the class, or could be static methods)
def initial_Y(X, c, rep, way="random"):
        N = X.shape[0]
        Y = np.zeros((rep, N), dtype=np.int32)

        if way == "random":
            for rep_i in range(rep):
                Y[rep_i] = np.random.randint(0, c, N)

        elif way == "k-means++":
            for rep_i in range(rep):
                Y[rep_i] = KMeans(n_clusters=c, init="k-means++", n_init=1, max_iter=1).fit(X).labels_

        else:
            assert 2 == 1

        return Y
def one_hot(y: np.ndarray, k: int):
    n = len(y)
    Y = np.zeros((n, k), dtype=np.float32)
    Y[np.arange(n), y] = 1.0
    return Y
def compute_cluster_centers_cdkm(X, Y):
    """
    X: (n, d)
    Y: list of cluster label arrays, each of shape (n,)
    """
    y = Y[0]  # shape (n,)
    n, k = X.shape[0], np.max(y) + 1
    Y0 = np.zeros((n, k), dtype=np.float64)
    Y0[np.arange(n), y] = 1.0  # one-hot

    weights = np.sum(Y0, axis=0)  # (k,)
    weights[weights == 0] = 1e-10

    centers = (Y0.T @ X) / weights[:, None]
    return centers

def get_centroid(data, label, K, n, d_features):
    """
    Update centroids after the assignment phase.
    data: (n, d_features)
    label: (n,)
    K: number of clusters
    n: number of samples
    d_features: number of features
    """
    centroids = np.zeros((K, d_features))
    for k_idx in range(K):
        members = (label == k_idx)
        if np.any(members):
            # Np.sum on boolean array members gives count of True values
            centroids[k_idx, :] = np.sum(data[members, :], axis=0) / np.sum(members)
        else:
            # Handle empty cluster: assign a random point from data
            # This random choice is now affected by the seed set in train()
            if n > 0 : # Ensure data is not empty
                 centroids[k_idx, :] = data[np.random.choice(n), :]
            # else: centroid remains zeros if data is empty (edge case)
    return centroids


def get_distance(data, centroids, K, n, d_features, size_cluster, lambda_param):
    """
    Objective function term for assignment:
    D(i,j) = distance(i-th data point, j-th centroid)^2 + lambda_param * size_of_jth_cluster
    data: (n, d_features)
    centroids: (K, d_features)
    size_cluster: (K,) - current size of each cluster
    lambda_param: balance weight
    Returns: D_matrix (n, K)
    """
    D_matrix = np.zeros((n, K))
    for k_idx in range(K):
        # Squared Euclidean distance
        dist_sq = np.sum((data - centroids[k_idx, :])**2, axis=1)
        D_matrix[:, k_idx] = dist_sq + lambda_param * size_cluster[k_idx]
    return D_matrix


def initial_centroid(x_data, K, n_samples):
    """
    Initialize centroids randomly by choosing K unique points from the data.
    x_data: (n_samples, d_features)
    K: number of clusters
    n_samples: number of samples
    """
    if K > n_samples:
        raise ValueError("K (number of clusters) cannot be greater than n_samples.")
    # This random choice is now affected by the seed set in train()
    indices = np.random.choice(n_samples, K, replace=False)
    return x_data[indices, :]

# 优化后的数据加载函数
def load_data_chunked(path, dtype='float32', chunksize=1000):
    """分块加载大数据集避免内存溢出"""
    chunks = []
    for chunk in pd.read_csv(path, header=None, chunksize=chunksize):
        chunks.append(chunk.astype(dtype))
    return np.concatenate(chunks, axis=0)

def run_experiment(model_class, model_name, dataset_path, dimensions, n_clusters, n_runs=1):
    """运行实验并返回 CV 和 Entropy"""
    try:
        X_data = load_data_chunked(dataset_path)
        for run in range(n_runs):
            if model_name == "FCFC":
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)
            elif model_name == "BCLS":
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)
            elif model_name == "Lloyd":
                model = model_class(d=dimensions, k=n_clusters, niter=10, seed=1234+run, verbose=False)
            elif model_name == "CDKM":
                model = model_class(d=dimensions, k=n_clusters, niter=10, seed=1234+run, verbose=False)
            elif model_name == "BKNC":
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)
            else:
                model = model_class(d=dimensions, k=n_clusters, niter=10, lambda_=0.1, seed=1234+run, verbose=False)

            model.train(X_data) if hasattr(model, 'train') else model.fit(X_data)

        # 返回三元组：方法名、CV、熵
        return model.cluster_variance_, model.normalized_entropy_

    finally:
        if 'X_data' in locals(): del X_data
        if 'model' in locals(): del model
        gc.collect()


if __name__ == '__main__':
    # 模型列表
    models = [
        (Lloyd, "Lloyd"),
        (CDKM, "CDKM"),
        (BCLS, "BCLS"),
        (FCFC, "FCFC"),
        (BKNC, "BKNC"),
        (MyKMeans, "MyKMeans")
    ]

    # 数据集路径和维度
    datasets = [
        ("/content/sample_data/Huatuo_1024d_10k.csv", 1024),
        ("/content/sample_data/LiveChat_1024d_10k.csv", 1024),
        ("/content/sample_data/deep_96d_10k.csv", 96),
        ("/content/sample_data/glove_300d_10k.csv", 300),
        ("/content/sample_data/sift_128d_10k.csv", 128)
    ]

    k = 10  # 固定聚类数
    records = []  # 存储结果

    for dataset_path, dim in datasets:
        dataset_name = Path(dataset_path).stem.split("-")[0]
        for model_class, model_name in models:
            print(f"Running {model_name} on {dataset_name} (d={dim}, k={k})...")
            cv, entro = run_experiment(model_class, model_name, dataset_path, dim, k)
            records.append([dataset_name, model_name, cv, entro])

    # 保存为 CSV 文件
    results_df = pd.DataFrame(records, columns=["Dataset", "Method", "CV", "Entropy"])
    results_df.to_csv('cv_entropy_results.csv', index=False)
    print("\nSaved CV & Entropy results to 'cv_entropy_results.csv'")

Running Lloyd on Huatuo (d=1024, k=10)...
Running CDKM on Huatuo (d=1024, k=10)...
Running BCLS on Huatuo (d=1024, k=10)...
Running FCFC on Huatuo (d=1024, k=10)...
Running BKNC on Huatuo (d=1024, k=10)...
Running MyKMeans on Huatuo (d=1024, k=10)...
Running Lloyd on LiveChat (d=1024, k=10)...
Running CDKM on LiveChat (d=1024, k=10)...
Running BCLS on LiveChat (d=1024, k=10)...
Running FCFC on LiveChat (d=1024, k=10)...
Running BKNC on LiveChat (d=1024, k=10)...
Running MyKMeans on LiveChat (d=1024, k=10)...
Running Lloyd on deep_96d_10k (d=96, k=10)...
Running CDKM on deep_96d_10k (d=96, k=10)...
Running BCLS on deep_96d_10k (d=96, k=10)...
Running FCFC on deep_96d_10k (d=96, k=10)...
Running BKNC on deep_96d_10k (d=96, k=10)...
Running MyKMeans on deep_96d_10k (d=96, k=10)...
Running Lloyd on glove_300d_10k (d=300, k=10)...
Running CDKM on glove_300d_10k (d=300, k=10)...
Running BCLS on glove_300d_10k (d=300, k=10)...
Running FCFC on glove_300d_10k (d=300, k=10)...
Running BKNC on gl

In [None]:
import pandas as pd
import numpy as np

# 1️⃣ 读 CSV
df = pd.read_csv('cv_entropy_results.csv')

# 2️⃣ 提取原始 Dataset 和 Model 的顺序
df['Dataset_raw'] = df['Dataset'].str.split('-').str[0].str.split('_').str[0]
df['Model_raw'] = df['Method']

# 3️⃣ 替换 Dataset 和 Model 名
df['Dataset'] = df['Dataset_raw'].replace({
    'deep': 'Deep',
    'sift': 'SIFT',
    'glove': 'GloVe'
})
df['Method'] = df['Model_raw'].replace({'MyKMeans': 'Tub-means'})

# 4️⃣ 用 CSV 中 Dataset 出现顺序定义 Categorical
dataset_order = df.drop_duplicates('Dataset_raw')['Dataset'].tolist()
df['Dataset'] = pd.Categorical(df['Dataset'], categories=dataset_order, ordered=True)

# 5️⃣ 保留原行顺序标记（保证 Model 顺序）
df['row_order'] = np.arange(len(df))

# 6️⃣ 按 Dataset、原行号 排序
df = df.sort_values(['Dataset', 'row_order']).drop(columns=['row_order', 'Dataset_raw', 'Model_raw'])

# ✅ 从这里开始是你原来的生成表格逻辑
def format_row(dataset, model, cv, entro, last=False, multirow=False, multirow_count=None):
    cv = "{:.5f}".format(cv)
    entro = "{:.5f}".format(entro)

    if multirow:
        row = f"\\multirow{{{multirow_count}}}{{*}}{{\\ {dataset}}}\n& \\ {model} & {cv} & {entro} & 0 \\\\"
    elif last:
        row = f"& \\ {model} & {cv} & {entro} & 0 \\\\ \\midrule \n"
    else:
        row = f"& \\ {model} & {cv} & {entro} & 0 \\\\"
    return row

# 7️⃣ 分块生成
lines = []
for dataset, group in df.groupby('Dataset'):
    models = group['Method'].tolist()
    cv = group['CV'].tolist()
    entro = group['Entropy'].tolist()

    for i in range(len(models)):
        is_first = (i == 0)
        is_last = (i == len(models) - 1)  # 注意这里用实际长度
        line = format_row(dataset, models[i], cv[i], entro[i],
                          last=is_last, multirow=is_first, multirow_count=len(models) if is_first else None)
        lines.append(line)

# 8️⃣ 输出
latex_table_body = "\n".join(lines)
print(latex_table_body)

with open("latex_table_rows.tex", "w") as f:
    f.write(latex_table_body)

\multirow{6}{*}{\ Huatuo}
& \ Lloyd & 2.42516 & 0.98232 & 0 \\
& \ CDKM & 3.25267 & 0.96677 & 0 \\
& \ BCLS & 4.75672 & 0.92561 & 0 \\
& \ FCFC & 18.00000 & -0.00000 & 0 \\
& \ BKNC & 1.63684 & 0.99057 & 0 \\
& \ Tub-means & 0.57654 & 0.99740 & 0 \\ \midrule 

\multirow{6}{*}{\ LiveChat}
& \ Lloyd & 2.50875 & 0.97836 & 0 \\
& \ CDKM & 3.46105 & 0.96303 & 0 \\
& \ BCLS & 2.44696 & 0.97806 & 0 \\
& \ FCFC & 18.00000 & -0.00000 & 0 \\
& \ BKNC & 0.90911 & 0.99742 & 0 \\
& \ Tub-means & 0.17878 & 0.99990 & 0 \\ \midrule 

\multirow{6}{*}{\ Deep}
& \ Lloyd & 3.28267 & 0.96603 & 0 \\
& \ CDKM & 2.62874 & 0.97487 & 0 \\
& \ BCLS & 12.00000 & 0.54968 & 0 \\
& \ FCFC & 18.00000 & -0.00000 & 0 \\
& \ BKNC & 1.03890 & 0.99533 & 0 \\
& \ Tub-means & 0.62294 & 0.99871 & 0 \\ \midrule 

\multirow{6}{*}{\ GloVe}
& \ Lloyd & 2.42916 & 0.97282 & 0 \\
& \ CDKM & 2.66893 & 0.97861 & 0 \\
& \ BCLS & 11.15208 & 0.63001 & 0 \\
& \ FCFC & 10.18638 & 0.61696 & 0 \\
& \ BKNC & 0.94871 & 0.99594 & 0 \\
& \ Tub-

  for dataset, group in df.groupby('Dataset'):
