In [1]:
import networkx as nx
import numpy as np
import pickle
from tqdm.notebook import tqdm
import heapq
import matplotlib.pyplot as plt
from datetime import datetime

In [151]:
from collections import Counter

def neg_log_likelihood(labels):
    if len(labels) == 0:
        return 0
    loss = 0
    counter = Counter()
    for label in labels:
        counter[label] += 1
    size = np.sum(list(counter.values()))
    for label, count in counter.items():
        loss += count*np.log(count/size)
    return -loss

def count_forward_and_backward_edges(i, S_vertices, S_edges, vertex_to_set_index):
    forward_edges = 0
    backward_edges = 0
    for _, v in S_edges[i]:
        if v in S_vertices[i] or v not in vertex_to_set_index:
            continue
        j = vertex_to_set_index[v]
        if i < j:
            forward_edges += 1
        elif j < i:
            backward_edges += 1
    return forward_edges, backward_edges

def compute_loss_nll(G, S_vertices, S_edges, lambdas, vertex_to_set_index):
    lambda_fwd, lambda_back = lambdas
    loss = 0
    for i in range(len(S_vertices)):
        labels = [G.nodes[v]['features'] for v in S_vertices[i]]
        forward_edges, backward_edges = count_forward_and_backward_edges(i, S_vertices, S_edges, vertex_to_set_index)
        loss += neg_log_likelihood(labels) + lambda_fwd * forward_edges + lambda_back * backward_edges
    return loss

def L2_loss(features, mean=0):
    return np.sum(np.square(features-mean))

def compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index, centroids=None):
    lambda_fwd, lambda_back = lambdas
    loss = 0
    for i in range(len(S_vertices)):
        if len(S_vertices[i]) == 0:
            continue
        features = [G.nodes[v]['features'] for v in S_vertices[i]]
        if centroids is None:
            centroid = np.mean(features,axis=0)
        else:
            centroid = centroids[i]
        forward_edges, backward_edges = count_forward_and_backward_edges(i, S_vertices, S_edges, vertex_to_set_index)
        loss += L2_loss(features, centroid) + lambda_fwd * forward_edges + lambda_back * backward_edges
    return loss

def add_v(v, set_i, S_vertices, S_edges, vertex_to_set_index, G):
    S_vertices[set_i].add(v)
    vertex_to_set_index[v] = set_i
    S_edges[set_i] |= G.vertex_edges_from[v]

def remove_v(v, set_i, S_vertices, S_edges, vertex_to_set_index, G):
    S_vertices[set_i].remove(v)
    del vertex_to_set_index[v]
    for edge in G.vertex_edges_from[v]:
        S_edges[set_i].remove(edge)

def compute_vertex_edges(G):
    # for each vertex v, store which edges are from v to another vertex and from another vertex to v
    G.vertex_edges_from = dict({v:set() for v in G.nodes()})
    G.vertex_edges_to = dict({v:set() for v in G.nodes()})
    for u, v in G.edges():
        G.vertex_edges_from[u].add((u, v))
        G.vertex_edges_to[v].add((u, v))

In [18]:
def get_new_centroids(current_set_i, other_set_j, centroids, feature_v, size_from, size_to):
    if size_from == 1:
        new_centroid_from = None
    else:
        new_centroid_from = (size_from*centroids[current_set_i] - feature_v)/(size_from-1)
    if centroids[other_set_j] is None:
        new_centroid_to = feature_v
    else:
        new_centroid_to = (size_to*centroids[other_set_j] + feature_v)/(size_to+1)
    return new_centroid_from, new_centroid_to

