In [36]:
import numpy as np
from numpy.linalg import cholesky, solve

# ---------- helpers (与之前一致) ----------
def symmetrize(S):
    return 0.5 * (S + S.T)

def omega_diag_from_cov(S):
    Omega = np.linalg.inv(S)
    return np.diag(Omega)

def chol_extend(L, S, A, v, jitter=0.0):
    if len(A) == 0:
        val = S[v, v] + jitter
        if val <= 0:
            raise np.linalg.LinAlgError("Non-positive variance on diagonal.")
        return np.array([[np.sqrt(val)]], dtype=float)
    S_vA = S[np.ix_([v], A)]
    y = solve(L, S_vA.T)
    diag_term = S[v, v] + jitter - float(y.T @ y)
    if diag_term <= 0:
        diag_term = 1e-12  # 数值兜底
    alpha = np.sqrt(diag_term)
    m = L.shape[0]
    L_new = np.zeros((m+1, m+1), dtype=float)
    L_new[:m, :m] = L
    L_new[m, :m] = y.T
    L_new[m, m] = alpha
    return L_new

def batch_var_given_A(S, A, C, L):
    """一次性算出所有候选 C 的 Var(v|A)。"""
    if len(A) == 0 or L.size == 0:
        return np.diag(S)[C].astype(float)
    S_AC = S[np.ix_(A, C)]             # (|A|, |C|)
    Y = solve(L, S_AC)                 # L Y = S_AC
    red = np.sum(Y * Y, axis=0)        # 列向量的平方和
    return np.diag(S)[C].astype(float) - red

# ---------- multi-Sigma 聚合贪心 ----------
def greedy_select_node_mig_multi(
    Sigma_list,               # list of (n,n) 协方差矩阵
    k,                        # 选择个数
    A_init=None,
    ridge=1e-6,
    weights=None,             # 可选，对每个 Sigma 的权重，默认均等
    clip_var=1e-15,           # 防止 log(<=0)
    verbose=False
):
    """
    在多个协方差上聚合节点自身 MIG，逐步贪心选择。
    返回: A (选择顺序), mig_hist (每步的总增益), eval_hist (每轮候选数)
    """
    S_list = []
    L_list  = []
    var_all_list = []
    n = Sigma_list[0].shape[0]
    A = [] if A_init is None else list(A_init)
    chosen = set(A)
    C = np.array([i for i in range(n) if i not in chosen], dtype=int)

    S_count = len(Sigma_list)
    if weights is None:
        w = np.ones(S_count, dtype=float)
    else:
        w = np.asarray(weights, dtype=float)
        assert w.shape == (S_count,)
    # 预处理每个 Sigma：对称+ridge、Omega_diag、初始 L
    for Sigma in Sigma_list:
        S = symmetrize(Sigma) + ridge * np.eye(n)
        S_list.append(S)
        # Var(v | V\{v}) = 1 / Omega_vv
        Omega_diag = omega_diag_from_cov(S)
        var_all = 1.0 / Omega_diag
        var_all_list.append(var_all)
        # 初始 Cholesky
        L = cholesky(S[np.ix_(A, A)]) if len(A) > 0 else np.zeros((0,0))
        L_list.append(L)

    mig_hist = []
    eval_hist = []

    for t in range(k):
        if C.size == 0:
            break

        # 聚合所有 Sigma 的 MIG：sum_s w_s * [log Var_s(v|A) - log Var_all_s(v)]
        agg_mig = np.zeros(C.size, dtype=float)
        for s, (S, L, var_all_s, ws) in enumerate(zip(S_list, L_list, var_all_list, w)):
            var_C_s = batch_var_given_A(S, A, C.tolist(), L)
            # 数值裁剪避免 log 非法
            var_C_s = np.maximum(var_C_s, clip_var)
            vals = np.log(var_C_s) - np.log(var_all_s[C])
            agg_mig += ws * vals

        idx = int(np.argmax(agg_mig))
        v_star = int(C[idx])
        mig_best = float(agg_mig[idx])

        if verbose:
            print(f"[Multi-BatchGreedy] step {t+1}: pick v={v_star}, sum-MIG={mig_best:.6f}")

        # 同步扩展每个 Sigma 的 Cholesky
        for i in range(S_count):
            try:
                L_list[i] = chol_extend(L_list[i], S_list[i], A, v_star, jitter=0.0)
            except np.linalg.LinAlgError:
                L_list[i] = chol_extend(L_list[i], S_list[i], A, v_star, jitter=1e-9)

        # 更新集合
        A.append(v_star)
        chosen.add(v_star)
        C = C[C != v_star]

        mig_hist.append(mig_best)
        eval_hist.append(len(C) + 1)  # 这一轮评估的候选数（每个 Sigma 都做了批量求解）

    return A, mig_hist, eval_hist

