In [None]:
import numpy as np
from scipy import spatial
from scipy.sparse import csr_matrix
import json
from collections import defaultdict
from itertools import combinations
import math
import copy

In [None]:
def distance_matrix(node_ids, embedding_dict):
    embeddings = np.array([embedding_dict[v] for v in node_ids])
    return spatial.distance.cdist(embeddings, embeddings, metric='cosine')

In [None]:
def ravasz(node_indices, embedding_dict):
    def partition(node_indices):
        P = {}
        for v in node_indices:
            P[v] = v
        return P 
    def similarity(node_indices, D):
        SS = 1 - D
        for i in range(len(node_indices)):
            SS[i][i] = -math.inf
        return SS

    def reverse_index(P):
        comms = defaultdict(list)
        for v, comm in P.items():
            comms[comm].append(v)
        renumber_dict = {}
        for index, comm in enumerate(list(comms.keys())):
            renumber_dict[comm] = index
        renumbered_comms_dict = {
            renumber_dict[comm]: vertices for comm, vertices in comms.items()
        }
        return renumbered_comms_dict

    def cluster_embedding(embedding_dict, comms):
        new_embedding_dict = {}
        for comm, nodes in comms.items():
            avg_embedding = np.mean(np.array([embedding_dict[v] for v in nodes]), axis=0)
            new_embedding_dict[comm] = avg_embedding
        return new_embedding_dict

    levels = []
    P = partition(node_indices)
    comms_dict = reverse_index(P)
    ori_graph_partition = P
    levels = defaultdict(list)
    level = 0
    for v in node_indices:
        levels[v].append(P[v])
    D = distance_matrix(node_indices, embedding_dict)
    while(True):
        # init level slot
        for v, cur_levels in levels.items():
            cur_levels.append(None)
        print("clustering begin")
        print("initial nodes:", len(node_indices))
        print("calculating similarity matrix")
        similarity_matrix = similarity(node_indices, D)

        print("calculating reverse index of G")
        ori_graph_comms_dict = reverse_index(ori_graph_partition)
        most_similar_nodes = set()
        # for v in G.vs:
        for v in node_indices:
            print("finding most similar node")
            most_similar_node = max(range(len(similarity_matrix[v])), key=similarity_matrix[v].__getitem__)

            print("moving node: ", v, " from comm: ", P[v], " to comm: ", P[most_similar_node])
            most_similar_nodes.add(P[most_similar_node])
            print(len(ori_graph_comms_dict), len(ori_graph_partition))
            P[v] = P[most_similar_node]
        for v, c in P.items():
            for node in ori_graph_comms_dict[v]:
                ori_graph_partition[node] = c
                levels[node][level] = c
        ori_graph_comms_dict = reverse_index(ori_graph_partition)
        print("most similar nodes: ", len(most_similar_nodes), len(ori_graph_comms_dict))

        level += 1
        print("one iteration done")
        comms_dict = reverse_index(P)
        print("total nodes in comms:", sum([len(x) for x in ori_graph_comms_dict.values()]))
        comm_node_ids = list(comms_dict.keys())
        print("clusters: ", len(comm_node_ids))
        # preserve the hierarchy
        embedding_dict = cluster_embedding(embedding_dict, comms_dict)
        # construct new distances between clusters
        D = distance_matrix(comm_node_ids, embedding_dict)
        P = partition(comm_node_ids)
        # assign the result to operate recursively
        if len(node_indices) == len(comm_node_ids): break
        node_indices = comm_node_ids
        print("pass done. ")
    return levels


In [None]:
def _renumber_dict(P):
    comm_set = set(P.values())
    renumber_dict = {comm: index for index, comm in enumerate(comm_set)}
    return renumber_dict
    # P = {v: renumber_dict[comm] for v, comm in P.items()}
    # return P
    
def levels_to_partitions(node_indices, levels, idx_to_id_dict):
    partitions = []
    # for v in G.vs:
    #     levels[v.index] = levels[v.index][0:-1]
    for v in node_indices:
        levels[v] = levels[v][0:-1]
    for level in range(len(levels[0])):
        P = {}
        for v in node_indices:
            partition = levels[v][level]
            P[idx_to_id_dict[v]] = partition
        renumber_dict = _renumber_dict(P)
        P = {v: renumber_dict[comm] for v, comm in P.items()}
        # for v in G.vs:
        #     levels[v.index][level] = P[v['name']]
        for v in node_indices:
            levels[v][level] = P[idx_to_id_dict[v]]
        partitions.append(P)
    last_partition = partitions[-1]
    comm_labels = set(last_partition.values())
    if len(comm_labels) > 1:
        # partitions.append({v['name']: 0 for v in G.vs})
        partitions.append({idx_to_id_dict[v]: 0 for v in node_indices})
        for v in node_indices:
            levels[v].append(0)
    return partitions, levels

def add_dummy_partition(partitions):
    first_partition = partitions[0]
    dummy_partition = {}
    for index, node_id in enumerate(list(first_partition.keys())):
        dummy_partition[node_id] = index
    partitions.insert(0, dummy_partition)
    return partitions