def compute_delta_edge_loss(v, current_set_i, other_set_j, G, lambdas, vertex_to_set_index):
    # compute the change in loss from edges if we move v from current_set_i to other_set_j
    lambda_fwd, lambda_back = lambdas
    delta_edge_loss = 0
    for _, u in G.vertex_edges_from[v]:
        if vertex_to_set_index[u] > current_set_i and vertex_to_set_index[u] < other_set_j:
            delta_edge_loss += lambda_back-lambda_fwd
        elif vertex_to_set_index[u] < current_set_i and vertex_to_set_index[u] > other_set_j:
            delta_edge_loss += lambda_fwd-lambda_back
        elif vertex_to_set_index[u] == current_set_i:
            if other_set_j > current_set_i:
                delta_edge_loss += lambda_back
            else:
                delta_edge_loss += lambda_fwd
        elif vertex_to_set_index[u] == other_set_j:
            if other_set_j > current_set_i:
                delta_edge_loss -= lambda_fwd
            else:
                delta_edge_loss -= lambda_back
    for u, _ in G.vertex_edges_to[v]:
        if vertex_to_set_index[u] > current_set_i and vertex_to_set_index[u] < other_set_j:
            delta_edge_loss += lambda_fwd-lambda_back
        elif vertex_to_set_index[u] < current_set_i and vertex_to_set_index[u] > other_set_j:
            delta_edge_loss += lambda_back-lambda_fwd
        elif vertex_to_set_index[u] == current_set_i:
            if other_set_j > current_set_i:
                delta_edge_loss += lambda_fwd
            else:
                delta_edge_loss += lambda_back
        elif vertex_to_set_index[u] == other_set_j:
            if other_set_j > current_set_i:
                delta_edge_loss -= lambda_back
            else:
                delta_edge_loss -= lambda_fwd
    return delta_edge_loss

In [200]:
def compute_centroids(G, S_vertices):
    # compute centroids for each set or None if set is empty
    centroids = [np.mean([G.nodes[v]['features'] for v in S_vertices[i]],axis=0) \
                 if len(S_vertices[i]) > 0 else None \
                 for i in range(len(S_vertices))]
    return centroids

In [252]:
def find_node_clustering_greedy_local_search(G, k, lambdas, max_iter = 50, random_seed=None, given_partition=None, verbose=False):
    if random_seed is not None:
        np.random.seed(random_seed)
    lambda_fwd, lambda_back = lambdas

    S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
    vertex_to_set_index = dict() # for each vertex v, which set does v belong to

    S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set

    for v in G.nodes():
        if given_partition is None:
            set_index = np.random.randint(k)
        else:
            set_index = given_partition[v]
        add_v(v, set_index, S_vertices, S_edges, vertex_to_set_index, G)

    centroids = compute_centroids(G, S_vertices)

    total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index, centroids)
    iter = 0
    loss_decreasing = True
    while loss_decreasing and iter <= max_iter:
        iter += 1
        if verbose:
            print("iter:",iter, "loss:", total_loss)
        loss_decreasing = False
        for v in G.nodes():
            best_delta_loss = 0
            current_set_i = vertex_to_set_index[v]
            best_set_i = current_set_i

            feature_v = G.nodes[v]['features']
            size_from = len(S_vertices[current_set_i])
            if size_from == 1:
                current_set_L2_loss = 0
            else:
                current_set_L2_loss = size_from/(size_from-1) * L2_loss(feature_v, centroids[current_set_i])

            for other_set_j in range(k):
                if other_set_j == current_set_i:
                    continue
                size_to = len(S_vertices[other_set_j])

                other_set_L2_loss = size_to/(size_to+1) * L2_loss(feature_v, centroids[other_set_j])
                delta_edge_loss = compute_delta_edge_loss(v, current_set_i, other_set_j, G, lambdas, vertex_to_set_index)

                delta_loss = other_set_L2_loss - current_set_L2_loss + delta_edge_loss

                if delta_loss < best_delta_loss:
                    best_delta_loss = delta_loss
                    best_set_i = other_set_j
                
            if best_delta_loss < 0:
                loss_decreasing = True
                total_loss += best_delta_loss
                centroids[current_set_i], centroids[best_set_i] = get_new_centroids(current_set_i, best_set_i, centroids, feature_v, size_from, len(S_vertices[best_set_i]))
                remove_v(v, current_set_i, S_vertices, S_edges, vertex_to_set_index, G)
                add_v(v, best_set_i, S_vertices, S_edges, vertex_to_set_index, G)
    return total_loss, S_vertices, vertex_to_set_index

