In [None]:
# 更新的runtime
!pip install faiss-cpu
import numpy as np
import pandas as pd
import faiss
import time
from pathlib import Path
import gc
from scipy.linalg import orth # For creating orthogonal matrices
from collections import defaultdict
from typing import Dict, Tuple, List, Optional
import csv
import os

class Lloyd:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                 min_points_per_centroid=1, max_points_per_centroid=1000000000,
                 seed=1234, gpu=False, spherical=False,
                 update_index=True, frozen_centroids=False):
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo
        self.verbose = verbose
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.seed = seed
        self.gpu = gpu
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        self.centroids = None
        self.obj_history_ = None
        self.labels_ = None

        self.sse_ = None
        self.balance_loss_ = None
        self.runtime_ = None
        self.obj = None

    def compute_centroids_from_data(self, data_points, labels, num_clusters, data_dim):
        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        if labels is None:
            if self.verbose:
                print("Warning: Labels are None in compute_centroids_from_data.")
            return centroids

        for i in range(len(data_points)):
            label = labels[i]
            centroids[label] += data_points[i]
            counts[label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning: Cluster {j} is empty. Assigning random point.")
                centroids[j] = data_points[np.random.randint(len(data_points))]

        return centroids

    def train(self, x_orig_data, weights=None, init_centroids=None):
        start_time = time.time()
        np.random.seed(self.seed)

        if x_orig_data.shape[1] != self.d:
            raise ValueError(f"Input dimension {x_orig_data.shape[1]} != {self.d}")

        n, dim = x_orig_data.shape
        x = np.ascontiguousarray(x_orig_data, dtype='float32')

        kmeans = faiss.Kmeans(
            d=self.d,
            k=self.k,
            niter=self.niter,
            nredo=self.nredo,
            verbose=self.verbose,
            min_points_per_centroid=self.min_points_per_centroid,
            max_points_per_centroid=self.max_points_per_centroid,
            seed=self.seed,
            gpu=self.gpu,
            spherical=self.spherical,
            update_index=self.update_index,
            frozen_centroids=self.frozen_centroids
        )

        kmeans.train(x, init_centroids=init_centroids)

        _, self.labels_ = kmeans.index.search(x, 1)
        self.labels_ = self.labels_.flatten()

        self.centroids = kmeans.centroids
        self.obj_history_ = kmeans.obj if kmeans.obj is not None and len(kmeans.obj) > 0 else np.zeros(self.niter)
        self.obj = kmeans.obj[-1] if kmeans.obj is not None and len(kmeans.obj) > 0 else None
        self.runtime_ = time.time() - start_time


        # Print every 5th iteration's objective value
        if self.verbose and self.obj_history_ is not None and len(self.obj_history_) > 0:
            print("\n--- Objective Value (every 5 iterations) ---")
            for i, val in enumerate(self.obj_history_):
                if (i + 1) % 5 == 0 or i == len(self.obj_history_) - 1:
                    print(f"  Iter {i+1:2d}: {val:.6f}")

        final_sse = 0
        for i in range(n):
            cluster_idx = self.labels_[i]
            final_sse += np.sum((x_orig_data[i] - self.centroids[cluster_idx]) ** 2)
        self.sse_ = final_sse

        final_balance_loss = 0
        if self.labels_ is not None:
            sizes = np.bincount(self.labels_, minlength=self.k)
            ideal = n / self.k
            final_balance_loss = np.sum((sizes - ideal) ** 2)
        self.balance_loss_ = final_balance_loss

        if self.verbose:
            print(f"Lloyd training finished in {self.runtime_:.4f}s")
            print(f"Final obj: {self.obj}")
            print(f"Cluster sizes: {dict(zip(*np.unique(self.labels_, return_counts=True)))}")
            print(f"SSE: {self.sse_:.4f}")
            print(f"Balance Loss: {self.balance_loss_:.4f}")

class DynamicKMeansCoreset:
    def __init__(self, d, k, threshold=100, **lloyd_kwargs):
        self.d = d
        self.k = k
        self.threshold = threshold
        self.lloyd_kwargs = lloyd_kwargs

        self.buffer = [] # 临时存储插入/删除的点ID
        self.coreset_tree = defaultdict(list) # 分层存储核心集的数据结构
        self.active_points: Dict[str, Tuple[np.ndarray, float]] = {} # 当前活跃的点(未删除)
        self.deleted_points = set() # 已删除的点ID集合
        self.model = Lloyd(d=d, k=k, **lloyd_kwargs)

    def insert(self, point_id: str, vector: np.ndarray, weight: float = 1.0):
        self.active_points[point_id] = (np.array(vector, dtype=np.float32), weight)
        self.buffer.append(point_id)
        if len(self.buffer) >= self.threshold:
            self._add_coreset(self.buffer)
            self.buffer.clear()

    def delete(self, point_id: str):
        if point_id in self.active_points:
            self.deleted_points.add(point_id)
            self.buffer.append(point_id)
            if len(self.buffer) >= self.threshold:
                self._add_coreset(self.buffer)
                self.buffer.clear()

    def _add_coreset(self, point_ids: List[str]):
        X, weights = [], []
        for pid in point_ids:
            if pid in self.active_points and pid not in self.deleted_points:
                vec, w = self.active_points[pid]
                X.append(vec)
                weights.append(w)

        if not X:
            return

        X = np.vstack(X).astype(np.float32)  # 将向量堆叠成矩阵
        weights = np.array(weights, dtype=np.float32)  # 将权重转为numpy数组
        self._merge_into_tree(0, (X, weights))  # 将新核心集合并到树中

    def _merge_into_tree(self, level: int, new_coreset: Tuple[np.ndarray, np.ndarray]):
        self.coreset_tree[level].append(new_coreset)
        if len(self.coreset_tree[level]) > 1:
            X1, w1 = self.coreset_tree[level].pop()
            X2, w2 = self.coreset_tree[level].pop()
            X = np.vstack([X1, X2])
            W = np.concatenate([w1, w2])
            reducer = Lloyd(d=self.d, k=self.k, **self.lloyd_kwargs)
            reducer.train(X, weights=W)
            reduced_X = reducer.centroids
            reduced_W = np.bincount(reducer.labels_, minlength=self.k).astype(np.float32)
            self._merge_into_tree(level + 1, (reduced_X, reduced_W))

    def retrain(self):
        if self.buffer:
            self._add_coreset(self.buffer)
            self.buffer.clear()
        X_all, W_all = [], []
        for level in self.coreset_tree:
            for X, W in self.coreset_tree[level]:
                X_all.append(X)
                W_all.append(W)
        if not X_all:
            return
        X = np.vstack(X_all)
        W = np.concatenate(W_all)
        mask = W > 0
        X = X[mask]
        W = W[mask]
        self.model.train(X, weights=W)

    def get_centroids(self) -> Optional[np.ndarray]:
        return self.model.centroids

    def get_labels(self) -> Optional[np.ndarray]:
        return self.model.labels_

    def get_metrics(self) -> Dict[str, float]:
        return {
            "sse": self.model.sse_,
            "balance_loss": self.model.balance_loss_,
            "runtime": self.model.runtime_,
            "objective": self.model.obj
        }

class OnlineKMeans:
    def __init__(self, d, k, init_method='split', verbose=False, seed=0):
        self.d, self.k = d, k
        self.verbose = verbose
        np.random.seed(seed)
        self.centroids = None
        self.counts = None  # 每簇访问计数

    def initialize(self, first_point):
        self.centroids = first_point.reshape(1, -1)
        self.counts = np.array([1], dtype=int)

    def maybe_split(self):
        # 找方差最大簇，二分裂
        if self.centroids.shape[0] < self.k:
            variances = []
            for i in range(len(self.counts)):
                variances.append(np.random.rand())  # 实际应估方差，这里 placeholder
            j = np.argmax(variances)
            c = self.centroids[j]
            delta = np.random.randn(self.d) * 1e-3
            new_c1, new_c2 = c + delta, c - delta
            self.centroids[j] = new_c1
            self.centroids = np.vstack([self.centroids, new_c2])
            self.counts[j] = self.counts[j] // 2
            self.counts = np.append(self.counts, self.counts[j])
            if self.verbose:
                print(f"Split centroid {j} into 2; now {self.centroids.shape[0]} total.")

    def partial_fit(self, x_new):
        if self.centroids is None:
            self.initialize(x_new)
            return
        # 找最近质心
        dists = np.linalg.norm(self.centroids - x_new, axis=1)
        j = np.argmin(dists)
        self.counts[j] += 1
        eta = 1.0 / self.counts[j]
        self.centroids[j] += eta * (x_new - self.centroids[j])
        # 尝试分裂
        if self.centroids.shape[0] < self.k:
            self.maybe_split()

    def get_balance_loss(self):
        if self.counts is None or len(self.counts) == 0:
            return 0.0
        total = np.sum(self.counts)
        ideal = total / self.k
        return np.sum((self.counts - ideal) ** 2)

class MiniBatchKMeans(Lloyd):
    def __init__(self, d, k, batch_size=32, iters=100, **kwargs):
        super().__init__(d=d, k=k, **kwargs)
        self.batch_size = batch_size
        self.iters = iters

    def partial_fit(self, X_stream):
        # X_stream: shape (n_stream, d)
        # 初始化质心（如果为空则从 stream 抽 batch 初始化）
        if self.centroids is None:
            init = X_stream[np.random.choice(len(X_stream), self.k, replace=False)]
            self.centroids = init
            self.counts = np.zeros(self.k, dtype=int)

        for i in range(self.iters):
            idx = np.random.choice(len(X_stream), self.batch_size, replace=False)
            batch = X_stream[idx]
            index = faiss.IndexFlatL2(self.d)
            index.add(self.centroids.astype(np.float32))
            _, labels = index.search(batch.astype(np.float32), 1)

            labels = labels.flatten()
            for j in range(self.k):
                mask = (labels == j)
                m = np.sum(mask)
                if m > 0:
                    centroid = self.centroids[j]
                    mean_batch = batch[mask].mean(axis=0)
                    eta = 1.0 / (self.counts[j] + m)
                    self.centroids[j] = (1 - eta) * centroid + eta * mean_batch
                    self.counts[j] += m
        # 设置 labels 未定义，balance/SSE 可后续重算

    def get_balance_loss(self):
        if getattr(self, 'counts', None) is None:
            return None
        ideal = self.counts.sum() / self.k
        return np.sum((self.counts - ideal)**2)

def evaluate_test_sse(test_data, centroids):
    dists = np.linalg.norm(test_data[:, None, :] - centroids[None, :, :], axis=2)
    nearest = np.argmin(dists, axis=1)
    return np.sum((test_data - centroids[nearest])**2)

def run_test_update_only_comparison(datasets, k=10, threshold=200):
    for ds in datasets:
        name = Path(ds["train"]).stem.split("_")[0]
        dim = ds["dim"]
        train_data = pd.read_csv(ds["train"], header=None).values.astype(np.float32)
        test_data = pd.read_csv(ds["test"], header=None).values.astype(np.float32)

        print(f"\n====== Dataset: {name} (dim={dim}) ======")

        # 🧠 Step 1️⃣: 使用 Lloyd 方法训练 10K 数据，作为所有方法的初始化
        lloyd = Lloyd(d=dim, k=k, verbose=False)
        lloyd.train(train_data)
        print(f"[Lloyd Init] SSE: {lloyd.sse_:.2f} | BalLoss: {lloyd.balance_loss_:.2f} | Time: {lloyd.runtime_:.2f}s")

        base_centroids = lloyd.centroids.copy()
        base_labels = lloyd.labels_.copy()

        ### 🎯 方法1：DynamicCoreset（更新 test_data）
        dyn = DynamicKMeansCoreset(d=dim, k=k, threshold=threshold, verbose=False)
        for i, vec in enumerate(train_data):
            dyn.insert(f"{name}_train_{i}", vec)
        dyn.model.centroids = base_centroids.copy()
        dyn.model.labels_ = base_labels.copy()

        test_sse_dyn_before = evaluate_test_sse(test_data, base_centroids)
        for i, vec in enumerate(test_data):
            dyn.insert(f"{name}_test_{i}", vec)

        t0 = time.perf_counter()
        dyn.retrain()
        t1 = time.perf_counter()
        test_sse_dyn_after = evaluate_test_sse(test_data, dyn.get_centroids())

        print(f"[Coreset Update] SSE: {dyn.model.sse_:.2f} | BalLoss: {dyn.model.balance_loss_:.2f} | Time: {t1 - t0:.2f}s")
        print(f"[Coreset Test SSE] Before: {test_sse_dyn_before:.2f} → After: {test_sse_dyn_after:.2f}")

        ### 🟦 方法2：OnlineKMeans（更新 test_data）
        online = OnlineKMeans(d=dim, k=k)
        online.centroids = base_centroids.copy()
        online.counts = np.bincount(base_labels, minlength=k)

        test_sse_online_before = evaluate_test_sse(test_data, base_centroids)
        t2 = time.perf_counter()
        for vec in test_data:
            online.partial_fit(vec)
        t3 = time.perf_counter()
        test_sse_online_after = evaluate_test_sse(test_data, online.centroids)

        # 获取更新后的指标
        cent = online.centroids
        dists = np.linalg.norm(test_data[:,None,:] - cent[None,:,:], axis=2)
        online_sse = np.sum((test_data - cent[np.argmin(dists,axis=1)])**2)
        online_balance = model.get_balance_loss()

        print(f"[OnlineKMeans Update] SSE: {online_sse:.2f} | BalLoss: {online_balance:.2f} | Time: {t3 - t2:.2f}s")
        print(f"[OnlineKMeans Test SSE] Before: {test_sse_online_before:.2f} → After: {test_sse_online_after:.2f}")

        ### 🟨 方法3：MiniBatchKMeans（更新 test_data）
        mini = MiniBatchKMeans(d=dim, k=k, batch_size=128, iters=10, verbose=False)
        mini.centroids = base_centroids.copy()
        mini.counts = np.bincount(base_labels, minlength=k)

        test_sse_mini_before = evaluate_test_sse(test_data, mini.centroids)
        t4 = time.perf_counter()
        mini.partial_fit(test_data)
        t5 = time.perf_counter()
        test_sse_mini_after = evaluate_test_sse(test_data, mini.centroids)

        print(f"[MiniBatchKMeans Update] BalLoss: {mini.get_balance_loss():.2f} | Time: {t5 - t4:.2f}s")
        print(f"[MiniBatch Test SSE] Before: {test_sse_mini_before:.2f} → After: {test_sse_mini_after:.2f}")

def run_test_update_comparison(datasets, k=10, threshold=200):
    # 准备CSV文件存储结果
    csv_file = 'runtime_comparison.csv'
    headers = ['Dataset', 'Scale', 'Coreset', 'OnlineKMeans', 'MiniBatchKMeans']

    # 定义数据集和数据量的固定顺序
    DATASET_ORDER = ['Huatuo', 'LiveChat', 'deep', 'glove', 'sift']
    SCALE_ORDER = ['100', '500', '1k', '2k', '5k']

    # 如果文件不存在，写入表头
    if not os.path.exists(csv_file):
        with open(csv_file, mode='w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(headers)

    # 收集所有结果，稍后按顺序写入
    all_results = []

    # 按固定顺序处理数据集
    for dataset_name in DATASET_ORDER:
        # 找到对应的数据集配置
        ds = next((d for d in datasets if Path(d["train"]).stem.split("_")[0] == dataset_name), None)
        if not ds:
            continue

        name = dataset_name
        dim = ds["dim"]
        train_data = pd.read_csv(ds["train"], header=None).values.astype(np.float32)

        # 准备不同大小的测试数据集
        test_files = {
            '100': ds["test1"],
            '500': ds["test2"],
            '1k': ds["test3"],
            '2k': ds["test4"],
            '5k': ds["test5"]
        }

        # 使用Lloyd方法初始化
        lloyd = Lloyd(d=dim, k=k, verbose=False)
        lloyd.train(train_data)
        base_centroids = lloyd.centroids.copy()
        base_labels = lloyd.labels_.copy()

        # 按数据量顺序处理
        for size in SCALE_ORDER:
            test_path = test_files.get(size)
            if not test_path or not os.path.exists(test_path):
                print(f"警告：测试文件 {test_path} 不存在，跳过")
                continue

            test_data = pd.read_csv(test_path, header=None).values.astype(np.float32)
            print(f"\n====== 数据集: {name} (dim={dim}) | 测试数据量: {size} ======")

            # 存储运行时间的字典
            runtime_results = {
                'Dataset': name,
                'Scale': size,
                'Coreset': 0,
                'OnlineKMeans': 0,
                'MiniBatchKMeans': 0
            }

            # 方法1: DynamicCoreset
            dyn = DynamicKMeansCoreset(d=dim, k=k, threshold=threshold, verbose=False)
            for i, vec in enumerate(train_data):
                dyn.insert(f"{name}_train_{i}", vec)
            dyn.model.centroids = base_centroids.copy()
            dyn.model.labels_ = base_labels.copy()

            t0 = time.perf_counter()
            for i, vec in enumerate(test_data):
                dyn.insert(f"{name}_test_{i}", vec)
            dyn.retrain()
            t1 = time.perf_counter()
            runtime_results['Coreset'] = t1 - t0

            # 方法2: OnlineKMeans
            online = OnlineKMeans(d=dim, k=k)
            online.centroids = base_centroids.copy()
            online.counts = np.bincount(base_labels, minlength=k)

            t2 = time.perf_counter()
            for vec in test_data:
                online.partial_fit(vec)
            t3 = time.perf_counter()
            runtime_results['OnlineKMeans'] = t3 - t2

            # 方法3: MiniBatchKMeans
            mini = MiniBatchKMeans(d=dim, k=k, batch_size=32, iters=10, verbose=False)
            mini.centroids = base_centroids.copy()
            mini.counts = np.bincount(base_labels, minlength=k)

            t4 = time.perf_counter()
            mini.partial_fit(test_data)
            t5 = time.perf_counter()
            runtime_results['MiniBatchKMeans'] = t5 - t4

            # 打印结果
            print(f"[Coreset] 时间: {runtime_results['Coreset']:.4f}s")
            print(f"[OnlineKMeans] 时间: {runtime_results['OnlineKMeans']:.4f}s")
            print(f"[MiniBatchKMeans] 时间: {runtime_results['MiniBatchKMeans']:.4f}s")

            # 添加到结果列表
            all_results.append([
                runtime_results['Dataset'],
                runtime_results['Scale'],
                runtime_results['Coreset'],
                runtime_results['OnlineKMeans'],
                runtime_results['MiniBatchKMeans']
            ])

    # 按顺序写入所有结果
    with open(csv_file, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(headers)
        writer.writerows(all_results)

if __name__ == '__main__':
    datasets = [
        {
            "train": "/content/sample_data/Huatuo_1024d_10k.csv",
            "test1": "/content/sample_data/Huatuo_1024d_100.csv",
            "test2": "/content/sample_data/Huatuo_1024d_500.csv",
            "test3": "/content/sample_data/Huatuo_1024d_1k.csv",
            "test4": "/content/sample_data/Huatuo_1024d_2k.csv",
            "test5": "/content/sample_data/Huatuo_1024d_5k.csv",
            "dim": 1024
        },
        {
            "train": "/content/sample_data/LiveChat_1024d_10k.csv",
            "test1": "/content/sample_data/LiveChat_1024d_100.csv",
            "test2": "/content/sample_data/LiveChat_1024d_500.csv",
            "test3": "/content/sample_data/LiveChat_1024d_1k.csv",
            "test4": "/content/sample_data/LiveChat_1024d_2k.csv",
            "test5": "/content/sample_data/LiveChat_1024d_5k.csv",
            "dim": 1024
        },
        {
            "train": "/content/sample_data/deep_96d_10k.csv",
            "test1": "/content/sample_data/deep_96d_100.csv",
            "test2": "/content/sample_data/deep_96d_500.csv",
            "test3": "/content/sample_data/deep_96d_1k.csv",
            "test4": "/content/sample_data/deep_96d_2k.csv",
            "test5": "/content/sample_data/deep_96d_5k.csv",
            "dim": 96
        },
        {
            "train": "/content/sample_data/glove_300d_10k.csv",
            "test1": "/content/sample_data/glove_300d_100.csv",
            "test2": "/content/sample_data/glove_300d_500.csv",
            "test3": "/content/sample_data/glove_300d_1k.csv",
            "test4": "/content/sample_data/glove_300d_2k.csv",
            "test5": "/content/sample_data/glove_300d_5k.csv",
            "dim": 300
        },
        {
            "train": "/content/sample_data/sift_128d_10k.csv",
            "test1": "/content/sample_data/sift_128d_100.csv",
            "test2": "/content/sample_data/sift_128d_500.csv",
            "test3": "/content/sample_data/sift_128d_1k.csv",
            "test4": "/content/sample_data/sift_128d_2k.csv",
            "test5": "/content/sample_data/sift_128d_5k.csv",
            "dim": 128
        }
    ]

    # run_test_update_only_comparison(datasets, k=10, threshold=200)
    run_test_update_comparison(datasets, k=10, threshold=200)
    print(f"\n结果已保存")


[Coreset] 时间: 0.0082s
[OnlineKMeans] 时间: 0.0058s
[MiniBatchKMeans] 时间: 0.0087s

[Coreset] 时间: 0.0344s
[OnlineKMeans] 时间: 0.0276s
[MiniBatchKMeans] 时间: 0.0098s

[Coreset] 时间: 0.0710s
[OnlineKMeans] 时间: 0.0580s
[MiniBatchKMeans] 时间: 0.0121s

[Coreset] 时间: 0.1445s
[OnlineKMeans] 时间: 0.2645s
[MiniBatchKMeans] 时间: 0.0133s

[Coreset] 时间: 0.3669s
[OnlineKMeans] 时间: 0.3070s
[MiniBatchKMeans] 时间: 0.0139s

[Coreset] 时间: 0.0211s
[OnlineKMeans] 时间: 0.0148s
[MiniBatchKMeans] 时间: 0.0143s

[Coreset] 时间: 0.0344s
[OnlineKMeans] 时间: 0.0265s
[MiniBatchKMeans] 时间: 0.0095s

[Coreset] 时间: 0.0697s
[OnlineKMeans] 时间: 0.0567s
[MiniBatchKMeans] 时间: 0.0111s

[Coreset] 时间: 0.1605s
[OnlineKMeans] 时间: 0.0999s
[MiniBatchKMeans] 时间: 0.0123s

[Coreset] 时间: 0.3420s
[OnlineKMeans] 时间: 0.3136s
[MiniBatchKMeans] 时间: 0.0150s

[Coreset] 时间: 0.0029s
[OnlineKMeans] 时间: 0.0031s
[MiniBatchKMeans] 时间: 0.0054s

[Coreset] 时间: 0.0113s
[OnlineKMeans] 时间: 0.0153s
[MiniBatchKMeans] 时间: 0.0057s

[Coreset] 时间: 0.0234s
[OnlineKMeans] 时间

In [None]:
# 更新的balance loss
!pip install faiss-cpu
import numpy as np
import pandas as pd
import faiss
import time
from pathlib import Path
import gc
from scipy.linalg import orth # For creating orthogonal matrices
from collections import defaultdict
from typing import Dict, Tuple, List, Optional
import csv
import os

class Lloyd:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                 min_points_per_centroid=1, max_points_per_centroid=1000000000,
                 seed=1234, gpu=False, spherical=False,
                 update_index=True, frozen_centroids=False):
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo
        self.verbose = verbose
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.seed = seed
        self.gpu = gpu
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        self.centroids = None
        self.obj_history_ = None
        self.labels_ = None

        self.sse_ = None
        self.balance_loss_ = None
        self.runtime_ = None
        self.obj = None

    def compute_centroids_from_data(self, data_points, labels, num_clusters, data_dim):
        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        if labels is None:
            if self.verbose:
                print("Warning: Labels are None in compute_centroids_from_data.")
            return centroids

        for i in range(len(data_points)):
            label = labels[i]
            centroids[label] += data_points[i]
            counts[label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning: Cluster {j} is empty. Assigning random point.")
                centroids[j] = data_points[np.random.randint(len(data_points))]

        return centroids

    def train(self, x_orig_data, weights=None, init_centroids=None):
        start_time = time.time()
        np.random.seed(self.seed)

        if x_orig_data.shape[1] != self.d:
            raise ValueError(f"Input dimension {x_orig_data.shape[1]} != {self.d}")

        n, dim = x_orig_data.shape
        x = np.ascontiguousarray(x_orig_data, dtype='float32')

        kmeans = faiss.Kmeans(
            d=self.d,
            k=self.k,
            niter=self.niter,
            nredo=self.nredo,
            verbose=self.verbose,
            min_points_per_centroid=self.min_points_per_centroid,
            max_points_per_centroid=self.max_points_per_centroid,
            seed=self.seed,
            gpu=self.gpu,
            spherical=self.spherical,
            update_index=self.update_index,
            frozen_centroids=self.frozen_centroids
        )

        kmeans.train(x, init_centroids=init_centroids)

        _, self.labels_ = kmeans.index.search(x, 1)
        self.labels_ = self.labels_.flatten()

        self.centroids = kmeans.centroids
        self.obj_history_ = kmeans.obj if kmeans.obj is not None and len(kmeans.obj) > 0 else np.zeros(self.niter)
        self.obj = kmeans.obj[-1] if kmeans.obj is not None and len(kmeans.obj) > 0 else None
        self.runtime_ = time.time() - start_time


        # Print every 5th iteration's objective value
        if self.verbose and self.obj_history_ is not None and len(self.obj_history_) > 0:
            print("\n--- Objective Value (every 5 iterations) ---")
            for i, val in enumerate(self.obj_history_):
                if (i + 1) % 5 == 0 or i == len(self.obj_history_) - 1:
                    print(f"  Iter {i+1:2d}: {val:.6f}")

        final_sse = 0
        for i in range(n):
            cluster_idx = self.labels_[i]
            final_sse += np.sum((x_orig_data[i] - self.centroids[cluster_idx]) ** 2)
        self.sse_ = final_sse

        final_balance_loss = 0
        if self.labels_ is not None:
            sizes = np.bincount(self.labels_, minlength=self.k)
            ideal = n / self.k
            final_balance_loss = np.sum((sizes - ideal) ** 2)
        self.balance_loss_ = final_balance_loss

        if self.verbose:
            print(f"Lloyd training finished in {self.runtime_:.4f}s")
            print(f"Final obj: {self.obj}")
            print(f"Cluster sizes: {dict(zip(*np.unique(self.labels_, return_counts=True)))}")
            print(f"SSE: {self.sse_:.4f}")
            print(f"Balance Loss: {self.balance_loss_:.4f}")

class DynamicKMeansCoreset:
    def __init__(self, d, k, threshold=100, **lloyd_kwargs):
        self.d = d
        self.k = k
        self.threshold = threshold
        self.lloyd_kwargs = lloyd_kwargs

        self.buffer = [] # 临时存储插入/删除的点ID
        self.coreset_tree = defaultdict(list) # 分层存储核心集的数据结构
        self.active_points: Dict[str, Tuple[np.ndarray, float]] = {} # 当前活跃的点(未删除)
        self.deleted_points = set() # 已删除的点ID集合
        self.model = Lloyd(d=d, k=k, **lloyd_kwargs)

    def insert(self, point_id: str, vector: np.ndarray, weight: float = 1.0):
        self.active_points[point_id] = (np.array(vector, dtype=np.float32), weight)
        self.buffer.append(point_id)
        if len(self.buffer) >= self.threshold:
            self._add_coreset(self.buffer)
            self.buffer.clear()

    def delete(self, point_id: str):
        if point_id in self.active_points:
            self.deleted_points.add(point_id)
            self.buffer.append(point_id)
            if len(self.buffer) >= self.threshold:
                self._add_coreset(self.buffer)
                self.buffer.clear()

    def _add_coreset(self, point_ids: List[str]):
        X, weights = [], []
        for pid in point_ids:
            if pid in self.active_points and pid not in self.deleted_points:
                vec, w = self.active_points[pid]
                X.append(vec)
                weights.append(w)

        if not X:
            return

        X = np.vstack(X).astype(np.float32)  # 将向量堆叠成矩阵
        weights = np.array(weights, dtype=np.float32)  # 将权重转为numpy数组
        self._merge_into_tree(0, (X, weights))  # 将新核心集合并到树中

    def _merge_into_tree(self, level: int, new_coreset: Tuple[np.ndarray, np.ndarray]):
        self.coreset_tree[level].append(new_coreset)
        if len(self.coreset_tree[level]) > 1:
            X1, w1 = self.coreset_tree[level].pop()
            X2, w2 = self.coreset_tree[level].pop()
            X = np.vstack([X1, X2])
            W = np.concatenate([w1, w2])
            reducer = Lloyd(d=self.d, k=self.k, **self.lloyd_kwargs)
            reducer.train(X, weights=W)
            reduced_X = reducer.centroids
            reduced_W = np.bincount(reducer.labels_, minlength=self.k).astype(np.float32)
            self._merge_into_tree(level + 1, (reduced_X, reduced_W))

    def retrain(self):
        if self.buffer:
            self._add_coreset(self.buffer)
            self.buffer.clear()
        X_all, W_all = [], []
        for level in self.coreset_tree:
            for X, W in self.coreset_tree[level]:
                X_all.append(X)
                W_all.append(W)
        if not X_all:
            return
        X = np.vstack(X_all)
        W = np.concatenate(W_all)
        mask = W > 0
        X = X[mask]
        W = W[mask]
        self.model.train(X, weights=W)

    def get_centroids(self) -> Optional[np.ndarray]:
        return self.model.centroids

    def get_labels(self) -> Optional[np.ndarray]:
        return self.model.labels_

    def get_metrics(self) -> Dict[str, float]:
        return {
            "sse": self.model.sse_,
            "balance_loss": self.model.balance_loss_,
            "runtime": self.model.runtime_,
            "objective": self.model.obj
        }

class OnlineKMeans:
    def __init__(self, d, k, init_method='split', verbose=False, seed=0):
        self.d, self.k = d, k
        self.verbose = verbose
        np.random.seed(seed)
        self.centroids = None
        self.counts = None  # 每簇访问计数

    def initialize(self, first_point):
        self.centroids = first_point.reshape(1, -1)
        self.counts = np.array([1], dtype=int)

    def maybe_split(self):
        # 找方差最大簇，二分裂
        if self.centroids.shape[0] < self.k:
            variances = []
            for i in range(len(self.counts)):
                variances.append(np.random.rand())  # 实际应估方差，这里 placeholder
            j = np.argmax(variances)
            c = self.centroids[j]
            delta = np.random.randn(self.d) * 1e-3
            new_c1, new_c2 = c + delta, c - delta
            self.centroids[j] = new_c1
            self.centroids = np.vstack([self.centroids, new_c2])
            self.counts[j] = self.counts[j] // 2
            self.counts = np.append(self.counts, self.counts[j])
            if self.verbose:
                print(f"Split centroid {j} into 2; now {self.centroids.shape[0]} total.")

    def partial_fit(self, x_new):
        if self.centroids is None:
            self.initialize(x_new)
            return
        # 找最近质心
        dists = np.linalg.norm(self.centroids - x_new, axis=1)
        j = np.argmin(dists)
        self.counts[j] += 1
        eta = 1.0 / self.counts[j]
        self.centroids[j] += eta * (x_new - self.centroids[j])
        # 尝试分裂
        if self.centroids.shape[0] < self.k:
            self.maybe_split()

    def get_balance_loss(self):
        if self.counts is None or len(self.counts) == 0:
            return 0.0
        total = np.sum(self.counts)
        ideal = total / self.k
        return np.sum((self.counts - ideal) ** 2)

class MiniBatchKMeans(Lloyd):
    def __init__(self, d, k, batch_size=256, iters=100, **kwargs):
        super().__init__(d=d, k=k, **kwargs)
        self.batch_size = batch_size
        self.iters = iters

    def partial_fit(self, X_stream):
        # X_stream: shape (n_stream, d)
        # 初始化质心（如果为空则从 stream 抽 batch 初始化）
        if self.centroids is None:
            init = X_stream[np.random.choice(len(X_stream), self.k, replace=False)]
            self.centroids = init
            self.counts = np.zeros(self.k, dtype=int)

        for i in range(self.iters):
            idx = np.random.choice(len(X_stream), self.batch_size, replace=False)
            batch = X_stream[idx]
            index = faiss.IndexFlatL2(self.d)
            index.add(self.centroids.astype(np.float32))
            _, labels = index.search(batch.astype(np.float32), 1)

            labels = labels.flatten()
            for j in range(self.k):
                mask = (labels == j)
                m = np.sum(mask)
                if m > 0:
                    centroid = self.centroids[j]
                    mean_batch = batch[mask].mean(axis=0)
                    eta = 1.0 / (self.counts[j] + m)
                    self.centroids[j] = (1 - eta) * centroid + eta * mean_batch
                    self.counts[j] += m
        # 设置 labels 未定义，balance/SSE 可后续重算

    def get_balance_loss(self):
        if getattr(self, 'counts', None) is None:
            return None
        ideal = self.counts.sum() / self.k
        return np.sum((self.counts - ideal)**2)

def evaluate_test_sse(test_data, centroids):
    dists = np.linalg.norm(test_data[:, None, :] - centroids[None, :, :], axis=2)
    nearest = np.argmin(dists, axis=1)
    return np.sum((test_data - centroids[nearest])**2)


def run_test_update_comparison(datasets, k=10, threshold=200):
    # 准备CSV文件存储结果
    csv_file = 'balance_loss_comparison.csv'
    headers = ['Dataset', 'Scale', 'Coreset', 'OnlineKMeans', 'MiniBatch']

    # 定义数据集和数据量的固定顺序
    DATASET_ORDER = ['Huatuo', 'LiveChat', 'deep', 'glove', 'sift']
    SCALE_ORDER = ['100', '500', '1k', '2k', '5k']

    # 如果文件不存在，写入表头
    if not os.path.exists(csv_file):
        with open(csv_file, mode='w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(headers)

    # 收集所有结果
    all_results = []

    # 按固定顺序处理数据集
    for dataset_name in DATASET_ORDER:
        # 找到对应的数据集配置
        ds = next((d for d in datasets if Path(d["train"]).stem.split("_")[0] == dataset_name), None)
        if not ds:
            continue

        name = dataset_name
        dim = ds["dim"]
        train_data = pd.read_csv(ds["train"], header=None).values.astype(np.float32)

        # 准备不同大小的测试数据集
        test_files = {
            '100': ds["test1"],
            '500': ds["test2"],
            '1k': ds["test3"],
            '2k': ds["test4"],
            '5k': ds["test5"]
        }

        # 使用Lloyd方法初始化获取基础平衡损失
        lloyd = Lloyd(d=dim, k=k, verbose=False)
        lloyd.train(train_data)
        base_balance_loss = lloyd.balance_loss_
        base_centroids = lloyd.centroids.copy()
        base_labels = lloyd.labels_.copy()

        # 按数据量顺序处理
        for size in SCALE_ORDER:
            test_path = test_files.get(size)
            if not test_path or not os.path.exists(test_path):
                print(f"警告：测试文件 {test_path} 不存在，跳过")
                continue

            test_data = pd.read_csv(test_path, header=None).values.astype(np.float32)
            print(f"\n====== 数据集: {name} (维度={dim}) | 测试数据量: {size} ======")

            # 存储结果的字典
            results = {
                'Dataset': name,
                'Scale': size,
                'Coreset': 0,
                'OnlineKMeans': 0,
                'MiniBatch': 0
            }

            # 方法1: DynamicCoreset
            dyn = DynamicKMeansCoreset(d=dim, k=k, threshold=threshold, verbose=False)
            # 插入训练数据
            for i, vec in enumerate(train_data):
                dyn.insert(f"{name}_train_{i}", vec)
            # 插入测试数据
            for i, vec in enumerate(test_data):
                dyn.insert(f"{name}_test_{i}", vec)
            dyn.retrain()

            # 计算整体平衡损失（训练+测试）
            combined_data = np.vstack([train_data, test_data])
            dists = np.linalg.norm(combined_data[:, None, :] - dyn.model.centroids[None, :, :], axis=2)
            labels = np.argmin(dists, axis=1)
            sizes = np.bincount(labels, minlength=k)
            ideal = len(combined_data) / k
            results['Coreset'] = np.sum((sizes - ideal) ** 2)

            # 方法2: OnlineKMeans
            online = OnlineKMeans(d=dim, k=k)
            online.centroids = base_centroids.copy()
            online.counts = np.bincount(base_labels, minlength=k)

            # 先用训练数据更新
            for vec in train_data:
                online.partial_fit(vec)
            # 再用测试数据更新
            for vec in test_data:
                online.partial_fit(vec)

            # 计算整体平衡损失
            dists = np.linalg.norm(combined_data[:, None, :] - online.centroids[None, :, :], axis=2)
            labels = np.argmin(dists, axis=1)
            sizes = np.bincount(labels, minlength=k)
            results['OnlineKMeans'] = np.sum((sizes - ideal) ** 2)

            # 方法3: MiniBatchKMeans
            mini = MiniBatchKMeans(d=dim, k=k, batch_size=32, iters=10, verbose=False)
            mini.centroids = base_centroids.copy()
            mini.counts = np.bincount(base_labels, minlength=k)

            # 先用训练数据更新
            mini.partial_fit(train_data)
            # 再用测试数据更新
            mini.partial_fit(test_data)

            # 计算整体平衡损失
            dists = np.linalg.norm(combined_data[:, None, :] - mini.centroids[None, :, :], axis=2)
            labels = np.argmin(dists, axis=1)
            sizes = np.bincount(labels, minlength=k)
            results['MiniBatch'] = np.sum((sizes - ideal) ** 2)

            # 打印结果
            print(f"初始平衡损失: {base_balance_loss:.2f}")
            print(f"[Coreset] 平衡损失: {results['Coreset']:.2f}")
            print(f"[OnlineKMeans] 平衡损失: {results['OnlineKMeans']:.2f}")
            print(f"[MiniBatchKMeans] 平衡损失: {results['MiniBatch']:.2f}")

            # 添加到结果列表
            all_results.append([
                results['Dataset'],
                results['Scale'],
                results['Coreset'],
                results['OnlineKMeans'],
                results['MiniBatch']
            ])

    # 按顺序写入所有结果
    with open(csv_file, mode='w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(headers)
        writer.writerows(all_results)

if __name__ == '__main__':
    datasets = [
        {
            "train": "/content/sample_data/Huatuo_1024d_10k.csv",
            "test1": "/content/sample_data/Huatuo_1024d_100.csv",
            "test2": "/content/sample_data/Huatuo_1024d_500.csv",
            "test3": "/content/sample_data/Huatuo_1024d_1k.csv",
            "test4": "/content/sample_data/Huatuo_1024d_2k.csv",
            "test5": "/content/sample_data/Huatuo_1024d_5k.csv",
            "dim": 1024
        },
        {
            "train": "/content/sample_data/LiveChat_1024d_10k.csv",
            "test1": "/content/sample_data/LiveChat_1024d_100.csv",
            "test2": "/content/sample_data/LiveChat_1024d_500.csv",
            "test3": "/content/sample_data/LiveChat_1024d_1k.csv",
            "test4": "/content/sample_data/LiveChat_1024d_2k.csv",
            "test5": "/content/sample_data/LiveChat_1024d_5k.csv",
            "dim": 1024
        },
        {
            "train": "/content/sample_data/deep_96d_10k.csv",
            "test1": "/content/sample_data/deep_96d_100.csv",
            "test2": "/content/sample_data/deep_96d_500.csv",
            "test3": "/content/sample_data/deep_96d_1k.csv",
            "test4": "/content/sample_data/deep_96d_2k.csv",
            "test5": "/content/sample_data/deep_96d_5k.csv",
            "dim": 96
        },
        {
            "train": "/content/sample_data/glove_300d_10k.csv",
            "test1": "/content/sample_data/glove_300d_100.csv",
            "test2": "/content/sample_data/glove_300d_500.csv",
            "test3": "/content/sample_data/glove_300d_1k.csv",
            "test4": "/content/sample_data/glove_300d_2k.csv",
            "test5": "/content/sample_data/glove_300d_5k.csv",
            "dim": 300
        },
        {
            "train": "/content/sample_data/sift_128d_10k.csv",
            "test1": "/content/sample_data/sift_128d_100.csv",
            "test2": "/content/sample_data/sift_128d_500.csv",
            "test3": "/content/sample_data/sift_128d_1k.csv",
            "test4": "/content/sample_data/sift_128d_2k.csv",
            "test5": "/content/sample_data/sift_128d_5k.csv",
            "dim": 128
        }
    ]

    run_test_update_comparison(datasets, k=10, threshold=200)
    print(f"\n结果已保存")


初始平衡损失: 837824.90
[Coreset] 平衡损失: 18520505.60
[OnlineKMeans] 平衡损失: 861825.60
[MiniBatchKMeans] 平衡损失: 860715.60

初始平衡损失: 837824.90
[Coreset] 平衡损失: 21121399.60
[OnlineKMeans] 平衡损失: 937327.60
[MiniBatchKMeans] 平衡损失: 945901.60

初始平衡损失: 837824.90
[Coreset] 平衡损失: 4364881.60
[OnlineKMeans] 平衡损失: 1027969.60
[MiniBatchKMeans] 平衡损失: 1044507.60

初始平衡损失: 837824.90
[Coreset] 平衡损失: 5409673.60
[OnlineKMeans] 平衡损失: 1242217.60
[MiniBatchKMeans] 平衡损失: 1268505.60

初始平衡损失: 837824.90
[Coreset] 平衡损失: 10848703.60
[OnlineKMeans] 平衡损失: 1935787.60
[MiniBatchKMeans] 平衡损失: 1955811.60

初始平衡损失: 1284214.90
[Coreset] 平衡损失: 88950887.60
[OnlineKMeans] 平衡损失: 1336655.60
[MiniBatchKMeans] 平衡损失: 1324987.60

初始平衡损失: 1284214.90
[Coreset] 平衡损失: 46579175.60
[OnlineKMeans] 平衡损失: 1467303.60
[MiniBatchKMeans] 平衡损失: 1437917.60

初始平衡损失: 1284214.90
[Coreset] 平衡损失: 60329683.60
[OnlineKMeans] 平衡损失: 1613771.60
[MiniBatchKMeans] 平衡损失: 1588573.60

初始平衡损失: 1284214.90
[Coreset] 平衡损失: 51309661.60
[OnlineKMeans] 平衡损失: 1938925.60
[MiniBatchK

In [None]:
!pip install faiss-cpu
import numpy as np
import pandas as pd
import faiss
import time
from pathlib import Path
import gc
from scipy.linalg import orth # For creating orthogonal matrices
from collections import defaultdict
from typing import Dict, Tuple, List, Optional
import csv
import os

class Lloyd:
    def __init__(self, d, k, niter=25, nredo=1, verbose=True,
                 min_points_per_centroid=1, max_points_per_centroid=1000000000,
                 seed=1234, gpu=False, spherical=False,
                 update_index=True, frozen_centroids=False):
        self.d = d
        self.k = k
        self.niter = niter
        self.nredo = nredo
        self.verbose = verbose
        self.min_points_per_centroid = min_points_per_centroid
        self.max_points_per_centroid = max_points_per_centroid
        self.seed = seed
        self.gpu = gpu
        self.spherical = spherical
        self.update_index = update_index
        self.frozen_centroids = frozen_centroids

        self.centroids = None
        self.obj_history_ = None
        self.labels_ = None

        self.sse_ = None
        self.balance_loss_ = None
        self.runtime_ = None
        self.obj = None

    def compute_centroids_from_data(self, data_points, labels, num_clusters, data_dim):
        centroids = np.zeros((num_clusters, data_dim), dtype=np.float32)
        counts = np.zeros(num_clusters, dtype=int)

        if labels is None:
            if self.verbose:
                print("Warning: Labels are None in compute_centroids_from_data.")
            return centroids

        for i in range(len(data_points)):
            label = labels[i]
            centroids[label] += data_points[i]
            counts[label] += 1

        for j in range(num_clusters):
            if counts[j] > 0:
                centroids[j] /= counts[j]
            else:
                if self.verbose:
                    print(f"Warning: Cluster {j} is empty. Assigning random point.")
                centroids[j] = data_points[np.random.randint(len(data_points))]

        return centroids

    def train(self, x_orig_data, weights=None, init_centroids=None):
        start_time = time.time()
        np.random.seed(self.seed)

        if x_orig_data.shape[1] != self.d:
            raise ValueError(f"Input dimension {x_orig_data.shape[1]} != {self.d}")

        n, dim = x_orig_data.shape
        x = np.ascontiguousarray(x_orig_data, dtype='float32')

        kmeans = faiss.Kmeans(
            d=self.d,
            k=self.k,
            niter=self.niter,
            nredo=self.nredo,
            verbose=self.verbose,
            min_points_per_centroid=self.min_points_per_centroid,
            max_points_per_centroid=self.max_points_per_centroid,
            seed=self.seed,
            gpu=self.gpu,
            spherical=self.spherical,
            update_index=self.update_index,
            frozen_centroids=self.frozen_centroids
        )

        kmeans.train(x, init_centroids=init_centroids)

        _, self.labels_ = kmeans.index.search(x, 1)
        self.labels_ = self.labels_.flatten()

        self.centroids = kmeans.centroids
        self.obj_history_ = kmeans.obj if kmeans.obj is not None and len(kmeans.obj) > 0 else np.zeros(self.niter)
        self.obj = kmeans.obj[-1] if kmeans.obj is not None and len(kmeans.obj) > 0 else None
        self.runtime_ = time.time() - start_time


        # Print every 5th iteration's objective value
        if self.verbose and self.obj_history_ is not None and len(self.obj_history_) > 0:
            print("\n--- Objective Value (every 5 iterations) ---")
            for i, val in enumerate(self.obj_history_):
                if (i + 1) % 5 == 0 or i == len(self.obj_history_) - 1:
                    print(f"  Iter {i+1:2d}: {val:.6f}")

        final_sse = 0
        for i in range(n):
            cluster_idx = self.labels_[i]
            final_sse += np.sum((x_orig_data[i] - self.centroids[cluster_idx]) ** 2)
        self.sse_ = final_sse

        final_balance_loss = 0
        if self.labels_ is not None:
            sizes = np.bincount(self.labels_, minlength=self.k)
            ideal = n / self.k
            final_balance_loss = np.sum((sizes - ideal) ** 2)
        self.balance_loss_ = final_balance_loss

        if self.verbose:
            print(f"Lloyd training finished in {self.runtime_:.4f}s")
            print(f"Final obj: {self.obj}")
            print(f"Cluster sizes: {dict(zip(*np.unique(self.labels_, return_counts=True)))}")
            print(f"SSE: {self.sse_:.4f}")
            print(f"Balance Loss: {self.balance_loss_:.4f}")

class DynamicKMeansCoreset:
    def __init__(self, d, k, threshold=100, **lloyd_kwargs):
        self.d = d
        self.k = k
        self.threshold = threshold
        self.lloyd_kwargs = lloyd_kwargs

        self.buffer = [] # 临时存储插入/删除的点ID
        self.coreset_tree = defaultdict(list) # 分层存储核心集的数据结构
        self.active_points: Dict[str, Tuple[np.ndarray, float]] = {} # 当前活跃的点(未删除)
        self.deleted_points = set() # 已删除的点ID集合
        self.model = Lloyd(d=d, k=k, **lloyd_kwargs)

    def insert(self, point_id: str, vector: np.ndarray, weight: float = 1.0):
        self.active_points[point_id] = (np.array(vector, dtype=np.float32), weight)
        self.buffer.append(point_id)
        if len(self.buffer) >= self.threshold:
            self._add_coreset(self.buffer)
            self.buffer.clear()

    def delete(self, point_id: str):
        if point_id in self.active_points:
            self.deleted_points.add(point_id)
            self.buffer.append(point_id)
            if len(self.buffer) >= self.threshold:
                self._add_coreset(self.buffer)
                self.buffer.clear()

    def _add_coreset(self, point_ids: List[str]):
        X, weights = [], []
        for pid in point_ids:
            if pid in self.active_points and pid not in self.deleted_points:
                vec, w = self.active_points[pid]
                X.append(vec)
                weights.append(w)

        if not X:
            return

        X = np.vstack(X).astype(np.float32)  # 将向量堆叠成矩阵
        weights = np.array(weights, dtype=np.float32)  # 将权重转为numpy数组
        self._merge_into_tree(0, (X, weights))  # 将新核心集合并到树中

    def _merge_into_tree(self, level: int, new_coreset: Tuple[np.ndarray, np.ndarray]):
        self.coreset_tree[level].append(new_coreset)
        if len(self.coreset_tree[level]) > 1:
            X1, w1 = self.coreset_tree[level].pop()
            X2, w2 = self.coreset_tree[level].pop()
            X = np.vstack([X1, X2])
            W = np.concatenate([w1, w2])
            reducer = Lloyd(d=self.d, k=self.k, **self.lloyd_kwargs)
            reducer.train(X, weights=W)
            reduced_X = reducer.centroids
            reduced_W = np.bincount(reducer.labels_, minlength=self.k).astype(np.float32)
            self._merge_into_tree(level + 1, (reduced_X, reduced_W))

    def retrain(self):
        if self.buffer:
            self._add_coreset(self.buffer)
            self.buffer.clear()
        X_all, W_all = [], []
        for level in self.coreset_tree:
            for X, W in self.coreset_tree[level]:
                X_all.append(X)
                W_all.append(W)
        if not X_all:
            return
        X = np.vstack(X_all)
        W = np.concatenate(W_all)
        mask = W > 0
        X = X[mask]
        W = W[mask]
        self.model.train(X, weights=W)

    def get_centroids(self) -> Optional[np.ndarray]:
        return self.model.centroids

    def get_labels(self) -> Optional[np.ndarray]:
        return self.model.labels_

    def get_metrics(self) -> Dict[str, float]:
        return {
            "sse": self.model.sse_,
            "balance_loss": self.model.balance_loss_,
            "runtime": self.model.runtime_,
            "objective": self.model.obj
        }

class OnlineKMeans:
    def __init__(self, d, k, init_method='split', verbose=False, seed=0):
        self.d, self.k = d, k
        self.verbose = verbose
        np.random.seed(seed)
        self.centroids = None
        self.counts = None  # 每簇访问计数

    def initialize(self, first_point):
        self.centroids = first_point.reshape(1, -1)
        self.counts = np.array([1], dtype=int)

    def maybe_split(self):
        # 找方差最大簇，二分裂
        if self.centroids.shape[0] < self.k:
            variances = []
            for i in range(len(self.counts)):
                variances.append(np.random.rand())  # 实际应估方差，这里 placeholder
            j = np.argmax(variances)
            c = self.centroids[j]
            delta = np.random.randn(self.d) * 1e-3
            new_c1, new_c2 = c + delta, c - delta
            self.centroids[j] = new_c1
            self.centroids = np.vstack([self.centroids, new_c2])
            self.counts[j] = self.counts[j] // 2
            self.counts = np.append(self.counts, self.counts[j])
            if self.verbose:
                print(f"Split centroid {j} into 2; now {self.centroids.shape[0]} total.")

    def partial_fit(self, x_new):
        if self.centroids is None:
            self.initialize(x_new)
            return
        # 找最近质心
        dists = np.linalg.norm(self.centroids - x_new, axis=1)
        j = np.argmin(dists)
        self.counts[j] += 1
        eta = 1.0 / self.counts[j]
        self.centroids[j] += eta * (x_new - self.centroids[j])
        # 尝试分裂
        if self.centroids.shape[0] < self.k:
            self.maybe_split()

    def get_balance_loss(self):
        if self.counts is None or len(self.counts) == 0:
            return 0.0
        total = np.sum(self.counts)
        ideal = total / self.k
        return np.sum((self.counts - ideal) ** 2)

class MiniBatchKMeans(Lloyd):
    def __init__(self, d, k, batch_size=256, iters=100, **kwargs):
        super().__init__(d=d, k=k, **kwargs)
        self.batch_size = batch_size
        self.iters = iters

    def partial_fit(self, X_stream):
        # X_stream: shape (n_stream, d)
        # 初始化质心（如果为空则从 stream 抽 batch 初始化）
        if self.centroids is None:
            init = X_stream[np.random.choice(len(X_stream), self.k, replace=False)]
            self.centroids = init
            self.counts = np.zeros(self.k, dtype=int)

        for i in range(self.iters):
            idx = np.random.choice(len(X_stream), self.batch_size, replace=False)
            batch = X_stream[idx]
            index = faiss.IndexFlatL2(self.d)
            index.add(self.centroids.astype(np.float32))
            _, labels = index.search(batch.astype(np.float32), 1)

            labels = labels.flatten()
            for j in range(self.k):
                mask = (labels == j)
                m = np.sum(mask)
                if m > 0:
                    centroid = self.centroids[j]
                    mean_batch = batch[mask].mean(axis=0)
                    eta = 1.0 / (self.counts[j] + m)
                    self.centroids[j] = (1 - eta) * centroid + eta * mean_batch
                    self.counts[j] += m
        # 设置 labels 未定义，balance/SSE 可后续重算

    def get_balance_loss(self):
        if getattr(self, 'counts', None) is None:
            return None
        ideal = self.counts.sum() / self.k
        return np.sum((self.counts - ideal)**2)

def evaluate_test_sse(test_data, centroids):
    dists = np.linalg.norm(test_data[:, None, :] - centroids[None, :, :], axis=2)
    nearest = np.argmin(dists, axis=1)
    return np.sum((test_data - centroids[nearest])**2)


def run_test_update_comparison(datasets, k=10, threshold=200):
    # 准备CSV文件存储结果
    csv_file = 'sse_comparison.csv'
    headers = ['Dataset', 'Scale', 'Coreset', 'OnlineKMeans', 'MiniBatch']

    # 定义数据集和数据量的固定顺序
    DATASET_ORDER = ['Huatuo', 'LiveChat', 'deep', 'glove', 'sift']
    SCALE_ORDER = ['100', '500', '1k', '2k', '5k']

    # 如果文件不存在，写入表头
    if not os.path.exists(csv_file):
        with open(csv_file, mode='w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(headers)

    # 收集所有结果
    all_results = []

    # 按固定顺序处理数据集
    for dataset_name in DATASET_ORDER:
        # 找到对应的数据集配置
        ds = next((d for d in datasets if Path(d["train"]).stem.split("_")[0] == dataset_name), None)
        if not ds:
            continue

        name = dataset_name
        dim = ds["dim"]
        train_data = pd.read_csv(ds["train"], header=None).values.astype(np.float32)

        # 准备不同大小的测试数据集
        test_files = {
            '100': ds["test1"],
            '500': ds["test2"],
            '1k': ds["test3"],
            '2k': ds["test4"],
            '5k': ds["test5"]
        }

        # 使用Lloyd方法初始化获取基础平衡损失
        lloyd = Lloyd(d=dim, k=k, verbose=False)
        lloyd.train(train_data)
        base_balance_loss = lloyd.balance_loss_
        base_centroids = lloyd.centroids.copy()
        base_labels = lloyd.labels_.copy()

        # 按数据量顺序处理
        for size in SCALE_ORDER:
            test_path = test_files.get(size)
            if not test_path or not os.path.exists(test_path):
                print(f"警告：测试文件 {test_path} 不存在，跳过")
                continue

            test_data = pd.read_csv(test_path, header=None).values.astype(np.float32)
            print(f"\n====== 数据集: {name} (维度={dim}) | 测试数据量: {size} ======")

            # 存储结果的字典
            results = {
                'Dataset': name,
                'Scale': size,
                'Coreset': 0,
                'OnlineKMeans': 0,
                'MiniBatch': 0
            }

            # 方法1: DynamicCoreset
            dyn = DynamicKMeansCoreset(d=dim, k=k, threshold=threshold, verbose=False)
            # 插入训练数据
            for i, vec in enumerate(train_data):
                dyn.insert(f"{name}_train_{i}", vec)
            # 插入测试数据
            for i, vec in enumerate(test_data):
                dyn.insert(f"{name}_test_{i}", vec)
            dyn.retrain()

            # 计算整体SSE（训练+测试）
            combined_data = np.vstack([train_data, test_data])
            dists = np.linalg.norm(combined_data[:, None, :] - dyn.model.centroids[None, :, :], axis=2)
            nearest_centroids = np.argmin(dists, axis=1)
            sse = np.sum((combined_data - dyn.model.centroids[nearest_centroids]) ** 2)
            results['Coreset'] = sse

            # 方法2: OnlineKMeans
            online = OnlineKMeans(d=dim, k=k)
            online.centroids = base_centroids.copy()
            online.counts = np.bincount(base_labels, minlength=k)

            # 先用训练数据更新
            for vec in train_data:
                online.partial_fit(vec)
            # 再用测试数据更新
            for vec in test_data:
                online.partial_fit(vec)

            # 计算整体SSE
            dists = np.linalg.norm(combined_data[:, None, :] - online.centroids[None, :, :], axis=2)
            nearest_centroids = np.argmin(dists, axis=1)
            sse = np.sum((combined_data - online.centroids[nearest_centroids]) ** 2)
            results['OnlineKMeans'] = sse

            # 方法3: MiniBatchKMeans
            mini = MiniBatchKMeans(d=dim, k=k, batch_size=32, iters=10, verbose=False)
            mini.centroids = base_centroids.copy()
            mini.counts = np.bincount(base_labels, minlength=k)

            # 先用训练数据更新
            mini.partial_fit(train_data)
            # 再用测试数据更新
            mini.partial_fit(test_data)

            # 计算整体SSE
            dists = np.linalg.norm(combined_data[:, None, :] - mini.centroids[None, :, :], axis=2)
            nearest_centroids = np.argmin(dists, axis=1)
            sse = np.sum((combined_data - mini.centroids[nearest_centroids]) ** 2)
            results['MiniBatch'] = sse

            # 打印结果
            print(f"初始SSE: {base_balance_loss:.2f}")
            print(f"[Coreset] SSE: {results['Coreset']:.2f}")
            print(f"[OnlineKMeans] SSE: {results['OnlineKMeans']:.2f}")
            print(f"[MiniBatchKMeans] SSE: {results['MiniBatch']:.2f}")

            # 添加到结果列表
            all_results.append([
                results['Dataset'],
                results['Scale'],
                results['Coreset'],
                results['OnlineKMeans'],
                results['MiniBatch']
            ])

    # 按顺序写入所有结果
    with open(csv_file, mode='w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(headers)
        writer.writerows(all_results)

if __name__ == '__main__':
    datasets = [
        {
            "train": "/content/sample_data/Huatuo_1024d_10k.csv",
            "test1": "/content/sample_data/Huatuo_1024d_100.csv",
            "test2": "/content/sample_data/Huatuo_1024d_500.csv",
            "test3": "/content/sample_data/Huatuo_1024d_1k.csv",
            "test4": "/content/sample_data/Huatuo_1024d_2k.csv",
            "test5": "/content/sample_data/Huatuo_1024d_5k.csv",
            "dim": 1024
        },
        {
            "train": "/content/sample_data/LiveChat_1024d_10k.csv",
            "test1": "/content/sample_data/LiveChat_1024d_100.csv",
            "test2": "/content/sample_data/LiveChat_1024d_500.csv",
            "test3": "/content/sample_data/LiveChat_1024d_1k.csv",
            "test4": "/content/sample_data/LiveChat_1024d_2k.csv",
            "test5": "/content/sample_data/LiveChat_1024d_5k.csv",
            "dim": 1024
        },
        {
            "train": "/content/sample_data/deep_96d_10k.csv",
            "test1": "/content/sample_data/deep_96d_100.csv",
            "test2": "/content/sample_data/deep_96d_500.csv",
            "test3": "/content/sample_data/deep_96d_1k.csv",
            "test4": "/content/sample_data/deep_96d_2k.csv",
            "test5": "/content/sample_data/deep_96d_5k.csv",
            "dim": 96
        },
        {
            "train": "/content/sample_data/glove_300d_10k.csv",
            "test1": "/content/sample_data/glove_300d_100.csv",
            "test2": "/content/sample_data/glove_300d_500.csv",
            "test3": "/content/sample_data/glove_300d_1k.csv",
            "test4": "/content/sample_data/glove_300d_2k.csv",
            "test5": "/content/sample_data/glove_300d_5k.csv",
            "dim": 300
        },
        {
            "train": "/content/sample_data/sift_128d_10k.csv",
            "test1": "/content/sample_data/sift_128d_100.csv",
            "test2": "/content/sample_data/sift_128d_500.csv",
            "test3": "/content/sample_data/sift_128d_1k.csv",
            "test4": "/content/sample_data/sift_128d_2k.csv",
            "test5": "/content/sample_data/sift_128d_5k.csv",
            "dim": 128
        }
    ]

    run_test_update_comparison(datasets, k=10, threshold=200)
    print(f"\n结果已保存")


初始SSE: 837824.90
[Coreset] SSE: 2764.94
[OnlineKMeans] SSE: 2502.68
[MiniBatchKMeans] SSE: 2502.71

初始SSE: 837824.90
[Coreset] SSE: 2897.14
[OnlineKMeans] SSE: 2601.31
[MiniBatchKMeans] SSE: 2601.40

初始SSE: 837824.90
[Coreset] SSE: 2868.90
[OnlineKMeans] SSE: 2726.15
[MiniBatchKMeans] SSE: 2726.32

初始SSE: 837824.90
[Coreset] SSE: 3146.81
[OnlineKMeans] SSE: 2973.99
[MiniBatchKMeans] SSE: 2974.35

初始SSE: 837824.90
[Coreset] SSE: 3882.81
[OnlineKMeans] SSE: 3716.28
[MiniBatchKMeans] SSE: 3717.31

初始SSE: 1284214.90
[Coreset] SSE: 2074.40
[OnlineKMeans] SSE: 1879.87
[MiniBatchKMeans] SSE: 1879.90

初始SSE: 1284214.90
[Coreset] SSE: 2134.75
[OnlineKMeans] SSE: 1953.95
[MiniBatchKMeans] SSE: 1954.02

初始SSE: 1284214.90
[Coreset] SSE: 2222.50
[OnlineKMeans] SSE: 2045.94
[MiniBatchKMeans] SSE: 2046.09

初始SSE: 1284214.90
[Coreset] SSE: 2418.15
[OnlineKMeans] SSE: 2232.59
[MiniBatchKMeans] SSE: 2232.92

初始SSE: 1284214.90
[Coreset] SSE: 3024.90
[OnlineKMeans] SSE: 2792.67
[MiniBatchKMeans] SSE: 279