# ---------------- 使用示例 ----------------
if __name__ == "__main__":
    rng = np.random.default_rng(0)
    n = 800
    k = 800
    # 制造三份不同的协方差（比如来自三组采样/三种模型）
    U1 = rng.standard_normal((n, n)); Sigma1 = U1 @ U1.T
    U2 = rng.standard_normal((n, n)); Sigma2 = U2 @ U2.T
    U3 = rng.standard_normal((n, n)); Sigma3 = U3 @ U3.T

    A, migs, evals = greedy_select_node_mig_multi([Sigma1, Sigma2, Sigma3], k, verbose=True)
    print("Selected first 10:", A[:10])
    print("Mean aggregated MIG:", float(np.mean(migs)))


[Multi-BatchGreedy] step 1: pick v=673, sum-MIG=37.399499
[Multi-BatchGreedy] step 2: pick v=131, sum-MIG=36.867520
[Multi-BatchGreedy] step 3: pick v=559, sum-MIG=36.285669
[Multi-BatchGreedy] step 4: pick v=387, sum-MIG=36.152993
[Multi-BatchGreedy] step 5: pick v=492, sum-MIG=35.937583
[Multi-BatchGreedy] step 6: pick v=462, sum-MIG=35.907226
[Multi-BatchGreedy] step 7: pick v=69, sum-MIG=35.734945
[Multi-BatchGreedy] step 8: pick v=43, sum-MIG=35.465306
[Multi-BatchGreedy] step 9: pick v=485, sum-MIG=35.410189
[Multi-BatchGreedy] step 10: pick v=84, sum-MIG=35.348914
[Multi-BatchGreedy] step 11: pick v=738, sum-MIG=35.276162
[Multi-BatchGreedy] step 12: pick v=595, sum-MIG=35.232747
[Multi-BatchGreedy] step 13: pick v=187, sum-MIG=35.218326
[Multi-BatchGreedy] step 14: pick v=481, sum-MIG=35.190252
[Multi-BatchGreedy] step 15: pick v=333, sum-MIG=35.118876
[Multi-BatchGreedy] step 16: pick v=501, sum-MIG=34.989797
[Multi-BatchGreedy] step 17: pick v=163, sum-MIG=34.930926
[Multi-Ba

  diag_term = S[v, v] + jitter - float(y.T @ y)


[Multi-BatchGreedy] step 92: pick v=52, sum-MIG=32.911085
[Multi-BatchGreedy] step 93: pick v=193, sum-MIG=32.903005
[Multi-BatchGreedy] step 94: pick v=228, sum-MIG=32.878902
[Multi-BatchGreedy] step 95: pick v=90, sum-MIG=32.842608
[Multi-BatchGreedy] step 96: pick v=142, sum-MIG=32.821281
[Multi-BatchGreedy] step 97: pick v=158, sum-MIG=32.813245
[Multi-BatchGreedy] step 98: pick v=212, sum-MIG=32.775224
[Multi-BatchGreedy] step 99: pick v=150, sum-MIG=32.761585
[Multi-BatchGreedy] step 100: pick v=753, sum-MIG=32.738059
[Multi-BatchGreedy] step 101: pick v=717, sum-MIG=32.714606
[Multi-BatchGreedy] step 102: pick v=480, sum-MIG=32.706853
[Multi-BatchGreedy] step 103: pick v=67, sum-MIG=32.694708
[Multi-BatchGreedy] step 104: pick v=569, sum-MIG=32.688653
[Multi-BatchGreedy] step 105: pick v=78, sum-MIG=32.677320
[Multi-BatchGreedy] step 106: pick v=110, sum-MIG=32.665917
[Multi-BatchGreedy] step 107: pick v=40, sum-MIG=32.650667
[Multi-BatchGreedy] step 108: pick v=600, sum-MIG=32.