In [155]:
def add_subtree_to_best_set(v, parent_set_index, 
        best_loss_if_v_in_set_i, best_direction_with_parent_set, best_loss_if_v_in_earlier_set, 
        best_loss_if_v_in_later_set, S_vertices, S_edges, vertex_to_set_index, G):
    if parent_set_index is None:
        best_set_index = min(best_loss_if_v_in_set_i[v], key=best_loss_if_v_in_set_i[v].get)
    else:
        best_direction = best_direction_with_parent_set[v][parent_set_index]
        if best_direction == 0:
            best_set_index = parent_set_index
        elif best_direction == 1:
            best_set_index = best_loss_if_v_in_earlier_set[v][parent_set_index-1][1]
        elif best_direction == 2:
            best_set_index = best_loss_if_v_in_later_set[v][parent_set_index+1][1]
    add_v(v, best_set_index, S_vertices, S_edges, vertex_to_set_index, G)
    for _, child_node in G.vertex_edges_from[v]:
        add_subtree_to_best_set(child_node, best_set_index, 
            best_loss_if_v_in_set_i, best_direction_with_parent_set, best_loss_if_v_in_earlier_set, 
            best_loss_if_v_in_later_set, S_vertices, S_edges, vertex_to_set_index, G)

In [6]:
def compute_empty_set_centroids_and_l2_losses(G, k, vertex_to_set_index, centroids):
    # compute l2 loss for each node in each set
    empty_clusters = []
    l2_losses_dict = dict()
    current_l2_losses = []
    for set_i in range(k):
        if centroids[set_i] is None:
            empty_clusters.append(set_i)
            continue
        l2_losses_dict[set_i] = dict()
        for node in G.nodes():
            l2_loss = L2_loss(G.nodes[node]["features"], centroids[set_i])
            l2_losses_dict[set_i][node] = l2_loss
            if set_i == vertex_to_set_index[node]:
                current_l2_losses.append((l2_loss, node))

    # assign the empty clusters with centroids having same features as the nodes with largest l2 losses to their current clusters
    node_indices_with_largest_distances = heapq.nlargest(len(empty_clusters), range(len(current_l2_losses)), current_l2_losses.__getitem__)
    for i, empty_cluster_i in enumerate(empty_clusters):
        node = current_l2_losses[node_indices_with_largest_distances[i]][1]
        centroids[empty_cluster_i] = np.array(G.nodes[node]["features"])
        l2_losses_dict[empty_cluster_i] = dict()
        for node in G.nodes():
            l2_loss = L2_loss(G.nodes[node]["features"], centroids[empty_cluster_i])
            l2_losses_dict[empty_cluster_i][node] = l2_loss

    return centroids, l2_losses_dict

