In [3]:
from collections import defaultdict

def calculate_rankings(pog_str):
    """
        pog_str: the raw POG that ensembles 
        '3=4=6>1=2=7>0>5=8=9>10>14=16=20>15=18>11=13>12=17=19>21'
    """
    clusters = pog_str.split('>')
    rank = {}
    adj_rank = {}
    
    rank_count = 1
    for idx, group in enumerate(clusters):
        group_clusters = group.split('=')
        for cluster in group_clusters:
            rank[int(cluster)] = rank_count
        rank_count += 1

    rank_vec = [rank[cluster] for cluster in sorted(rank.keys())]
    reverse_rank = defaultdict(list)
    for cluster, r in rank.items():
        reverse_rank[r].append(cluster)

    for r in sorted(reverse_rank.keys()):
        for cluster in reverse_rank[r]:
            adj_rank[cluster] = rank[cluster]+1/2*(len(reverse_rank[r])-1)
    adj_rank_vec = [adj_rank[cluster] for cluster in sorted(adj_rank.keys())]
    return rank_vec, adj_rank_vec

pog1 = '3=4=6>1=2=7>0>5=8=9>10>14=16=20>15=18>11=13>12=17=19>21'
pog2 = '3>4=6>0=1=2=7>8=10>5=9>14=15=21>16=20>11=12=13=17=18>19'
pog3 = '1=2=3=4=6>7>0>5=8>9=10>14=16=20>11=13=15=18>12=17>19=21'
pog4 = '6>2=3=5>1=4=7=9>0>8=10>15=17=21>14>13=16=20>12=19>11=18'
pog5 = '4>0=1=2=6=7=9=16>3>5=8=10=14>12=13=21>11=15=17=18=19=20'
pog6 = '3=6>1>0=2=4=9>7=8>5>10>14=16>17>15=20>18>19=21>12>11>13'
pog7 = '6>2=3=4>0=1=9>5=7>8=10>16>14>20>15>13=17=18>11=12=19>21'
pog8 = '0=1=2=3=4>5=6>9>7>8=10>14=15=16>21>11=12=13=17=18=19=20'
pog9 = '1=2>7>3>0>4>6>9>8>10>5>20>14=16>12=13>15>11=17>18>19=21'
pog10 = '2=3=4>0=1=6>7=9>8>5=10>15=17=21>14=16=18=20>11=12=13=19'

print(calculate_rankings(pog1)[1])

adj_ranks = []
for idx in range(10):
    pog_str = eval(f'pog{idx+1}')
    rank_vec, adj_rank_vec = calculate_rankings(pog_str)
    adj_ranks.append(adj_rank_vec)

import numpy
adj_ranks = numpy.array(adj_ranks)

adj_ranks

[3.0, 3.0, 3.0, 2.0, 2.0, 5.0, 2.0, 3.0, 5.0, 5.0, 5.0, 8.5, 10.0, 8.5, 7.0, 7.5, 7.0, 10.0, 7.5, 10.0, 7.0, 10.0]


