In [1]:
import networkx as nx
import numpy as np
import pickle
from tqdm.notebook import tqdm
import heapq
import matplotlib.pyplot as plt
from datetime import datetime

In [151]:
from collections import Counter

def neg_log_likelihood(labels):
    if len(labels) == 0:
        return 0
    loss = 0
    counter = Counter()
    for label in labels:
        counter[label] += 1
    size = np.sum(list(counter.values()))
    for label, count in counter.items():
        loss += count*np.log(count/size)
    return -loss

def count_forward_and_backward_edges(i, S_vertices, S_edges, vertex_to_set_index):
    forward_edges = 0
    backward_edges = 0
    for _, v in S_edges[i]:
        if v in S_vertices[i] or v not in vertex_to_set_index:
            continue
        j = vertex_to_set_index[v]
        if i < j:
            forward_edges += 1
        elif j < i:
            backward_edges += 1
    return forward_edges, backward_edges

def compute_loss_nll(G, S_vertices, S_edges, lambdas, vertex_to_set_index):
    lambda_fwd, lambda_back = lambdas
    loss = 0
    for i in range(len(S_vertices)):
        labels = [G.nodes[v]['features'] for v in S_vertices[i]]
        forward_edges, backward_edges = count_forward_and_backward_edges(i, S_vertices, S_edges, vertex_to_set_index)
        loss += neg_log_likelihood(labels) + lambda_fwd * forward_edges + lambda_back * backward_edges
    return loss

def L2_loss(features, mean=0):
    return np.sum(np.square(features-mean))

def compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index, centroids=None):
    lambda_fwd, lambda_back = lambdas
    loss = 0
    for i in range(len(S_vertices)):
        if len(S_vertices[i]) == 0:
            continue
        features = [G.nodes[v]['features'] for v in S_vertices[i]]
        if centroids is None:
            centroid = np.mean(features,axis=0)
        else:
            centroid = centroids[i]
        forward_edges, backward_edges = count_forward_and_backward_edges(i, S_vertices, S_edges, vertex_to_set_index)
        loss += L2_loss(features, centroid) + lambda_fwd * forward_edges + lambda_back * backward_edges
    return loss

def add_v(v, set_i, S_vertices, S_edges, vertex_to_set_index, G):
    S_vertices[set_i].add(v)
    vertex_to_set_index[v] = set_i
    S_edges[set_i] |= G.vertex_edges_from[v]

def remove_v(v, set_i, S_vertices, S_edges, vertex_to_set_index, G):
    S_vertices[set_i].remove(v)
    del vertex_to_set_index[v]
    for edge in G.vertex_edges_from[v]:
        S_edges[set_i].remove(edge)

def compute_vertex_edges(G):
    # for each vertex v, store which edges are from v to another vertex and from another vertex to v
    G.vertex_edges_from = dict({v:set() for v in G.nodes()})
    G.vertex_edges_to = dict({v:set() for v in G.nodes()})
    for u, v in G.edges():
        G.vertex_edges_from[u].add((u, v))
        G.vertex_edges_to[v].add((u, v))

In [18]:
def get_new_centroids(current_set_i, other_set_j, centroids, feature_v, size_from, size_to):
    if size_from == 1:
        new_centroid_from = None
    else:
        new_centroid_from = (size_from*centroids[current_set_i] - feature_v)/(size_from-1)
    if centroids[other_set_j] is None:
        new_centroid_to = feature_v
    else:
        new_centroid_to = (size_to*centroids[other_set_j] + feature_v)/(size_to+1)
    return new_centroid_from, new_centroid_to

def compute_delta_edge_loss(v, current_set_i, other_set_j, G, lambdas, vertex_to_set_index):
    # compute the change in loss from edges if we move v from current_set_i to other_set_j
    lambda_fwd, lambda_back = lambdas
    delta_edge_loss = 0
    for _, u in G.vertex_edges_from[v]:
        if vertex_to_set_index[u] > current_set_i and vertex_to_set_index[u] < other_set_j:
            delta_edge_loss += lambda_back-lambda_fwd
        elif vertex_to_set_index[u] < current_set_i and vertex_to_set_index[u] > other_set_j:
            delta_edge_loss += lambda_fwd-lambda_back
        elif vertex_to_set_index[u] == current_set_i:
            if other_set_j > current_set_i:
                delta_edge_loss += lambda_back
            else:
                delta_edge_loss += lambda_fwd
        elif vertex_to_set_index[u] == other_set_j:
            if other_set_j > current_set_i:
                delta_edge_loss -= lambda_fwd
            else:
                delta_edge_loss -= lambda_back
    for u, _ in G.vertex_edges_to[v]:
        if vertex_to_set_index[u] > current_set_i and vertex_to_set_index[u] < other_set_j:
            delta_edge_loss += lambda_fwd-lambda_back
        elif vertex_to_set_index[u] < current_set_i and vertex_to_set_index[u] > other_set_j:
            delta_edge_loss += lambda_back-lambda_fwd
        elif vertex_to_set_index[u] == current_set_i:
            if other_set_j > current_set_i:
                delta_edge_loss += lambda_fwd
            else:
                delta_edge_loss += lambda_back
        elif vertex_to_set_index[u] == other_set_j:
            if other_set_j > current_set_i:
                delta_edge_loss -= lambda_back
            else:
                delta_edge_loss -= lambda_fwd
    return delta_edge_loss