In [253]:
def find_node_clustering_on_tree(G, k, lambdas, max_iter = 50, random_seed=None, given_partition=None, verbose=False):
    if random_seed is not None:
        np.random.seed(random_seed)
    lambda_fwd, lambda_back = lambdas

    S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
    vertex_to_set_index = dict() # for each vertex v, which set does v belong to    
    S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set   
    
    root_vertices = [v for v in G.nodes() if G.in_degree(v) == 0]

    for v in G.nodes():
        if given_partition is None:
            set_index = np.random.randint(k)
        else:
            set_index = given_partition[v]
        add_v(v, set_index, S_vertices, S_edges, vertex_to_set_index, G)

    current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index)
    if verbose:
        print(f"start loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    iter = 0
    loss_decreasing = True

    while loss_decreasing and iter < max_iter:
        iter += 1
        loss_decreasing = False

        # compute centroids
        centroids = compute_centroids(G, S_vertices)
        centroids, l2_losses_dict = compute_empty_set_centroids_and_l2_losses(G, k, vertex_to_set_index, centroids)

        # reset all sets
        S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
        vertex_to_set_index = dict() # for each vertex v, which set does v belong to
        S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set

        best_loss_if_v_in_set_i = dict()
        best_loss_if_v_in_earlier_set = dict()
        best_loss_if_v_in_later_set = dict()
        best_direction_with_parent_set = dict()
        
        total_best_loss = 0

        for root in root_vertices:
            layers = [layer for layer in nx.bfs_layers(G, root)]
            for layer in layers[::-1]:
                for node in layer:
                    best_loss_if_v_in_set_i[node] = dict()
                    best_loss_if_v_in_earlier_set[node] = dict()
                    best_loss_if_v_in_earlier_set[node][-1] = (np.inf, -1)
                    best_loss_if_v_in_later_set[node] = dict()
                    best_loss_if_v_in_later_set[node][k] = (np.inf, -1)
                    best_direction_with_parent_set[node] = dict()

                    best_loss, best_set_index = np.inf, -1
                    for set_i in range(k):
                        loss = l2_losses_dict[set_i][node]

                        for _, child_node in G.vertex_edges_from[node]:
                            min_direction, min_value = min(
                                enumerate([best_loss_if_v_in_set_i[child_node][set_i],
                                        best_loss_if_v_in_earlier_set[child_node][set_i-1][0]+lambda_back,
                                        best_loss_if_v_in_later_set[child_node][set_i+1][0]+lambda_fwd]),
                                key=lambda x: x[1])
                            loss += min_value

                            # Store assuming parent goes to set_i whether we want to put child_node in set_i, 
                            #  or in some earlier set, or in some later set
                            best_direction_with_parent_set[child_node][set_i] = min_direction
                    
                        best_loss_if_v_in_set_i[node][set_i] = loss
                        if loss < best_loss:
                            best_loss, best_set_index = loss, set_i
                        best_loss_if_v_in_earlier_set[node][set_i] = (best_loss, best_set_index)
                    best_loss = np.inf
                    best_set_index = -1
                    for set_i in range(k-1, -1, -1):
                        if best_loss_if_v_in_set_i[node][set_i] < best_loss:
                            best_loss, best_set_index = best_loss_if_v_in_set_i[node][set_i], set_i
                        best_loss_if_v_in_later_set[node][set_i] = (best_loss, best_set_index)

            add_subtree_to_best_set(root, None, 
                best_loss_if_v_in_set_i, best_direction_with_parent_set, best_loss_if_v_in_earlier_set, 
                best_loss_if_v_in_later_set, S_vertices, S_edges, vertex_to_set_index, G)
            total_best_loss += best_loss
        if verbose:
            print("Total best loss with fixed centroids: ", total_best_loss)
                
        previous_loss = current_total_loss
        current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index)
        if current_total_loss < previous_loss:
            loss_decreasing = True
        if verbose:
            print(f"iter: {iter}, loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    return current_total_loss, S_vertices, vertex_to_set_index

In [262]:
def find_node_clustering_iterative_mincut(G, k, lambdas, max_iter = 50, random_seed=None, given_partition=None, verbose=False):
    if random_seed is not None:
        np.random.seed(random_seed)
    lambda_fwd, lambda_back = lambdas

    S_vertices = [set() for i in range(k)] # for each set i, which vertices are in that set
    vertex_to_set_index = dict() # for each vertex v, which set does v belong to    
    S_edges = [set() for i in range(k)] # for each set i, which edges are within that set or from that set to another set   

    root_vertices = [v for v in G.nodes() if G.in_degree(v) == 0]

    for v in G.nodes():
        if given_partition is None:
            set_index = np.random.randint(k)
        else:
            set_index = given_partition[v]
        add_v(v, set_index, S_vertices, S_edges, vertex_to_set_index, G)

    centroids = compute_centroids(G, S_vertices)

    current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index)
    if verbose:
        print(f"start loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    iter = 0
    loss_decreasing = True

    while loss_decreasing and iter < max_iter:
        iter += 1
        loss_decreasing = False

        for set_i in range(k-1):
            for set_j in range(set_i+1, k):
                if len(S_vertices[set_i]) == 0 and len(S_vertices[set_j]) == 0:
                    continue
                
                H = nx.DiGraph()
                for set_ij in [set_i, set_j]:
                    # if set_ij is empty, then make the feature of max distance vertex in the other set as the centroid
                    if len(S_vertices[set_ij]) == 0:
                        other_set_ji = set_i if set_ij == set_j else set_j
                        max_distance_vertex = max(S_vertices[other_set_ji], key=lambda v: L2_loss(G.nodes[v]["features"], centroids[other_set_ji]))
                        centroids[set_ij] = np.array(G.nodes[max_distance_vertex]["features"])

                    # create graoh H with vertices from set_i and set_j and compute mincut
                    for v in S_vertices[set_ij]:
                        edges_to_v = 0
                        edges_from_v = 0
                        for _, u in G.vertex_edges_from[v]:
                            if vertex_to_set_index[u] > set_i and vertex_to_set_index[u] < set_j:
                                edges_from_v += 1
                        for u, _ in G.vertex_edges_to[v]:
                            if vertex_to_set_index[u] > set_i and vertex_to_set_index[u] < set_j:
                                edges_to_v += 1
                        s_v_weight = L2_loss(G.nodes[v]["features"], centroids[set_j]) + lambda_fwd * edges_to_v + lambda_back * edges_from_v
                        v_t_weight = L2_loss(G.nodes[v]["features"], centroids[set_i]) + lambda_back * edges_to_v + lambda_fwd * edges_from_v
                        H.add_edge("_s", v, weight=s_v_weight)
                        H.add_edge(v, "_t", weight=v_t_weight)
                    for u, v in S_edges[set_ij]:
                        if v in S_vertices[set_i] or v in S_vertices[set_j]:
                            H.add_edge(u, v, weight=lambda_fwd)
                            H.add_edge(v, u, weight=lambda_back)

                _, partition = nx.minimum_cut(H, "_s", "_t", capacity="weight")

                for partition_i, set_ij in enumerate([set_i, set_j]):
                    S_edges[set_ij] = set()
                    S_vertices[set_ij] = set()
                    for v in partition[partition_i]:
                        if v != "_s" and v != "_t":
                            add_v(v, set_ij, S_vertices, S_edges, vertex_to_set_index, G)
                    if len(S_vertices[set_ij]) > 0:
                        centroids[set_ij] = np.mean([G.nodes[v]["features"] for v in S_vertices[set_ij]], axis=0)
                    else:
                        centroids[set_ij] = None
        
        previous_loss = current_total_loss
        current_total_loss = compute_loss_L2(G, S_vertices, S_edges, lambdas, vertex_to_set_index, centroids)
        if current_total_loss < previous_loss:
            loss_decreasing = True
        if verbose:
            print(f"iter: {iter}, loss: {current_total_loss}, sets: {[len(S_vertices[s_i]) for s_i in range(k)]}")

    return current_total_loss, S_vertices, vertex_to_set_index

In [248]:
def create_synthetic_tree(n_vertices, n_features, n_clusters, centroid_variance=10, node_variance=0.1, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)
    synthetic_centroids = centroid_variance*np.random.uniform(size=(n_clusters, n_features))

    G = nx.DiGraph()
    G.add_nodes_from(range(n_vertices))
    for i, node in enumerate(G.nodes()):
        set_i = int(i / (n_vertices / n_clusters))
        G.nodes[node]["features"] = tuple(np.random.normal(loc=synthetic_centroids[set_i], scale=node_variance, size=n_features))

        #add edge to random earlier node to create a tree
        if node == 0:
            continue
        earlier_node = np.random.randint(0, node)
        G.add_edge(earlier_node, node)

    compute_vertex_edges(G)
    return G, synthetic_centroids

def add_synthetic_edges(G, probability = 0.05, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)
    nodes = list(G.nodes())
    for i in range(len(nodes)-1):
        for j in range(i+1, len(nodes)):
            if np.random.uniform() < probability:
                G.add_edge(nodes[i], nodes[j])
    compute_vertex_edges(G)

In [218]:
def plot_synthetic_test(ps, rand_scores, test_lambdas, filename=None, show_legend=True):
    plt.rcParams['text.usetex'] = True
    plt.rcParams.update({'font.size': 22})
    fig, ax = plt.subplots(1,1)
    fig.set_size_inches(4, 3)
    for lambda_back in test_lambdas:
        if lambda_back == 100000:
            plt.plot(ps, rand_scores[lambda_back], label=f"$\lambda_b=10^5$")
        else:
            plt.plot(ps, rand_scores[lambda_back], label=f"$\lambda_b={lambda_back}$")
    plt.xlabel("Probability p")
    plt.xticks(np.arange(0,1.0+1e-8,0.2), ["0", ".2", ".4", ".6", ".8", "1"])
    plt.yticks(np.arange(0,1.0+1e-8,0.2), ["0", ".2", ".4", ".6", ".8", "1"])
    ax.spines[['right','top']].set_visible(False)
    plt.grid(dashes=(2, 6))
    if show_legend:
        plt.ylabel("ARI")
        plt.legend(handlelength=1.0, borderpad=0.2, labelspacing=0.1, handletextpad=0.1, borderaxespad=0.1)
    plt.tight_layout()

    if filename is not None:
        plt.savefig(filename, dpi=100, pad_inches=0, bbox_inches='tight')
    plt.show()

In [None]:
from sklearn.metrics.cluster import adjusted_rand_score

random_seed = 1234
np.random.seed(random_seed)

n_vertices = 1000
n_features = 10
n_clusters = 5

centroid_variance = 1.0
node_variance = 0.1

ground_truth = [int(i / (n_vertices / n_clusters)) for i in range(n_vertices)]

step_size = 0.05
test_lambdas = [0, 100000]
n_repeated_runs = 10

show_legend = True

for alg, test_on_DAG in ["tree_DP", False], ["greedy_local_search", False], ["iterative_mincut", True], ["greedy_local_search", True]:
    print(alg, test_on_DAG)
    G, synthetic_centroids = create_synthetic_tree(n_vertices, n_features, n_clusters, 
                            centroid_variance=centroid_variance, node_variance=node_variance,random_seed=random_seed)
    # add random edges to make the tree into a more complex DAG
    if test_on_DAG:
        add_synthetic_edges(G, probability=0.01, random_seed=random_seed)
        print("DAG size:",len(G.nodes()), len(G.edges()))

    p = 0
    rand_scores = dict()
    ps = []
    while p <= 1.0+1e-8:
        for i, node in enumerate(G.nodes()):
            if np.random.uniform() < p: # assign to random set with probability p
                random_set_i = np.random.randint(0, n_clusters)
                G.nodes[node]["features"] = tuple(np.random.normal(loc=synthetic_centroids[random_set_i], scale=node_variance, size=n_features))

        compute_vertex_edges(G)

        for lambda_back in test_lambdas:
            # run the algorithm multiple times and take the best result
            best_loss = np.inf
            for run_i in range(n_repeated_runs):
                seed = random_seed + run_i
                if alg == "iterative_mincut":
                    loss, S_vertices, vertex_to_set_index = find_node_clustering_iterative_mincut(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed)
                elif alg == "greedy_local_search":
                    loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed)
                elif alg == "tree_DP":
                    loss, S_vertices, vertex_to_set_index = find_node_clustering_on_tree(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed)

                if loss < best_loss:
                    best_loss = loss
                    best_vertex_to_set_index = vertex_to_set_index

            clustering = [best_vertex_to_set_index[node] for node in G.nodes()]
            # compute adjusted rand index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
            rand_score = adjusted_rand_score(ground_truth, clustering)

            if lambda_back not in rand_scores:
                rand_scores[lambda_back] = []
            rand_scores[lambda_back].append(rand_score)
            print(f"lambda: {lambda_back}, p: {p}, adjusted rand score: {rand_score}, loss: {best_loss}")
        ps.append(p)
        p += step_size

    print("-----------------------")
    filename = f"synthetic_{'DAG' if test_on_DAG else 'tree'}_rand_index_{alg}_k{n_clusters}_top{n_repeated_runs}.pdf"
    plot_synthetic_test(ps, rand_scores, test_lambdas, filename, show_legend=show_legend)
    show_legend = False