array([[ 3. ,  3. ,  3. ,  2. ,  2. ,  5. ,  2. ,  3. ,  5. ,  5. ,  5. ,
         8.5, 10. ,  8.5,  7. ,  7.5,  7. , 10. ,  7.5, 10. ,  7. , 10. ],
       [ 4.5,  4.5,  4.5,  1. ,  2.5,  5.5,  2.5,  4.5,  4.5,  5.5,  4.5,
        10. , 10. , 10. ,  7. ,  7. ,  7.5, 10. , 10. ,  9. ,  7.5,  7. ],
       [ 3. ,  3. ,  3. ,  3. ,  3. ,  4.5,  3. ,  2. ,  4.5,  5.5,  5.5,
         8.5,  8.5,  8.5,  7. ,  8.5,  7. ,  8.5,  8.5,  9.5,  7. ,  9.5],
       [ 4. ,  4.5,  3. ,  3. ,  4.5,  3. ,  1. ,  4.5,  5.5,  4.5,  5.5,
        10.5,  9.5,  9. ,  7. ,  7. ,  9. ,  7. , 10.5,  9.5,  9. ,  7. ],
       [ 5. ,  5. ,  5. ,  3. ,  1. ,  5.5,  5. ,  5. ,  5.5,  5. ,  5.5,
         8.5,  6. ,  6. ,  5.5,  8.5,  5. ,  8.5,  8.5,  8.5,  8.5,  6. ],
       [ 4.5,  2. ,  4.5,  1.5,  4.5,  5. ,  1.5,  4.5,  4.5,  4.5,  6. ,
        13. , 12. , 14. ,  7.5,  9.5,  7.5,  8. , 10. , 11.5,  9.5, 11.5],
       [ 4. ,  4. ,  3. ,  3. ,  3. ,  4.5,  1. ,  4.5,  5.5,  4. ,  5.5,
        12. , 12. , 11. ,  7. , 

In [2]:
import numpy as np
from sklearn.cluster import KMeans

def pog_ensemble(pog_ranks: np.ndarray, sigma: float, epsilon: float, 
                 k: int, eta: float) -> np.ndarray:
    """
    聚合多个POG向量，使用Welsch函数计算权重，直到结果收敛。
    
    Parameters:
    pog_ranks (np.ndarray): 形状为 (M, N) 的POG矩阵，其中M是POG的数量，N是每个POG的维度。
    sigma (float): Welsch函数的超参数sigma，用于计算权重。
    epsilon (float): 收敛阈值，当误差小于epsilon时停止迭代。
    
    Returns:
    np.ndarray: 聚合后的POG向量。
    """
    # 获取POG矩阵的维度
    M, N = pog_ranks.shape

    # 初始化R*为POG向量的平均值
    R_star = np.mean(pog_ranks, axis=0)
    
    # 迭代直到收敛
    while True:
        # 计算每个POG的alpha_m
        alphas = np.zeros(M)
        for m in range(M):
            dist = np.linalg.norm(pog_ranks[m] - R_star)  # 计算欧几里得距离 ||R_m - R*||_2
            alphas[m] = 1 - np.exp(-dist**2 / sigma**2)  # 使用Welsch函数计算alpha_m

        # 计算权重w_m
        sum_alpha = np.sum(alphas)
        w_m = alphas / sum_alpha  # 归一化权重
        # 计算误差err
        R_star_new = np.sum(w_m[:, np.newaxis] * pog_ranks, axis=0)  # 计算加权聚合结果
        
        err = np.linalg.norm(R_star_new - R_star)  # 计算误差 ||R^* - Σw_m R_m||_2

        # 如果误差小于epsilon，则停止迭代
        if err < epsilon:
            break
        
        # 更新R*为新的聚合结果
        R_star = R_star_new
    
    kmeans = KMeans(n_clusters=k, random_state=0)
    kmeans.fit(R_star.reshape(-1, 1))
    labels = kmeans.labels_
    # 按照labels，将相同类别的聚类组合。聚类的编号为R_star中的索引
    # 例如[1 1 1 4 4 3 4 1 3 3 3 2 2 2 0 0 0 2 2 2 0 2]
    # 将 0，1，2，7合并为一类，3，4，6合并为一类，5，8，9，10合并为一类，11，12，13，17，18，19，21合并为一类，14，15，16，20合并为一类
    
    pog = {}
    for idx, label in enumerate(labels):
        if label in pog:
            pog[label].append(idx)
        else:
            pog[label] = [idx]
    avg_rank = {}
    for key in pog:
        avg_rank[key] = np.mean([R_star[i] for i in pog[key]])
    # 根据avg_rank，对pog的键值对进行排序，avg小的键值对排在前面，输出一个和pog一样的字典，只是键值对按照avg_rank排序
    sorted_pog = dict(sorted(pog.items(), key=lambda x: avg_rank[x[0]]))
    # 1=2=3=4=6 > 0=7 > 5=8=9=10 > 14=15=16=20 > 17=21 > 13=18 > 11=12 > 19 > 22
    # 计算每个聚类组的平均排名
    return R_star, sorted_pog

# 示例调用：
pog_ranks = adj_ranks  # 假设有10条POG，每条POG是长度为22的向量

# 计算聚合后的POG
aggregated_pog, sorted_pog = pog_ensemble(pog_ranks, sigma=0.1, epsilon=1e-6, k=10, eta=0.5)

print("Aggregated POG:", aggregated_pog)
print("Sorted POG:", sorted_pog)
sorted_pog_str = ''
for i in sorted_pog.keys():
    sorted_pog_str += '='.join([str(j) for j in sorted_pog[i]]) + '>'
sorted_pog_str += '22'

print(f'aggregat_pog: {sorted_pog_str}')
print('ground_truth: 1=2=3=4=6>0=7>5=8=9=10>14=15=16=20>17=21>13=18>11=12>19>22')

# print(aggregated_pog.reshape(-1,1))

Aggregated POG: [ 3.8   3.35  3.25  2.45  3.05  5.1   2.75  3.75  5.25  4.75  5.75 10.7
 10.2  10.1   7.6   8.5   7.7   9.65 10.15 10.8   8.7   9.45]
Sorted POG: {np.int32(4): [3, 6], np.int32(6): [1, 2, 4], np.int32(1): [0, 7], np.int32(3): [5, 8, 9], np.int32(9): [10], np.int32(2): [14, 16], np.int32(5): [15, 20], np.int32(7): [17, 21], np.int32(0): [12, 13, 18], np.int32(8): [11, 19]}
aggregat_pog: 3=6>1=2=4>0=7>5=8=9>10>14=16>15=20>17=21>12=13=18>11=19>22
ground_truth: 1=2=3=4=6>0=7>5=8=9=10>14=15=16=20>17=21>13=18>11=12>19>22