In [200]:
def compute_centroids(G, S_vertices):
    # compute centroids for each set or None if set is empty
    centroids = [np.mean([G.nodes[v]['features'] for v in S_vertices[i]],axis=0) \
                 if len(S_vertices[i]) > 0 else None \
                 for i in range(len(S_vertices))]
    return centroids

In [252]:
def find_node_clustering_greedy_local_search(G, k, lambdas, max_iter = 50, random_seed=None, given_partition=None, verbose=False):
    if random_seed is not None:
        np.random.seed(random_seed)
    lambda_fwd, lambda_back = lambdas

    S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
    vertex_to_set_index = dict() # for each vertex v, which set does v belong to

    S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set

    for v in G.nodes():
        if given_partition is None:
            set_index = np.random.randint(k)
        else:
            set_index = given_partition[v]
        add_v(v, set_index, S_vertices, S_edges, vertex_to_set_index, G)

    centroids = compute_centroids(G, S_vertices)

    total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index, centroids)
    iter = 0
    loss_decreasing = True
    while loss_decreasing and iter <= max_iter:
        iter += 1
        if verbose:
            print("iter:",iter, "loss:", total_loss)
        loss_decreasing = False
        for v in G.nodes():
            best_delta_loss = 0
            current_set_i = vertex_to_set_index[v]
            best_set_i = current_set_i

            feature_v = G.nodes[v]['features']
            size_from = len(S_vertices[current_set_i])
            if size_from == 1:
                current_set_L2_loss = 0
            else:
                current_set_L2_loss = size_from/(size_from-1) * L2_loss(feature_v, centroids[current_set_i])

            for other_set_j in range(k):
                if other_set_j == current_set_i:
                    continue
                size_to = len(S_vertices[other_set_j])

                other_set_L2_loss = size_to/(size_to+1) * L2_loss(feature_v, centroids[other_set_j])
                delta_edge_loss = compute_delta_edge_loss(v, current_set_i, other_set_j, G, lambdas, vertex_to_set_index)

                delta_loss = other_set_L2_loss - current_set_L2_loss + delta_edge_loss

                if delta_loss < best_delta_loss:
                    best_delta_loss = delta_loss
                    best_set_i = other_set_j
                
            if best_delta_loss < 0:
                loss_decreasing = True
                total_loss += best_delta_loss
                centroids[current_set_i], centroids[best_set_i] = get_new_centroids(current_set_i, best_set_i, centroids, feature_v, size_from, len(S_vertices[best_set_i]))
                remove_v(v, current_set_i, S_vertices, S_edges, vertex_to_set_index, G)
                add_v(v, best_set_i, S_vertices, S_edges, vertex_to_set_index, G)
    return total_loss, S_vertices, vertex_to_set_index

In [155]:
def add_subtree_to_best_set(v, parent_set_index, 
        best_loss_if_v_in_set_i, best_direction_with_parent_set, best_loss_if_v_in_earlier_set, 
        best_loss_if_v_in_later_set, S_vertices, S_edges, vertex_to_set_index, G):
    if parent_set_index is None:
        best_set_index = min(best_loss_if_v_in_set_i[v], key=best_loss_if_v_in_set_i[v].get)
    else:
        best_direction = best_direction_with_parent_set[v][parent_set_index]
        if best_direction == 0:
            best_set_index = parent_set_index
        elif best_direction == 1:
            best_set_index = best_loss_if_v_in_earlier_set[v][parent_set_index-1][1]
        elif best_direction == 2:
            best_set_index = best_loss_if_v_in_later_set[v][parent_set_index+1][1]
    add_v(v, best_set_index, S_vertices, S_edges, vertex_to_set_index, G)
    for _, child_node in G.vertex_edges_from[v]:
        add_subtree_to_best_set(child_node, best_set_index, 
            best_loss_if_v_in_set_i, best_direction_with_parent_set, best_loss_if_v_in_earlier_set, 
            best_loss_if_v_in_later_set, S_vertices, S_edges, vertex_to_set_index, G)