Tests on Reddit data

In [288]:
with open("reddit_graph.pkl", "rb") as file:
    G = pickle.load(file)

In [None]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")
print()
for s in S_vertices:
    print(s)

iter: 1 loss: 3940665656.576416
iter: 2 loss: 50159011.38318025
iter: 3 loss: 15856699.188728653
iter: 4 loss: 11056192.979447143
iter: 5 loss: 9256008.28808104
iter: 6 loss: 8055877.264409343
iter: 7 loss: 7455769.235399244
iter: 8 loss: 6855700.741376637
iter: 9 loss: 6455657.720226466
iter: 10 loss: 6155629.525961521
iter: 11 loss: 6155606.325383788
iter: 12 loss: 5955590.379249708
iter: 13 loss: 5955577.783016718
iter: 14 loss: 5955566.366590925
iter: 15 loss: 5955559.315084102
iter: 16 loss: 5955552.776681843
iter: 17 loss: 5855546.90841427
iter: 18 loss: 5655544.340509312
iter: 19 loss: 5655542.688244269
iter: 20 loss: 5455541.560956437
iter: 21 loss: 5455539.601228041
iter: 22 loss: 5455538.604204335
iter: 23 loss: 5455538.130013668
iter: 24 loss: 5455537.972450399
iter: 25 loss: 5455537.596910045
iter: 26 loss: 5455537.261278185
iter: 27 loss: 5455536.916264038
iter: 28 loss: 5455536.770827867
iter: 29 loss: 5455536.722737536
iter: 30 loss: 5455536.695093758
iter: 31 loss: 5455