In [None]:
def get_level_transition(levels):
    nested_comms = {}
    for i in range(len(levels[0])-1):
        for v, transitions in levels.items():
            trans_children_title = "L-{}-{}".format(i, transitions[i])
            trans_parent_title = "L-{}-{}".format(i+1, transitions[i+1])
            # if children is the first level
            if trans_children_title not in nested_comms:
                # create leaf
                nested_comms[trans_children_title] = {
                    "title": trans_children_title,
                    "key": trans_children_title
                }
                # add to parent 
                if trans_parent_title not in nested_comms:
                    nested_comms[trans_parent_title] = {
                        "title": trans_parent_title,
                        "key": trans_parent_title,
                        'children': [nested_comms[trans_children_title]]
                    }
                # avoid adding duplicate children
                elif trans_children_title not in [child['title'] for child in nested_comms[trans_parent_title]['children']]:
                    nested_comms[trans_parent_title]['children'].append(nested_comms[trans_children_title])
            else:
                # if children is not the first level
                # add to parent directly
                if trans_parent_title not in nested_comms:
                    nested_comms[trans_parent_title] = {
                        "title": trans_parent_title,
                        "key": trans_parent_title,
                        'children': [nested_comms[trans_children_title]]
                    }
                # avoid adding duplicate children
                elif trans_children_title not in [child['title'] for child in nested_comms[trans_parent_title]['children']]:
                    nested_comms[trans_parent_title]['children'].append(nested_comms[trans_children_title])
    final_level = len(levels[0])-1
    return nested_comms['L-{}-{}'.format(final_level, 0)]

def dfs(hierarchy, leaf_children_dict):
    cur_level_label = hierarchy['title'].split("-")[1]
    cur_cluster_label = hierarchy['title'].split("-")[2]
    new_level_label = str(int(cur_level_label) + 1)
    hierarchy['title'] = "L-{}-{}".format(new_level_label, cur_cluster_label)
    hierarchy['key'] = "L-{}-{}".format(new_level_label, cur_cluster_label)
    if 'children' in hierarchy:
        for child in hierarchy['children']:
            dfs(child, leaf_children_dict)
    else:
        dummy_clusters = leaf_children_dict[cur_cluster_label]
        hierarchy['children'] = []
        for dummy_cluster_label in dummy_clusters:
            hierarchy['children'].append({ 
                "title": "L-0-{}".format(dummy_cluster_label),
                "key": "L-0-{}".format(dummy_cluster_label),
            })
    return

def add_dummy_hierarchy(partitions, hierarchies):
    first_partition = partitions[0]
    second_partition = partitions[1]
    second_level_children_dict = defaultdict(list)
    for node_id, dummy_cluster_label in first_partition.items():
        parent_cluster_label = second_partition[node_id]
        second_level_children_dict[str(parent_cluster_label)].append(dummy_cluster_label)
    dfs(hierarchies, second_level_children_dict)
    return hierarchies

In [None]:
def cluster(node_ids, embedding_dict):
    index_to_id_dict = {index: node_id for index, node_id in enumerate(node_ids)}
    id_to_index_dict = { node_id: index for index, node_id in enumerate(node_ids)}
    index_embedding_dict = { id_to_index_dict[node_id]: embedding for node_id, embedding in embedding_dict.items()}
    node_indices = list(index_to_id_dict.keys())
    levels = ravasz(node_indices, index_embedding_dict)
    partitions, renumbered_levels = levels_to_partitions(node_indices, copy.deepcopy(levels), index_to_id_dict)
    partitions = add_dummy_partition(partitions)
    hierarchies = get_level_transition(renumbered_levels)
    hierarchies = add_dummy_hierarchy(partitions, hierarchies)
    return partitions, hierarchies

In [None]:
dataset = json.load(open('data/pairwise_samples.json'))
node_ids = list(dataset.keys())
full_embedding_dict = { v: dataset[v]['full_embedding'] for v in node_ids }
writer_summ_embedding_dict = { v: dataset[v]['writer_summary_embedding'] for v in node_ids }
llm_summ_embedding_dict = { v: dataset[v]['llm_summary_embedding'] for v in node_ids }

In [55]:
def save_json(data, filepath=r'new_data.json'):
   with open(filepath, 'w') as fp:
      json.dump(data, fp, indent=4)

In [57]:
from cluster import ravasz_cluster
# full
partitions, hierarchies = ravasz_cluster(node_ids, full_embedding_dict)
save_json(partitions, "data/cluster_result/full_partitions.json")
save_json(hierarchies, "data/cluster_result/full_hierarchies.json")

# writer summary
partitions, hierarchies = ravasz_cluster(node_ids, writer_summ_embedding_dict)
save_json(partitions, "data/cluster_result/writer_summ_partitions.json")
save_json(hierarchies, "data/cluster_result/writer_summ_hierarchies.json")

# llm summary
partitions, hierarchies = ravasz_cluster(node_ids, llm_summ_embedding_dict)
save_json(partitions, "data/cluster_result/llm_summ_partitions.json")
save_json(hierarchies, "data/cluster_result/llm_summ_hierarchies.json")

clustering begin
initial nodes: 76
calculating similarity matrix
calculating reverse index of G
finding most similar node
moving node:  0  from comm:  0  to comm:  25
76 76
finding most similar node
moving node:  1  from comm:  1  to comm:  48
76 76
finding most similar node
moving node:  2  from comm:  2  to comm:  68
76 76
finding most similar node
moving node:  3  from comm:  3  to comm:  11
76 76
finding most similar node
moving node:  4  from comm:  4  to comm:  65
76 76
finding most similar node
moving node:  5  from comm:  5  to comm:  51
76 76
finding most similar node
moving node:  6  from comm:  6  to comm:  70
76 76
finding most similar node
moving node:  7  from comm:  7  to comm:  27
76 76
finding most similar node
moving node:  8  from comm:  8  to comm:  51
76 76
finding most similar node
moving node:  9  from comm:  9  to comm:  74
76 76
finding most similar node
moving node:  10  from comm:  10  to comm:  73
76 76
finding most similar node
moving node:  11  from comm: 

In [None]:
len(partitions[2])