In [6]:
def compute_empty_set_centroids_and_l2_losses(G, k, vertex_to_set_index, centroids):
    # compute l2 loss for each node in each set
    empty_clusters = []
    l2_losses_dict = dict()
    current_l2_losses = []
    for set_i in range(k):
        if centroids[set_i] is None:
            empty_clusters.append(set_i)
            continue
        l2_losses_dict[set_i] = dict()
        for node in G.nodes():
            l2_loss = L2_loss(G.nodes[node]["features"], centroids[set_i])
            l2_losses_dict[set_i][node] = l2_loss
            if set_i == vertex_to_set_index[node]:
                current_l2_losses.append((l2_loss, node))

    # assign the empty clusters with centroids having same features as the nodes with largest l2 losses to their current clusters
    node_indices_with_largest_distances = heapq.nlargest(len(empty_clusters), range(len(current_l2_losses)), current_l2_losses.__getitem__)
    for i, empty_cluster_i in enumerate(empty_clusters):
        node = current_l2_losses[node_indices_with_largest_distances[i]][1]
        centroids[empty_cluster_i] = np.array(G.nodes[node]["features"])
        l2_losses_dict[empty_cluster_i] = dict()
        for node in G.nodes():
            l2_loss = L2_loss(G.nodes[node]["features"], centroids[empty_cluster_i])
            l2_losses_dict[empty_cluster_i][node] = l2_loss

    return centroids, l2_losses_dict