In [None]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_on_tree(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")
print()
for s in S_vertices:
    print(s)

start loss: 3940665656.576416, sets: [3679, 3712, 3631, 3792, 3823, 3749, 3816, 3671, 3693, 3743, 3752, 3868, 3803, 3705, 3674, 3747, 3739, 3739, 3751, 3691]
Total best loss with fixed centroids:  65291.85737448931
iter: 1, loss: 60142.331115722656, sets: [4076, 2354, 2902, 2248, 1108, 10755, 6411, 1639, 3358, 952, 3646, 2203, 7220, 4310, 2907, 2040, 6704, 944, 3954, 5047]
Total best loss with fixed centroids:  58652.51109749079
iter: 2, loss: 57612.9521484375, sets: [4560, 2056, 4557, 2356, 2192, 7521, 3446, 2214, 5324, 1549, 2967, 2741, 6957, 3632, 4289, 3248, 4889, 3191, 1904, 5185]
Total best loss with fixed centroids:  56920.111476421356
iter: 3, loss: 56293.97341918945, sets: [3941, 2215, 5437, 3049, 3107, 5966, 3063, 2730, 5353, 2219, 2125, 2999, 6489, 4371, 4982, 3869, 3898, 3583, 1032, 4350]
Total best loss with fixed centroids:  56108.17658709362
iter: 4, loss: 55967.83241271973, sets: [3431, 2397, 5799, 3139, 3308, 5309, 2962, 2853, 5407, 2383, 2106, 3083, 5762, 5225, 5086, 

Tests on DBLP data

In [None]:
with open("dblp_graph.pkl", "rb") as file:
    G = pickle.load(file)

In [None]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")
print()
for s in S_vertices:
    print(s)

iter: 1 loss: 3454825708.736328
iter: 2 loss: 155923445.87618545
iter: 3 loss: 52422572.51298591
iter: 4 loss: 37022352.47483971
iter: 5 loss: 32222281.250284005
iter: 6 loss: 30722246.015972916
iter: 7 loss: 29922224.832657736
iter: 8 loss: 29622211.70438918
iter: 9 loss: 29322206.53553498
iter: 10 loss: 29322203.689445738
iter: 11 loss: 29222202.978724536
iter: 12 loss: 29022202.89079858
iter: 13 loss: 29022202.262213916
iter: 14 loss: 29022202.1923197
iter: 15 loss: 28922201.445758622
iter: 16 loss: 28722201.17078541
iter: 17 loss: 28722200.724163294
iter: 18 loss: 28722200.72305447
iter: 19 loss: 28722200.72103421
iter: 20 loss: 28722200.71969105
iter: 21 loss: 28722200.71859103
iter: 22 loss: 28722200.7185615
Elapsed time: 0:32:31.521309

{2067406853, 3003219974, 2066317327, 2169487378, 195829777, 1505107989, 1843290139, 2171199516, 2009997344, 1958174753, 2118434861, 2158714928, 1531256877, 2164932656, 2612846646, 2999803960, 2789220419, 2776252482, 2773418060, 2105245773, 292451

In [None]:
current_time = datetime.now()
loss, S_vertices, vertex_to_set_index = find_node_clustering_iterative_mincut(G, 20, lambdas=(0, 100000), max_iter=75, random_seed=0, verbose=True)
print(f"Elapsed time: {datetime.now() - current_time}")
print()
for s in S_vertices:
    print(s)

start loss: 3454825708.736328, sets: [1492, 1493, 1489, 1589, 1545, 1587, 1512, 1488, 1534, 1518, 1510, 1581, 1583, 1509, 1494, 1527, 1528, 1564, 1569, 1469]
iter: 1, loss: 275724269.8200531, sets: [17109, 2981, 2240, 1416, 1284, 886, 597, 571, 465, 331, 304, 287, 358, 285, 385, 196, 271, 218, 208, 189]
iter: 2, loss: 38423291.71966553, sets: [11530, 3004, 2734, 1995, 1325, 2043, 708, 763, 617, 457, 876, 497, 628, 448, 749, 454, 565, 386, 428, 374]
iter: 3, loss: 16622992.587921143, sets: [9536, 2491, 3081, 2309, 1363, 2220, 759, 827, 688, 534, 1110, 661, 730, 554, 861, 576, 678, 497, 601, 505]
iter: 4, loss: 10622831.793273926, sets: [8504, 2109, 3203, 2676, 1458, 2162, 831, 875, 726, 567, 1191, 764, 737, 614, 869, 644, 700, 579, 773, 599]
iter: 5, loss: 9522751.41204834, sets: [7932, 1910, 3321, 2839, 1467, 2093, 921, 892, 764, 596, 1218, 825, 719, 623, 888, 691, 723, 662, 826, 671]
iter: 6, loss: 8422705.4163208, sets: [7577, 1804, 3404, 2972, 1455, 2067, 976, 891, 790, 636, 1213, 8

Additional tests on synthetic datasets:

In [279]:
from sklearn.metrics.cluster import adjusted_rand_score

random_seed = 1234
np.random.seed(random_seed)

n_vertices = 1000
n_features = 10
n_clusters = 5

centroid_variance = 1.0
node_variance = 0.1

ground_truth = [int(i / (n_vertices / n_clusters)) for i in range(n_vertices)]

test_lambdas = [100000]
n_repeated_runs = 1

p = 0.2

for alg, test_on_DAG in ["tree_DP", False], ["greedy_local_search", False], ["iterative_mincut", True], ["greedy_local_search", True]:
    print(alg, test_on_DAG)
    G, synthetic_centroids = create_synthetic_tree(n_vertices, n_features, n_clusters, 
                            centroid_variance=centroid_variance, node_variance=node_variance,random_seed=random_seed)
    # add random edges to make the tree into a more complex DAG
    if test_on_DAG:
        add_synthetic_edges(G, probability=0.01, random_seed=random_seed)
        print("DAG size:",len(G.nodes()), len(G.edges()))

    for i, node in enumerate(G.nodes()):
        if np.random.uniform() < p: # assign to random set with probability p
            random_set_i = np.random.randint(0, n_clusters)
            G.nodes[node]["features"] = tuple(np.random.normal(loc=synthetic_centroids[random_set_i], scale=node_variance, size=n_features))

    compute_vertex_edges(G)

    for lambda_back in test_lambdas:
        start_time = datetime.now()
        # run the algorithm multiple times and take the best result
        best_loss = np.inf
        for run_i in range(n_repeated_runs):
            seed = random_seed + run_i
            if alg == "iterative_mincut":
                loss, S_vertices, vertex_to_set_index = find_node_clustering_iterative_mincut(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed, verbose=True)
            elif alg == "greedy_local_search":
                loss, S_vertices, vertex_to_set_index = find_node_clustering_greedy_local_search(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed, verbose=True)
            elif alg == "tree_DP":
                loss, S_vertices, vertex_to_set_index = find_node_clustering_on_tree(G, n_clusters, lambdas=(0, lambda_back), random_seed=seed, verbose=True)

            if loss < best_loss:
                best_loss = loss
                best_vertex_to_set_index = vertex_to_set_index
        print(f"lambda_back={lambda_back}, loss={best_loss}, runtime={datetime.now() - start_time}")
        print()

    print("-----------------------")

tree_DP False
start loss: 40100637.534804136, sets: [182, 223, 204, 209, 182]
Total best loss with fixed centroids:  607.8545713009472
iter: 1, loss: 377.5089160907829, sets: [78, 368, 130, 184, 240]
Total best loss with fixed centroids:  352.35580817128135
iter: 2, loss: 340.6536630132808, sets: [83, 374, 30, 245, 268]
Total best loss with fixed centroids:  328.7709639841275
iter: 3, loss: 317.5872542551611, sets: [192, 316, 14, 221, 257]
Total best loss with fixed centroids:  307.80077296492107
iter: 4, loss: 297.1272876842463, sets: [278, 232, 32, 201, 257]
Total best loss with fixed centroids:  280.970785122488
iter: 5, loss: 269.95051691836045, sets: [322, 87, 119, 195, 277]
Total best loss with fixed centroids:  265.41550375850824
iter: 6, loss: 260.7670667046778, sets: [290, 59, 140, 194, 317]
Total best loss with fixed centroids:  258.2573005838414
iter: 7, loss: 256.9405916014889, sets: [266, 56, 149, 193, 336]
Total best loss with fixed centroids:  256.4557022449643
iter: 8, 