In [253]:
def find_node_clustering_on_tree(G, k, lambdas, max_iter = 50, random_seed=None, given_partition=None, verbose=False):
    if random_seed is not None:
        np.random.seed(random_seed)
    lambda_fwd, lambda_back = lambdas

    S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
    vertex_to_set_index = dict() # for each vertex v, which set does v belong to    
    S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set   
    
    root_vertices = [v for v in G.nodes() if G.in_degree(v) == 0]

    for v in G.nodes():
        if given_partition is None:
            set_index = np.random.randint(k)
        else:
            set_index = given_partition[v]
        add_v(v, set_index, S_vertices, S_edges, vertex_to_set_index, G)

    current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index)
    if verbose:
        print(f"start loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    iter = 0
    loss_decreasing = True

    while loss_decreasing and iter < max_iter:
        iter += 1
        loss_decreasing = False

        # compute centroids
        centroids = compute_centroids(G, S_vertices)
        centroids, l2_losses_dict = compute_empty_set_centroids_and_l2_losses(G, k, vertex_to_set_index, centroids)

        # reset all sets
        S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
        vertex_to_set_index = dict() # for each vertex v, which set does v belong to
        S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set

        best_loss_if_v_in_set_i = dict()
        best_loss_if_v_in_earlier_set = dict()
        best_loss_if_v_in_later_set = dict()
        best_direction_with_parent_set = dict()
        
        total_best_loss = 0

        for root in root_vertices:
            layers = [layer for layer in nx.bfs_layers(G, root)]
            for layer in layers[::-1]:
                for node in layer:
                    best_loss_if_v_in_set_i[node] = dict()
                    best_loss_if_v_in_earlier_set[node] = dict()
                    best_loss_if_v_in_earlier_set[node][-1] = (np.inf, -1)
                    best_loss_if_v_in_later_set[node] = dict()
                    best_loss_if_v_in_later_set[node][k] = (np.inf, -1)
                    best_direction_with_parent_set[node] = dict()

                    best_loss, best_set_index = np.inf, -1
                    for set_i in range(k):
                        loss = l2_losses_dict[set_i][node]

                        for _, child_node in G.vertex_edges_from[node]:
                            min_direction, min_value = min(
                                enumerate([best_loss_if_v_in_set_i[child_node][set_i],
                                        best_loss_if_v_in_earlier_set[child_node][set_i-1][0]+lambda_back,
                                        best_loss_if_v_in_later_set[child_node][set_i+1][0]+lambda_fwd]),
                                key=lambda x: x[1])
                            loss += min_value

                            # Store assuming parent goes to set_i whether we want to put child_node in set_i, 
                            #  or in some earlier set, or in some later set
                            best_direction_with_parent_set[child_node][set_i] = min_direction
                    
                        best_loss_if_v_in_set_i[node][set_i] = loss
                        if loss < best_loss:
                            best_loss, best_set_index = loss, set_i
                        best_loss_if_v_in_earlier_set[node][set_i] = (best_loss, best_set_index)
                    best_loss = np.inf
                    best_set_index = -1
                    for set_i in range(k-1, -1, -1):
                        if best_loss_if_v_in_set_i[node][set_i] < best_loss:
                            best_loss, best_set_index = best_loss_if_v_in_set_i[node][set_i], set_i
                        best_loss_if_v_in_later_set[node][set_i] = (best_loss, best_set_index)

            add_subtree_to_best_set(root, None, 
                best_loss_if_v_in_set_i, best_direction_with_parent_set, best_loss_if_v_in_earlier_set, 
                best_loss_if_v_in_later_set, S_vertices, S_edges, vertex_to_set_index, G)
            total_best_loss += best_loss
        if verbose:
            print("Total best loss with fixed centroids: ", total_best_loss)
                
        previous_loss = current_total_loss
        current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index)
        if current_total_loss < previous_loss:
            loss_decreasing = True
        if verbose:
            print(f"iter: {iter}, loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    return current_total_loss, S_vertices, vertex_to_set_index

In [262]:
def find_node_clustering_iterative_mincut(G, k, lambdas, max_iter = 50, random_seed=None, given_partition=None, verbose=False):
    if random_seed is not None:
        np.random.seed(random_seed)
    lambda_fwd, lambda_back = lambdas

    S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
    vertex_to_set_index = dict() # for each vertex v, which set does v belong to    
    S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set   

    root_vertices = [v for v in G.nodes() if G.in_degree(v) == 0]

    for v in G.nodes():
        if given_partition is None:
            set_index = np.random.randint(k)
        else:
            set_index = given_partition[v]
        add_v(v, set_index, S_vertices, S_edges, vertex_to_set_index, G)

    centroids = compute_centroids(G, S_vertices)

    current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index)
    if verbose:
        print(f"start loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    iter = 0
    loss_decreasing = True

    while loss_decreasing and iter < max_iter:
        iter += 1
        loss_decreasing = False

        for set_i in range(k-1):
            for set_j in range(set_i+1, k):
                if len(S_vertices[set_i]) == 0 and len(S_vertices[set_j]) == 0:
                    continue
                
                H = nx.DiGraph()
                for set_ij in [set_i, set_j]:
                    # if set_ij is empty, then make the feature of max distance vertex in the other set as the centroid
                    if len(S_vertices[set_ij]) == 0:
                        other_set_ji = set_i if set_ij == set_j else set_j
                        max_distance_vertex = max(S_vertices[other_set_ji], key=lambda v: L2_loss(G.nodes[v]["features"], centroids[other_set_ji]))
                        centroids[set_ij] = np.array(G.nodes[max_distance_vertex]["features"])

                    # create graoh H with vertices from set_i and set_j and compute mincut
                    for v in S_vertices[set_ij]:
                        edges_to_v = 0
                        edges_from_v = 0
                        for _, u in G.vertex_edges_from[v]:
                            if vertex_to_set_index[u] > set_i and vertex_to_set_index[u] < set_j:
                                edges_from_v += 1
                        for u, _ in G.vertex_edges_to[v]:
                            if vertex_to_set_index[u] > set_i and vertex_to_set_index[u] < set_j:
                                edges_to_v += 1
                        s_v_weight = L2_loss(G.nodes[v]["features"], centroids[set_j]) + lambda_fwd * edges_to_v + lambda_back * edges_from_v
                        v_t_weight = L2_loss(G.nodes[v]["features"], centroids[set_i]) + lambda_back * edges_to_v + lambda_fwd * edges_from_v
                        H.add_edge("_s", v, weight=s_v_weight)
                        H.add_edge(v, "_t", weight=v_t_weight)
                    for u, v in S_edges[set_ij]:
                        if v in S_vertices[set_i] or v in S_vertices[set_j]:
                            H.add_edge(u, v, weight=lambda_fwd)
                            H.add_edge(v, u, weight=lambda_back)

                _, partition = nx.minimum_cut(H, "_s", "_t", capacity="weight")

                for partition_i, set_ij in enumerate([set_i, set_j]):
                    S_edges[set_ij] = set()
                    S_vertices[set_ij] = set()
                    for v in partition[partition_i]:
                        if v != "_s" and v != "_t":
                            add_v(v, set_ij, S_vertices, S_edges, vertex_to_set_index, G)
                    if len(S_vertices[set_ij]) > 0:
                        centroids[set_ij] = np.mean([G.nodes[v]["features"] for v in S_vertices[set_ij]], axis=0)
                    else:
                        centroids[set_ij] = None
        
        previous_loss = current_total_loss
        current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index, centroids)
        if current_total_loss < previous_loss:
            loss_decreasing = True
        if verbose:
            print(f"iter: {iter}, loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    return current_total_loss, S_vertices, vertex_to_set_index

In [248]:
def create_synthetic_tree(n_vertices, n_features, n_clusters, centroid_variance=10, node_variance=0.1, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)
    synthetic_centroids = centroid_variance*np.random.uniform(size=(n_clusters, n_features))

    G = nx.DiGraph()
    G.add_nodes_from(range(n_vertices))
    for i, node in enumerate(G.nodes()):
        set_i = int(i / (n_vertices / n_clusters))
        G.nodes[node]["features"] = tuple(np.random.normal(loc=synthetic_centroids[set_i], scale=node_variance, size=n_features))

        #add edge to random earlier node to create a tree
        if node == 0:
            continue
        earlier_node = np.random.randint(0, node)
        G.add_edge(earlier_node, node)

    compute_vertex_edges(G)
    return G, synthetic_centroids

def add_synthetic_edges(G, probability = 0.05, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)
    nodes = list(G.nodes())
    for i in range(len(nodes)-1):
        for j in range(i+1, len(nodes)):
            if np.random.uniform() < probability:
                G.add_edge(nodes[i], nodes[j])
    compute_vertex_edges(G)

In [218]:
def plot_synthetic_test(ps, rand_scores, test_lambdas, filename=None):
    plt.rcParams['text.usetex'] = True
    fig, ax = plt.subplots(1,1)
    fig.set_size_inches(4, 3)
    for lambda_back in test_lambdas:
        plt.plot(ps, rand_scores[lambda_back], label=f"$\lambda_b={lambda_back}$")
    plt.xlabel("Probability p")
    plt.ylabel("Adjusted Rand Index")
    plt.legend()
    plt.tight_layout()
    if filename is not None:
        plt.savefig(filename, dpi=100, pad_inches=0)
    plt.show()

In [None]:
from sklearn.metrics.cluster import adjusted_rand_score

random_seed = 1234
np.random.seed(random_seed)

n_vertices = 1000
n_features = 10
n_clusters = 5

centroid_variance = 1.0
node_variance = 0.1

ground_truth = [int(i / (n_vertices / n_clusters)) for i in range(n_vertices)]

step_size = 0.05
test_lambdas = [0, 100000]
n_repeated_runs = 10

for alg, test_on_DAG in ["tree_DP", False], ["greedy_local_search", False], ["iterative_mincut", True], ["greedy_local_search", True]:
    print(alg, test_on_DAG)
    G, synthetic_centroids = create_synthetic_tree(n_vertices, n_features, n_clusters, 
                            centroid_variance=centroid_variance, node_variance=node_variance,random_seed=random_seed)
    # add random edges to make the tree into a more complex DAG
    if test_on_DAG:
        add_synthetic_edges(G, probability=0.01, random_seed=random_seed)
        print("DAG size:",len(G.nodes()), len(G.edges()))

    p = 0
    rand_scores = dict()
    ps = []
    while p <= 1.0+1e-8:
        for i, node in enumerate(G.nodes()):
            if np.random.uniform() < p: # assign to random set with probability p
                random_set_i = np.random.randint(0, n_clusters)
                G.nodes[node]["features"] = tuple(np.random.normal(loc=synthetic_centroids[random_set_i], scale=node_variance, size=n_features))

        compute_vertex_edges(G)

        for lambda_back in test_lambdas:
            # run the algorithm multiple times and take the best result
            best_loss = np.inf
            for run_i in range(n_repeated_runs):
                seed = random_seed + run_i
                if alg == "iterative_mincut":
                    loss, S_vertices, vertex_to_set_index = find_node_clustering_iterative_mincut(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed)
                elif alg == "greedy_local_search":
                    loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed)
                elif alg == "tree_DP":
                    loss, S_vertices, vertex_to_set_index = find_node_clustering_on_tree(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed)

                if loss < best_loss:
                    best_loss = loss
                    best_vertex_to_set_index = vertex_to_set_index

            clustering = [best_vertex_to_set_index[node] for node in G.nodes()]
            # compute adjusted rand index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
            rand_score = adjusted_rand_score(ground_truth, clustering)

            if lambda_back not in rand_scores:
                rand_scores[lambda_back] = []
            rand_scores[lambda_back].append(rand_score)
            print(f"lambda: {lambda_back}, p: {p}, adjusted rand score: {rand_score}, loss: {best_loss}")
        ps.append(p)
        p += step_size

    print("-----------------------")
    filename = f"synthetic_{'DAG' if test_on_DAG else 'tree'}_rand_index_{alg}_k={n_clusters}_top{n_repeated_runs}.pdf"
    plot_synthetic_test(ps, rand_scores, test_lambdas, filename)

Tests on Reddit data

In [288]:
with open("reddit_graph.pkl", "rb") as file:
    G = pickle.load(file)

In [289]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_on_tree(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")

start loss: 3940665656.568115, sets: [3679, 3712, 3631, 3792, 3823, 3749, 3816, 3671, 3693, 3743, 3752, 3868, 3803, 3705, 3674, 3747, 3739, 3739, 3751, 3691]
Total best loss with fixed centroids:  65291.857263445854
iter: 1, loss: 60142.32177734375, sets: [4076, 2354, 2902, 2248, 1108, 10755, 6411, 1639, 3358, 952, 3646, 2203, 7220, 4310, 2907, 2040, 6704, 944, 3954, 5047]
Total best loss with fixed centroids:  58652.51103082299
iter: 2, loss: 57612.94982910156, sets: [4560, 2056, 4557, 2356, 2192, 7521, 3446, 2214, 5324, 1549, 2967, 2741, 6957, 3632, 4289, 3248, 4889, 3191, 1904, 5185]
Total best loss with fixed centroids:  56920.111352309585
iter: 3, loss: 56293.962341308594, sets: [3941, 2215, 5437, 3049, 3107, 5966, 3063, 2730, 5353, 2219, 2125, 2999, 6489, 4371, 4982, 3869, 3898, 3583, 1032, 4350]
Total best loss with fixed centroids:  56108.176543105394
iter: 4, loss: 55967.83592224121, sets: [3431, 2397, 5799, 3139, 3308, 5309, 2962, 2853, 5407, 2383, 2106, 3083, 5762, 5225, 508

In [290]:
for s in S_vertices:
    print(s)

{'gbioxcv', 'gbgsmjq', 'gbjkj3o', 'gbgs9hz', 'gbgtu46', 'gbj8784', 'gbgrub0', 'gbgts0w', 'gbkgfsc', 'gbgujso', 'gbh80f7', 'gbijv7z', 'gbgtir5', 'gbgr1xl', 'gbgrtsi', 'gbif3sl', 'gbgs7ry', 'gbh718k', 'gbhbuiq', 'gbi8umf', 'gbhg306', 'gbgr02m', 'gbh84x0', 'gbj8pcf', 'gbgvx2q', 'gbgwg0m', 'gbgthdh', 'gbh40a7', 'gbgv6dn', 'gbgt20k', 'gebtdg2', 'gct4sim', 'gbh8b0m', 'gbh1m6v', 'gbgzrcr', 'gbgtgog', 'gbguoxk', 'gbgstcl', 'gbh052c', 'gbjalc3', 'gbhzmj2', 'gbgt4m9', 'gbja32y', 'gbgr9xm', 'gbgutz4', 'gbha3rk', 'gbh99zk', 'gbgtxsn', 'gbhb78k', 'gbh2dvx', 'gbhcqq2', 'gbi14da', 'gbh4av4', 'gbgs2ar', 'gbgvtwb', 'gbgu8jp', 'gbgxirb', 'gbhcdtj', 'gbgsqwc', 'gbhtemo', 'gbh8xia', 'gbguvd4', 'gbjtyu8', 'gbgswb5', 'gbgwtrr', 'gbhbiby', 'gbhb178', 'gbhwop3', 'gbgth1g', 'gbjken2', 'gbgrape', 'gbj75e0', 'gbh7oi6', 'gbjg7lr', 'gbgvvk0', 'gbgzun6', 'gbgsbje', 'gbgsjki', 'gbgqvca', 'gbham8c', 'gbjmhlp', 'gbgwmvi', 'gbgryl0', 'gbi95h9', 'gbhulab', 'gbkr5ik', 'gbgs4um', 'gbivgbq', 'gbi0s9c', 'gbiixag', 'gbgvd44'

In [52]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")


iter: 1 loss: 3940665656.576172
iter: 2 loss: 50159011.38350522
iter: 3 loss: 15856699.188847966
iter: 4 loss: 11056192.979513206
iter: 5 loss: 9256008.28814538
iter: 6 loss: 8055877.264499885
iter: 7 loss: 7455769.235527731
iter: 8 loss: 6855700.741513879
iter: 9 loss: 6455657.720367758
iter: 10 loss: 6155629.526096124
iter: 11 loss: 6155606.32551593
iter: 12 loss: 5955590.379381934
iter: 13 loss: 5955577.783150409
iter: 14 loss: 5955566.366719464
iter: 15 loss: 5955559.315207875
iter: 16 loss: 5955552.77680256
iter: 17 loss: 5855546.908533862
iter: 18 loss: 5655544.340628187
iter: 19 loss: 5655542.688361539
iter: 20 loss: 5455541.561074893
iter: 21 loss: 5455539.601348644
iter: 22 loss: 5455538.604323449
iter: 23 loss: 5455538.130133378
iter: 24 loss: 5455537.972569693
iter: 25 loss: 5455537.597025823
iter: 26 loss: 5455537.261393904
iter: 27 loss: 5455536.916380471
iter: 28 loss: 5455536.7709437655
iter: 29 loss: 5455536.7228535535
iter: 30 loss: 5455536.695209776
iter: 31 loss: 545

In [55]:
for s in S_vertices:
    print(s)

{'gbi0vop', 'gbgtir5', 'gbjl0o6', 'gbh99e3', 'gbgr9br', 'gbig03a', 'gbi7ypx', 'gbh53a3', 'gbgtqwq', 'gbhd6f3', 'gbh0i0o', 'gbgsszl', 'gbh9qja', 'gbgtfoq', 'gbgs2r7', 'gbgr51w', 'gc0o37c', 'gbgu5n9', 'gbh1lcf', 'gbgtbzj', 'gbgt06a', 'gbgu5ns', 'gbh93n8', 'gbhg1rw', 'gbgu5zt', 'gbi2x4f', 'gbh1i9j', 'gbh2boe', 'gbgudc2', 'gbgs4ka', 'gbgu4ef', 'gbgswio', 'gbgr3cs', 'gbh3m9v', 'gbktxp8', 'gbh6rqc', 'gbhkjvs', 'gbgvv6x', 'gbhh6gy', 'gbgv72a', 'gbgy0w2', 'gbh5deu', 'gbgttrk', 'gbgrjkc', 'gbgsdf0', 'gbh93ab', 'gbgr7si', 'gbgrhdm', 'gbh1c40', 'gbgr331', 'gbh1x53', 'gbgrov0', 'gbh3pqd', 'gbh3owg', 'gbgvbc4', 'gbhk87g', 'gbgsqna', 'gbhbfos', 'gbh7qjz', 'gbhvh2a', 'gbh80z7', 'gbluukj', 'gbhy1an', 'gbgukpw', 'gbhkit4', 'gbgto9t', 'gbizo7m', 'gbgw263', 'gbgx7ly', 'gbhr4gt', 'gbhxgn0', 'gbgr6fm', 'gbgx6a1', 'gbi0bdb', 'gbi61ol', 'gbgwx87', 'gbgtm56', 'gbhv3rw', 'gbhz3m9', 'gbgr9o5', 'gbgzvaa', 'gbh31yx', 'gbgqzrp', 'gbgto0b', 'gbh7rzr', 'gbhe89x', 'gbhqtzq', 'gbh1xvt', 'gbhtf99', 'gbhftjy', 'gbgt9am'

Tests on DBLP data

In [282]:
with open("dblp_graph.pkl", "rb") as file:
    G = pickle.load(file)

In [59]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")

iter: 1 loss: 3282025708.7355957
iter: 2 loss: 163423470.91921353
iter: 3 loss: 48722627.92323614
iter: 4 loss: 32522438.904299807
iter: 5 loss: 30322360.294559043
iter: 6 loss: 28822327.731860504
iter: 7 loss: 28422308.10202741
iter: 8 loss: 28022293.32947411
iter: 9 loss: 27522282.163654473
iter: 10 loss: 27222274.913778238
iter: 11 loss: 27022268.568848014
iter: 12 loss: 26822265.43805484
iter: 13 loss: 26522261.466446005
iter: 14 loss: 26222258.061951708
iter: 15 loss: 26022254.870151993
iter: 16 loss: 25722252.667581216
iter: 17 loss: 25722250.032851156
iter: 18 loss: 25722248.60209114
iter: 19 loss: 25522248.179399494
iter: 20 loss: 25422247.701586682
iter: 21 loss: 25322246.919700235
iter: 22 loss: 25322245.65809606
iter: 23 loss: 25322245.048429642
iter: 24 loss: 25322244.811622374
iter: 25 loss: 25322243.256455068
iter: 26 loss: 25322242.549229804
iter: 27 loss: 25322242.43581447
iter: 28 loss: 25222242.644007344
iter: 29 loss: 25222242.417132054
iter: 30 loss: 25122241.930745

In [60]:
for s in S_vertices:
    print(s)

{1566474244, 2295005189, 2141650953, 2138243089, 2074574866, 2962980889, 2137096228, 2555576367, 111607860, 2097184821, 2136440888, 2017132602, 1987510339, 1999994948, 2109571154, 2158952538, 1993343070, 2168717408, 2069790826, 85229681, 2111766645, 193233024, 2094858374, 2104557718, 2140766383, 2003108016, 1519648961, 2408186052, 2055667921, 2142535891, 2101281000, 2154594550, 1877442813, 2008350976, 2081980673, 2027913476, 196542726, 1991377165, 2130313487, 2167341328, 2062516497, 2163900690, 2030993695, 2120745256, 2166128942, 1866400055, 2061402433, 2963898694, 2170782029, 1994326354, 2166817130, 2101936509, 2127724926, 2151907713, 2406810007, 2149417376, 2158952869, 2009727399, 2623898022, 2408513978, 2092335550, 167379394, 2097119693, 2121236959, 2086863338, 2051834357, 2400190969, 2153579005, 1565557263, 2118418963, 2124120603, 1542652451, 2185757220, 2167439930, 1969357402, 1769669223, 2132771435, 2161443453, 2128282242, 2115273357, 2583200403, 2083291795, 2140209831, 239494827

In [283]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_iterative_mincut(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")

start loss: 3282025708.736328, sets: [1492, 1493, 1489, 1589, 1545, 1587, 1512, 1488, 1534, 1518, 1510, 1581, 1583, 1509, 1494, 1527, 1528, 1564, 1569, 1469]
iter: 1, loss: 164923997.09239197, sets: [14151, 3449, 2912, 1769, 1640, 949, 775, 618, 563, 509, 480, 339, 372, 363, 340, 307, 319, 233, 244, 249]
iter: 2, loss: 19422994.998291016, sets: [9181, 3139, 2604, 2062, 1537, 1797, 927, 1111, 912, 860, 628, 904, 845, 678, 649, 602, 691, 421, 472, 561]
iter: 3, loss: 10622664.827667236, sets: [6949, 2797, 2450, 2044, 1582, 2238, 1164, 1261, 1114, 998, 683, 1182, 1067, 872, 769, 836, 884, 568, 542, 581]
iter: 4, loss: 9222518.667480469, sets: [6008, 2635, 2333, 1962, 1619, 2428, 1356, 1214, 1198, 1072, 748, 1347, 1232, 979, 839, 889, 1005, 604, 543, 570]
iter: 5, loss: 7122463.028869629, sets: [5648, 2553, 2216, 1893, 1638, 2492, 1434, 1145, 1244, 1100, 799, 1458, 1344, 1066, 893, 894, 1066, 574, 515, 609]
iter: 6, loss: 7222422.231231689, sets: [5477, 2495, 2118, 1824, 1609, 2508, 1490, 

In [284]:
for s in S_vertices:
    print(s)

{2141650953, 2138243089, 2074574866, 2029584418, 2137096228, 2113765418, 2555576367, 111607860, 2097184821, 2151383095, 2017132602, 1995866178, 1987510339, 53444680, 2109571154, 2158952538, 2168717408, 2119008358, 2069790826, 85229681, 2111766645, 2106294397, 2168914046, 193233024, 2094858374, 2099216531, 2104557718, 2077491357, 2037842078, 2140766383, 2003108016, 2111176881, 2122678453, 1519648961, 2405761224, 2055667921, 2142535891, 1951891667, 2105540829, 2113601767, 2101281000, 2095841525, 2154594550, 2139095291, 1877442813, 1580826879, 2008350976, 2081980673, 2027913476, 196542726, 2130313487, 2167341328, 2062516497, 2163900690, 1072529691, 2963112220, 2030993695, 2120745256, 2166128942, 1527152942, 1599373617, 1994916152, 2114027834, 1994326354, 2126217565, 2100658528, 2154463588, 2000257387, 2105180523, 2164359548, 2101936509, 2127724926, 2151907713, 2115240329, 2145845649, 2293170582, 2149417376, 2158199200, 2009727399, 2403959208, 2135327149, 2408513978, 2092335550, 2097119693

Additional tests on synthetic datasets:

In [279]:
from sklearn.metrics.cluster import adjusted_rand_score

random_seed = 1234
np.random.seed(random_seed)

n_vertices = 1000
n_features = 10
n_clusters = 5

centroid_variance = 1.0
node_variance = 0.1

ground_truth = [int(i / (n_vertices / n_clusters)) for i in range(n_vertices)]

test_lambdas = [100000]
n_repeated_runs = 1

p = 0.2

for alg, test_on_DAG in ["tree_DP", False], ["greedy_local_search", False], ["iterative_mincut", True], ["greedy_local_search", True]:
    print(alg, test_on_DAG)
    G, synthetic_centroids = create_synthetic_tree(n_vertices, n_features, n_clusters, 
                            centroid_variance=centroid_variance, node_variance=node_variance,random_seed=random_seed)
    # add random edges to make the tree into a more complex DAG
    if test_on_DAG:
        add_synthetic_edges(G, probability=0.01, random_seed=random_seed)
        print("DAG size:",len(G.nodes()), len(G.edges()))

    for i, node in enumerate(G.nodes()):
        if np.random.uniform() < p: # assign to random set with probability p
            random_set_i = np.random.randint(0, n_clusters)
            G.nodes[node]["features"] = tuple(np.random.normal(loc=synthetic_centroids[random_set_i], scale=node_variance, size=n_features))

    compute_vertex_edges(G)

    for lambda_back in test_lambdas:
        start_time = datetime.now()
        # run the algorithm multiple times and take the best result
        best_loss = np.inf
        for run_i in range(n_repeated_runs):
            seed = random_seed + run_i
            if alg == "iterative_mincut":
                loss, S_vertices, vertex_to_set_index = find_node_clustering_iterative_mincut(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed, verbose=True)
            elif alg == "greedy_local_search":
                loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed, verbose=True)
            elif alg == "tree_DP":
                loss, S_vertices, vertex_to_set_index = find_node_clustering_on_tree(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed, verbose=True)

            if loss < best_loss:
                best_loss = loss
                best_vertex_to_set_index = vertex_to_set_index
        print(f"lambda_back={lambda_back}, loss={best_loss}, runtime={datetime.now() - start_time}")
        print()

    print("-----------------------")

tree_DP False
start loss: 40100637.534804136, sets: [182, 223, 204, 209, 182]
Total best loss with fixed centroids:  607.8545713009472
iter: 1, loss: 377.5089160907829, sets: [78, 368, 130, 184, 240]
Total best loss with fixed centroids:  352.35580817128135
iter: 2, loss: 340.6536630132808, sets: [83, 374, 30, 245, 268]
Total best loss with fixed centroids:  328.7709639841275
iter: 3, loss: 317.5872542551611, sets: [192, 316, 14, 221, 257]
Total best loss with fixed centroids:  307.80077296492107
iter: 4, loss: 297.1272876842463, sets: [278, 232, 32, 201, 257]
Total best loss with fixed centroids:  280.970785122488
iter: 5, loss: 269.95051691836045, sets: [322, 87, 119, 195, 277]
Total best loss with fixed centroids:  265.41550375850824
iter: 6, loss: 260.7670667046778, sets: [290, 59, 140, 194, 317]
Total best loss with fixed centroids:  258.2573005838414
iter: 7, loss: 256.9405916014889, sets: [266, 56, 149, 193, 336]
Total best loss with fixed centroids:  256.4557022449643
iter: 8, 