In [1]:
from cc_model.load_datasets import *
from cc_model.wl import *
from cc_model.utils import nx_to_gt
from cc_model.pagerank import all_pagerank
from cc_model.rewire import *

import networkx as nx
from pathlib import Path
import graph_tool.all as gt
import time

In [2]:
datasets = ["karate",  #"phonecalls",
            #"HepPh", "AstroPh", "web-Google", "soc-Pokec"
#            "deezer_HR", "deezer_HU", "deezer_RO","tw_musae_DE",
#            "tw_musae_ENGB","tw_musae_FR","lastfm_asia","fb_ath",
#            "fb_pol", "facebook_sc"
           ]

In [3]:
dataset_path = Path("/home/felix/projects/colorful_configuration/datasets")

In [4]:
#edges, is_directed = load_dataset(dataset_path, "soc-Pokec")

In [5]:
epsilon=1e-14
max_iter = 300
alpha=0.85

In [6]:
def run_pagerank(g, WL_round, outer_iter=10, verbosity=0):
    """ Generate synthethic networks which have the same WL colors as g at round WL_round and return absolute error"""
    mode = "theirs"
    base_pagerank = all_pagerank(g, mode, epsilon=epsilon, max_iter=max_iter, alpha=alpha)
    print(g.is_directed())
    pageranks = []
    GraphEnsemble = LocalHistogramRewiring(g, g.vp[f"color_{WL_round}"].get_array() )
    for i in range(outer_iter):
        if verbosity > 4:
            print("    ",i)
        new_g = GraphEnsemble.get_sample()
        assert new_g.is_directed() == g.is_directed()
        pagerank, err = all_pagerank(new_g, mode, epsilon=epsilon, max_iter=max_iter, alpha=alpha, return_err=True)

        if verbosity > 0:
            print("the error in pagerank iteration is:\r\n", err)
        pageranks.append(pagerank)
    error_sum = [np.sum(np.abs(base_pagerank-pagerank)) for pagerank in pageranks]
    if verbosity > 0:
        print("max", [np.max(np.abs(base_pagerank-pagerank)) for pagerank in pageranks])
    return error_sum

In [7]:
def get_MAE_for_iterations(g, edges, n_graphs):
    means = []
    stds = []
    labelings = WL_fast(edges)
    WL_iterations=len(labelings)
    print(WL_iteration)
    return
    for WL_round in range(WL_iterations):
        g_rewire = gt.Graph(g)
        if verbosity>0:
            print(WL_round)

        MAEs = run_pagerank(g_rewire, WL_round, outer_iter=n_graphs, verbosity=0)
        means.append(np.mean(MAEs))
        stds.append(np.std(MAEs))
    return means, stds

In [8]:
def compute_pagerank_on_all_Graphs(n_graphs, verbosity=0):
    list_means = []
    list_stds = []
    for dataset in datasets:
        if dataset==None:
            list_means.append([])
            list_stds.append([])
            continue
        if verbosity>0:
            print(dataset)
        if verbosity > 3:
            print("reading graph")
        G = load_gt_dataset_cached(dataset_path, dataset, verbosity=verbosity)
        print(G.num_edges(), G.num_vertices())
        
        edges = G.get_edges()
        if not G.is_directed():
            edges2 = np.vstack((edges[:,1], edges[:,0])).T

            edges = np.vstack((edges, edges2))
        
        if verbosity >3:
            print("done reading graph")
            print("starting WL")
            print(repr(G))
        #print("Done with WL")
        means, stds = get_MAE_for_iterations(G,edges, 
                                             n_graphs=n_graphs,)
        list_means.append(means)
        list_stds.append(stds)
    return list_means, list_stds

In [9]:
from collections import Counter

In [10]:
%load_ext snakeviz

In [11]:
#%%snakeviz --new-tab

verbosity=1
list_means, list_stds = compute_pagerank_on_all_Graphs(42, verbosity=1)

karate
78 34


NameError: name 'WL_fast' is not defined

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(10,6))
for means, stds,label in zip(list_means, list_stds, datasets):
    plt.errorbar(x=np.arange(len(means)),y=np.array(means)+1e-20, yerr=stds, label=label)
plt.ylabel("MAE of pagerank")
plt.xlabel("Iteration")
plt.yscale("log")
plt.legend()
plt.title("Convergence of pagerank for synthetic networks ")

In [None]:
G = load_gt_dataset_cached(dataset_path, "karate", verbosity=verbosity, force_reload=False)

In [None]:
def get_labelings(G):
    WL_iterations, labelings = WL(G, add_labelings=True, verbosity=1)
    arr_labelings = np.array(labelings, dtype=int)
    return arr_labelings

def get_sorted_labelings(G):
    """Computes the Wl labeling of a graph and sorts it such that the classes are in blocks"""
    arr_labelings = get_labelings(G)
    order = np.lexsort(arr_labelings[::-1,:])
    print(order)
    ordered_labelings = arr_labelings[:,order]
    return order, ordered_labelings

In [None]:
@njit
def to_in_neighbors(edges):
    in_degrees = np.bincount(edges[:,1])

    
    starting_positions = np.empty(in_degrees.shape[0]+1, dtype=np.int32)
    starting_positions[0]=0
    starting_positions[1:] = in_degrees.cumsum()
    current_index = starting_positions.copy()

    in_neighbors = np.zeros(edges.shape[0], dtype=np.int32)
    
    for l, r in edges:
        #if r < len(current_index)-1:
        #    assert current_index[r]<=starting_positions[r+1], f"{r} {current_index[r]} {starting_positions[r+1]} {current_index- starting_positions}"
        in_neighbors[current_index[r]] = l 
        current_index[r]+=1

    return starting_positions, in_neighbors, in_degrees.max()

In [None]:
%%time
to_in_neighbors(G.get_edges())

In [None]:
from numba.types import bool_
import numpy
@njit
def primesfrom2to(n):
    """ Input n>=6, Returns a array of primes, 2 <= p < n """
    size = int(n//3 + (n%6==2))
    sieve = numpy.ones(size, dtype=bool_)
    for i in range(1,int(n**0.5)//3+1):
        if sieve[i]:
            k=3*i+1|1
            sieve[       k*k//3     ::2*k] = False
            sieve[k*(k-2*(i&1)+4)//3::2*k] = False
    arr =  ((3*numpy.nonzero(sieve)[0][1:]+1)|1)
    output = np.empty(len(arr)+2, dtype=arr.dtype)
    output[0]=2
    output[1]=3
    output[2:]=arr
    return output

In [None]:
%%time
primesfrom2to(int(10_000_000))

In [None]:

def WL_fast(edges, max_iter=30):
    startings, neighbors, _ = to_in_neighbors(edges)
    return _WL_fast(startings, neighbors, max_iter)

@njit
def _WL_fast(startings, neighbors, max_iter=30):
    """WL using floating point operations with primes similar to 
    https://github.com/rmgarnett/fast_wl/blob/master/wl_transformation.m
    """
    num_nodes = len(startings)-1
    ln=np.log
    n=num_nodes
    if n >=6:
        correction = np.ceil(ln(n)+ln(ln(n))-1)
    else:
        correction = 5
    primes = primesfrom2to(num_nodes*correction)
    log_primes = np.log(primes)
    labels = np.zeros(num_nodes, dtype=np.uint32)
    vals = np.ones(num_nodes)
    
    out = []
    
    last_num_colors = 1
    for _ in range(max_iter):
        lb = startings[0]
        #print(vals)
        for i in range(num_nodes):
            lb = startings[i]
            ub = startings[i+1]
            for j in range(lb, ub):
                vals[i]+=log_primes[labels[neighbors[j]]]
        order = np.argsort(vals)
        last_val = vals[order[0]]
        num_colors = 0
        #print(vals)
        for node_id in order:
            val = vals[node_id]
            if last_val/val < 1-1e-16:
                num_colors += 1
                last_val = val

            labels[node_id] = num_colors
            vals[node_id] = primes[num_colors]
        out.append(labels.copy())
        if last_num_colors == num_colors:
            break
        else:
            last_num_colors = num_colors
            
    return out

In [None]:
G = load_gt_dataset_cached(dataset_path, "web-Google", verbosity=verbosity, force_reload=False)

In [None]:
@njit
def is_sorted(arr):
    return np.all((arr[1:]-arr[:-1])>0)

In [None]:
def WL_fast3(edges, max_iter=30):
    edges2 = np.empty_like(edges)
    edges2[:,0]= edges[:,1]
    edges2[:,1]= edges[:,0]
    startings, neighbors, max_degree = to_in_neighbors(edges2)
    return _WL_fast3(startings, neighbors, max_degree, max_iter)
@njit
def _WL_fast3(startings, neighbors, max_degree, max_iter=30):
    """WL using floating point operations with primes similar to 
    https://github.com/rmgarnett/fast_wl/blob/master/wl_transformation.m
    """
    raise ValueError("Did not work out as expected")
    num_nodes = len(startings)-1
    ln=np.log
    n=num_nodes
    if n >=6:
        # maybe remove the -1
        correction = np.ceil(ln(n)+ln(ln(n)))
    else:
        correction = 5
    primes = primesfrom2to(num_nodes*correction)
    log_primes = np.log(primes)
    deltas = log_primes.copy()
    labels = np.zeros(num_nodes, dtype=np.uint32)
    vals = np.ones(num_nodes)
    out = []
    order = np.arange(num_nodes)
    order_updates = order.copy()
    #skipped_updates = np.zeros(num_nodes, dtype=np.uint32)
    num_updates = num_nodes
    
    partitions = np.zeros(num_nodes+1, dtype=np.uint32)
    num_partitions = 1
    partitions[0] = 0
    partitions[1] = num_nodes
    
    affected_nodes = set()
    affected_nodes.add(0)
    affected_partitions = set()
    affected_partitions.add(0)
    #print("len", len(log_primes))
    last_num_colors = 1
    
    num_colors_sparse = 0
    for x in range(max_iter):


        #print(vals)
        #print(labels[order])
        #print(vals[order])
        print("num_updates", num_updates)
        if num_updates > num_nodes//4:
            for index in range(num_updates):# loop over all nodes that changed in last iter
                i = order_updates[index]
                lb = startings[i]
                ub = startings[i+1]
                for j in range(lb, ub): # propagate label of i to neighbor j
                    vals[neighbors[j]]+=deltas[labels[i]]

            # sort partitions such that the same values come after one another
            for i in range(num_partitions):           
                lb = partitions[i]
                ub = partitions[i+1]
                if ub-lb > 1:
                    partition_order = np.argsort(vals[order[lb:ub]])
                    order[lb:ub] = order[lb:ub][partition_order]
                    
                    num_colors = 0
            num_partitions=0
            last_index = 0
            num_updates = 0
            affected_nodes.clear()

            last_val = vals[order[0]]
            for i in range(len(order)):

                node_id = order[i]

                #i = order_index + skipped_updates[order_index]
                #print(order_index, i)
                val = vals[node_id]
                if val!=last_val:
                    num_partitions+=1
                    last_index = i
                    num_colors += 1
                    partitions[num_colors]=i
                    last_val = val
                    deltas[last_index] = log_primes[last_index]-log_primes[labels[node_id]]
                    #print(last_index, labels[node_id], deltas[last_index])
                #assert last_index-labels[node_id]>0

                if labels[node_id]!=last_index: #there is a need for updates 
                    order_updates[num_updates]=node_id
                    num_updates+=1
                else:
                    vals[node_id] += last_index-labels[node_id]

                labels[node_id] = last_index
        else:
            print("sparse")
            #print(partitions)
            # sparse implementation
            #affected_nodes.clear()
            for index in range(num_updates):# loop over all nodes that changed in last iter
                i = order_updates[index]
                lb = startings[i]
                ub = startings[i+1]
                affected_nodes.add(i)
                for j in range(lb, ub): # propagate label of i to neighbor j
                    vals[neighbors[j]]+=deltas[labels[i]]
                    affected_nodes.add(neighbors[j])
            
            affected_partitions.clear()
            #assert is_sorted(partitions[:num_partitions])
            for node_id in affected_nodes:
                partition = np.searchsorted(partitions[:num_partitions], node_id+1)
                affected_partitions.add(partition)
            print(len(affected_partitions))
            for p in affected_partitions:
                #assert p>0
                #assert p < num_partitions+1
                lb = partitions[p-1]
                ub = partitions[p]
                #print(lb, ub)
                #assert ub-lb > 0
                if ub-lb > 1:
                    partition_order = np.argsort(vals[order[lb:ub]])
                    order[lb:ub] = order[lb:ub][partition_order]
                
                
            num_colors = 0
            num_partitions=0
            last_index = 0
            num_updates = 0
            affected_nodes.clear()

            last_val = vals[order[0]]
            for i in range(len(order)):

                node_id = order[i]

                #i = order_index + skipped_updates[order_index]
                #print(order_index, i)
                val = vals[node_id]
                if val!=last_val:
                    num_partitions+=1
                    last_index = i
                    num_colors += 1
                    partitions[num_colors]=i
                    last_val = val
                    deltas[last_index] = log_primes[last_index]-log_primes[labels[node_id]]
                    #print(last_index, labels[node_id], deltas[last_index])
                #assert last_index-labels[node_id]>0

                if labels[node_id]!=last_index: #there is a need for updates 
                    order_updates[num_updates]=node_id
                    num_updates+=1
                    vals[node_id] += last_index-labels[node_id]
                    affected_nodes.add(node_id)

                labels[node_id] = last_index
            #+primes[num_colors]
        #print(vals[order])
        out.append(labels.copy())
        #print()
        if last_num_colors == num_colors:
            break
        else:
            last_num_colors = num_colors
            
    return out

# another thing that annoys me, is that each iteration is O(edges)
# but only a small number of nodes changes color


# This above WL implementation can be sped up significantly by not 
# sorting the whole array in each iteration but sorting the array only once
# this will give quite a boost as sorting is currently 
# the main cost of the algorithm 

In [None]:
def WL_fast2(edges, max_iter=30):
    edges2 = np.empty_like(edges)
    edges2[:,0]= edges[:,1]
    edges2[:,1]= edges[:,0]
    startings, neighbors, max_degree = to_in_neighbors(edges2)
    return _WL_fast2(startings, neighbors, max_degree, max_iter)
@njit
def is_sorted_fast(vals, order):
    last_val = vals[order[0]]
    for i in range(1, len(order)):
        if vals[order[i]]<last_val:
            return False
        last_val = vals[order[i]]
    return True

@njit
def _WL_fast2(startings, neighbors, max_degree, max_iter=30):
    """WL using floating point operations with primes similar to 
    https://github.com/rmgarnett/fast_wl/blob/master/wl_transformation.m
    """
    num_nodes = len(startings)-1
    ln=np.log
    n=num_nodes
    if n >=6:
        # maybe remove the -1
        correction = np.ceil(ln(n)+ln(ln(n)))
    else:
        correction = 5
    primes = primesfrom2to(num_nodes*correction)
    log_primes = np.log(primes)
    deltas = log_primes.copy()
    labels = np.zeros(num_nodes, dtype=np.uint32)
    vals = np.ones(num_nodes)
    out = []
    order = np.arange(num_nodes)
    order_updates = order.copy()
    #skipped_updates = np.zeros(num_nodes, dtype=np.uint32)
    num_updates = num_nodes
    
    partitions = np.zeros(num_nodes+1, dtype=np.uint32)
    num_partitions = 1
    partitions[0] = 0
    partitions[1] = num_nodes
    

    #print("len", len(log_primes))
    last_num_colors = 1
    
    num_colors_sparse = 0
    for x in range(max_iter):



        #print("num_updates", num_updates)
        for index in range(num_updates):# loop over all nodes that changed in last iter
            i = order_updates[index]
            lb = startings[i]
            ub = startings[i+1]
            for j in range(lb, ub): # propagate label of i to neighbor j
                vals[neighbors[j]]+=deltas[labels[i]]

        # sort partitions such that the same values come after one another
        for i in range(num_partitions):           
            lb = partitions[i]
            ub = partitions[i+1]
            if ub-lb > 1:
                if not is_sorted_fast(vals, order[lb:ub]):
                    partition_order = np.argsort(vals[order[lb:ub]])
                    order[lb:ub] = order[lb:ub][partition_order]

        num_colors = 0
        num_partitions=0
        last_index = 0
        num_updates = 0

        last_val = vals[order[0]]
        for i in range(len(order)):

            node_id = order[i]


            val = vals[node_id]
            if val!=last_val:
                num_partitions+=1
                last_index = i
                num_colors += 1
                partitions[num_colors]=i
                last_val = val
                deltas[last_index] = log_primes[last_index]-log_primes[labels[node_id]]


            if labels[node_id]!=last_index: #there is a need for updates 
                order_updates[num_updates]=node_id
                num_updates+=1
                vals[node_id] += last_index-labels[node_id]

            labels[node_id] = last_index
       
        out.append(labels.copy())
        #print()
        if last_num_colors == num_colors:
            break
        else:
            last_num_colors = num_colors
            
    return out

# another thing that annoys me, is that each iteration is O(edges)
# but only a small number of nodes changes color


# This above WL implementation can be sped up significantly by not 
# sorting the whole array in each iteration but sorting the array only once
# this will give quite a boost as sorting is currently 
# the main cost of the algorithm 

In [None]:
for i in range(0,0):
    print(i)

In [None]:
edges = G.get_edges()
if not G.is_directed():
    edges2 = np.vstack((edges[:,1], edges[:,0])).T

    edges = np.vstack((edges, edges2))

In [None]:
%load_ext snakeviz

In [None]:

%%snakeviz --new-tab

ret = WL_fast2(edges)

In [None]:
def get_maxs(l):
    return [x.max() for x in l]

def get_uniques(l):
    return [len(np.unique(x)) for x in l]
get_uniques(ret)

In [None]:
%%time
WL_iterations, labelings = WL(G, verbosity=1)

In [None]:
order, ordered_labelings = get_sorted_labelings(G)

In [None]:
ordered_labelings

In [None]:
def is_block(arr):
    """Checks whether a labeling is block ordered
    
    AACCBBBB is block ordered
    ABACCC is not block ordered
    """
    previous_seen = set()
    last_val = 0
    for i, val in enumerate(arr):
        if val == last_val:
            continue
        else:
            last_val=val
            if val in previous_seen:
                #print("AAAA", val, arr[i-10:i+10])
                return False
            else:
                previous_seen.add(val)
    return True

for i in range(ordered_labelings.shape[0]):
    arr = ordered_labelings[i,:]
    #print(np.nonzero(arr==4))
    print(np.count_nonzero(arr[1:]-arr[0:-1]))
    assert is_block(arr)
    print(is_block(arr))

In [None]:
inv_order = np.empty_like(order)
inv_order[order]=np.arange(len(inv_order))

In [None]:
inv_order

In [None]:
inv_order2 = np.empty_like(order)
for old_label, new_label in enumerate(order):
    inv_order2[new_label] = old_label

In [None]:
np.count_nonzero(np.abs(inv_order-inv_order2))

In [None]:
from numba import njit

@njit
def relabel_edges(inv_order, edges):
    print(edges.shape)
    for i in range(edges.shape[0]):
        edges[i][0] = inv_order[edges[i][0]]
        edges[i][1] = inv_order[edges[i][1]]
    return edges    

In [None]:
order

In [None]:
inv_order

In [None]:
inv_order3 = np.arange(len(order))[order]

In [None]:
edges = relabel_edges(inv_order, G.get_edges())

In [None]:
WL_fast(edges)

In [None]:
validate = defaultdict(set)
validate2 = defaultdict(set)
for val, key in zip(edges[:,0], G.get_edges()[:,0]):
    validate[val].add(key)
    validate2[key].add(val)
for val, key in zip(edges[:,1], G.get_edges()[:,1]):
    validate[val].add(key)
    validate2[key].add(val)

In [None]:
g2 = gt.Graph(directed=G.is_directed())
g2.add_edge_list(edges)
g2.num_edges()

In [None]:
def get_edge_id(labels, edges):
    max_label = labels.max()
    edge_id =  max_label*(labels[edges[:,0]]) + labels[edges[:,1]]
    print(edge_id.shape)
    return edge_id

@njit
def get_edge_id2(labels, edges, out):
    d = {(0, 0) : 0}
    del d[(0, 0)]
    is_mono = {0 : True}
    for i in range(len(edges)):
        e1, e2 = edges[i,:]
        tpl = (labels[e1], labels[e2])
        if not tpl in d:
            n = len(d)
            d[tpl] = n
            if labels[e1] == labels[e2]:
                is_mono[n] = True
        out[i] = d[tpl]    
    return out, is_mono

In [None]:
def get_shape(arr):
    return arr.shape

def get_shapes(arr):
    return list(map(get_shape, arr))

def get_lens(arr):
    return list(map(len, arr))

In [None]:

def get_dead_edges(labels, edges, dead_colors):
    is_dead_end1 = dead_colors[labels[edges[:,0]]]
    is_dead_end2 = dead_colors[labels[edges[:,1]]]
    return np.logical_or(is_dead_end1, is_dead_end2)



def get_dead_edges_full(edges, labelings, edges_classes):
    """ Computes the first index in which an edge is dead
    result [1,0,2] means 1st edge is dead after iteration 1, second edge is dead after iteration 0, 
                     third edge is dead after iteration 2
    
    """
    # maybe omit minlength
    dead_colors = [np.bincount(arr.ravel(), minlength=arr.max())==1 for arr in labelings]
    dead_edges = [get_dead_edges(labelings[i,:], edges, dead_colors[i]) for i in range(len(dead_colors))]
    # maybe bincount is very inefficient!
    dead_ids = [np.bincount(edges_classes[i], minlength = edges_classes[i].max()) <= 1 for i in range(len(dead_edges))]
    dead_edges2 = [dead_ids[i][edges_classes[i]] for i in range(len(dead_edges))]
    #print(get_shapes(dead_edges), get_shapes(dead_edges2))
    
    dead_edges_final = np.array([np.logical_or(a, b) for a, b in zip(dead_edges, dead_edges2)])
    #print(dead_edges_final)
    return dead_edges_final#, np.sum(dead_edges_final, axis=0)

In [None]:
from itertools import chain
def sort_edges(edges, labelings, directed = True):
    """Sort edges such that that edges of similar classes are consecutive
    
    additionally puts dead edges at the end
    
    """
    
    # WARNING If network is undirected edges need to be sorted first
    if directed == False:
        raise ValueError()
        

    #edges_classes = [get_edge_id(labelings[i,:], edges) for i in range(labelings.shape[0])]
    edges_classes = []
    is_mono = []
    for i in range(labelings.shape[0]):
        edge_class, mono = get_edge_id2(labelings[i,:], edges, np.empty(len(edges), dtype=np.uint32))
        edges_classes.append(edge_class)
        is_mono.append(mono)
    
    dead_indicator = get_dead_edges_full(edges, labelings, edges_classes)
    tmp = list(chain.from_iterable(zip(edges_classes, dead_indicator)))
    print(list(tmp))
    edges_classes_arr = np.vstack(edges_classes)
    to_sort_arr = np.vstack(tmp)#[dead_ids]+ edges_classes)
    
    # sort edges such that each of the classes are in order
    edge_order = np.lexsort(to_sort_arr[::-1,:])
    #print(edge_order)
    edges_ordered = edges[edge_order,:]
    return edges_ordered, edges_classes_arr[:, edge_order].T, dead_indicator[:, edge_order], is_mono

In [None]:
labelings2 = get_labelings(g2)#[:,order]

In [None]:
for i in range(labelings2.shape[0]):
    arr = labelings2[i,:]
    #print(np.nonzero(arr==4))
    #print(np.count_nonzero(arr[1:]-arr[0:-1]))
    print(is_block(arr))

In [None]:
edges_ordered, edges_classes, dead_arr, is_mono = sort_edges(edges, labelings2)

In [None]:
edges_classes

In [None]:
for i in range(edges_classes.shape[1]):
    print(is_block(edges_classes[:18561,i]))

In [None]:
arr = np.bincount(edges_classes[:18561,2])
inds = arr >0
print(arr[inds].min())
print(is_block(arr))

In [None]:
edges_classes[:,2].min()

In [None]:
dead_arr

In [None]:
#@njit
def _get_block_indices(arr_in, is_dead, out):
    """Returns the indices of block changes in arr
    input [4,4,2,2,3,5]
    output = [0,2,4,5,6]
    lower inclusive, upper exclusive
    
    """
    indices = np.arange(len(arr_in))[~is_dead]
    arr = arr_in[~is_dead]
    #print(arr)
    #print(indices)
    #print()
    if len(arr)==0:
        return out[:0, :]
    last_val = arr[0]
    out[0,0] = indices[0]
    n=0
    last_index=0
    for i, val in zip(indices, arr):

        if val == last_val:
            last_index=i
            continue
        else:
            last_val=val
            out[n,1]=last_index+1
            out[n+1,0]=i
            n+=1
            last_index=i
    out[n,1] = last_index+1
    if out[n,1]-out[n,0]>1:
        n+=1
    return out[:n,:]

def check_blocks(out_arr):
    block_lengths = out_arr[1:]-out_arr[0:len(out_arr)-1]
    inds = block_lengths <= 1
    assert np.all(block_lengths>1), f"{block_lengths[inds]} {out_arr[1:][inds]}"
        

#@njit
def get_block_indices(edges_classes, dead_arrs):
    """Returns an arr that contains the start and end of blocks"""
    out = []
    for arr, dead_arr in zip(edges_classes.T, dead_arrs):
        
        out_arr =_get_block_indices(arr, dead_arr, np.empty((len(arr),2), dtype=np.int32))
        #print(arr)
        #print(dead_arr)
        #c=45673
        #d=3
        #print(arr[c-d:c+d])
        #print(dead_arr[c-d:c+d])
        #print(out_arr)
        
        #check_blocks(out_arr)
        print(dead_arr.sum()+np.sum(out_arr[:,1]-out_arr[:,0]))
        print("block", np.sum(out_arr[:,1]-out_arr[:,0]))
        #print(len(edges_classes))
        out.append(out_arr)
        
    
    return out
    

In [None]:
x=get_block_indices(edges_classes, dead_arr)

In [None]:
@njit
def rewire_mono(edges, n_rewire):
    delta = len(edges)
    
    
    for _ in range(n_rewire):
        index1 = np.random.randint(0, delta)
        offset = np.random.randint(1, delta)
        i2_1 = np.random.randint(0, 2)
        i2_2 = 1 - i2_1
        index2 = (index1 + offset) % (delta)
        e1_l, e1_r = edges[index1,:]
        e2_l = edges[index2, i2_1]
        e2_r = edges[index2, i2_2]
        
        
        if (e1_r == e2_r) or (e1_l == e2_l): # swap would do nothing
            continue
            
        if (e1_l == e2_r) or (e1_r == e2_l): # no self loops after swab
            continue
        
        can_flip = True
        for i in range(len(edges)):
            ei_l, ei_r = edges[i,:]
            if ((ei_l == e1_l and ei_r == e2_r) or (ei_l == e2_l and ei_r == e1_r)
            or (ei_l == e1_r and ei_r == e2_l) or (ei_l == e2_r and ei_r == e1_l)):
                can_flip = False
                break
        if can_flip:
            edges[index1, 1] = e2_r
            edges[index2, 0] = e2_l
            edges[index2, 1] = e1_r

In [None]:
from numba.typed import List,Dict

In [None]:
@njit
def rewire_mono2(edges, n_rewire):
    delta = len(edges)
    neigh = Dict()
    neigh[0] = List([-1])
    del neigh[0]
    for l,r in edges:
        if l not in neigh:
            tmp = List([-1])
            tmp.pop()
            neigh[l] = tmp
        if r not in neigh:
            tmp = List([-1])
            tmp.pop()
            neigh[r] = tmp
        neigh[l].append(r)
        neigh[r].append(l)
    
    # start:
    # e1_l <-> e1_r
    # e2_l <-> e2_r
    # after
    # e1_l <-> e2_r
    # e2_l <-> e1_r
    
    for _ in range(n_rewire):
        index1 = np.random.randint(0, delta)
        offset = np.random.randint(1, delta)
        i2_1 = np.random.randint(0, 2)
        i2_2 = 1 - i2_1
        index2 = (index1 + offset) % (delta)
        e1_l, e1_r = edges[index1,:]
        e2_l = edges[index2, i2_1]
        e2_r = edges[index2, i2_2]
        
        
        if (e1_r == e2_r) or (e1_l == e2_l): # swap would do nothing
            continue
            
        if (e1_l == e2_r) or (e1_r == e2_l): # no self loops after swab
            continue
        
        can_flip = True
        if e2_r in neigh[e1_l] or e1_r in neigh[e2_l]:
            can_flip = False

        if can_flip:
            edges[index1, 1] = e2_r
            edges[index2, 0] = e2_l
            edges[index2, 1] = e1_r
            neigh[e1_l].remove(e1_r)
            neigh[e1_r].remove(e1_l)
            
            neigh[e2_l].remove(e2_r)
            neigh[e2_r].remove(e2_l)
            
            neigh[e1_l].append(e2_r)
            neigh[e2_r].append(e1_l)
            
            neigh[e2_l].append(e1_r)
            neigh[e1_r].append(e2_l)
            

In [None]:
np.random.seed(1)

edges = np.array([[1,2],[3,4]])
for _ in range(10):
    rewire_mono2(edges,1)
    print(edges)

In [None]:
np.random.randint(0, 2, size=100)

In [None]:
def rewire_bipartite(edges, lower, upper, n_rewire):
    """rewires a two class graph
    
    notice that also a one class _directed_ graph is a two class graph
    """
    if upper-lower < 2:
        raise ValueError
    
    _rewire_bipartite(edges[lower:upper], n_rewire)
    print(edges[lower:upper])
    
@njit
def _rewire_bipartite(edges, n_rewire):
    # can do further optimization because the left side is always in a block
    #  => can limit search range
    
    delta = len(edges)

    
    for _ in range(n_rewire):
        index1 = np.random.randint(0, delta)
        offset = np.random.randint(1, delta)
        index2 = (index1 + offset) % (delta)
        e1_l, e1_r = edges[index1,:]
        e2_l, e2_r = edges[index2 ,:]
        
        if (e1_r == e2_r) or (e1_l == e2_l): # swap would do nothing
            continue
            
        if (e1_l == e2_r) or (e1_r == e2_l): # no self loops after swab
            continue
        
        can_flip = True
        for i in range(len(edges)):
            ei_l, ei_r = edges[i,:]
            if (ei_l == e1_l and ei_r == e2_r) or (ei_l == e2_l and ei_r == e1_r):
                can_flip = False
                break
        if can_flip:
            edges[index1, 1] = e2_r
            edges[index2, 1] = e1_r



In [None]:
@njit
def _rewire_bipartite_large(edges, n_rewire):
    # can do further optimization because the left side is always in a block
    #  => can limit search range
    
    delta = len(edges)
    neigh = Dict()
    neigh[0] = List([-1])
    del neigh[0]
    for l,r in edges:
        if l not in neigh:
            tmp = List([-1])
            tmp.pop()
            neigh[l] = tmp
        neigh[l].append(r)
    
    for _ in range(n_rewire):
        index1 = np.random.randint(0, delta)
        offset = np.random.randint(1, delta)
        index2 = (index1 + offset) % (delta)
        e1_l, e1_r = edges[index1,:]
        e2_l, e2_r = edges[index2 ,:]
        
        if (e1_r == e2_r) or (e1_l == e2_l): # swap would do nothing
            continue
            
        if (e1_l == e2_r) or (e1_r == e2_l): # no self loops after swab
            continue
        
        can_flip = True
        if e2_r in neigh[e1_l] or e1_r in neigh[e2_l]:
            can_flip = False

        if can_flip:
            edges[index1, 1] = e2_r
            edges[index2, 1] = e1_r
            
            neigh[e1_l].remove(e1_r)
            neigh[e2_l].remove(e2_r)
            neigh[e1_l].append(e2_r)
            neigh[e2_l].append(e1_r)

In [None]:
@njit
def rewire_numba(edges, edge_class, current_mono, block, is_directed):
    # assumes edges to be ordered
    
    

    #block = block_indices[depth]
    #edge_class = edges_classes[:,depth]
    #curr_dead = dead_arr[depth,:]
    #current_mono = is_mono[depth]
    #print(edge_class)
    #print(np.bincount(edges.ravel()))
    #original_edges = edges.copy()
    total_a = 0
    total_b = 0
    #print("block", (block[:,1]-block[:,0]).sum())
    deltas=[]
    #print(block.shape, len(block), block.dtype)
    for i in range(len(block)):
        lower = block[i,0]
        upper = block[i,1]
        delta=upper-lower
        
        deltas.append(delta)
        current_class = edge_class[lower]

        if not is_directed and current_class in current_mono:
            total_a += int(delta)
            #print(f"---{delta}")
            if delta< 50:
                rewire_mono(edges[lower:upper], np.random.randint(delta, 2*delta))
            else:
                rewire_mono2(edges[lower:upper], np.random.randint(delta, 2*delta))
        else:
            total_b += int(delta)
            #print(f"-{delta}")
            if delta< 50:
                _rewire_bipartite(edges[lower:upper], np.random.randint(delta, 2*delta))
            else:
                _rewire_bipartite_large(edges[lower:upper], np.random.randint(delta, 2*delta))

    
    #print(np.max(deltas))
    #print(block[:,1]-block[:,0]-np.array(deltas))
    #print("both", total_a+total_b)
    #print("mono", total_a)
    #print("bipa", total_b)
    #print(edges-original_edges!=0)
    #print(np.bincount(edges.ravel()))

In [None]:
get_lens(is_mono)

In [None]:
for i in range(len(dead_arr)-1):
    print(np.all(np.logical_or(~dead_arr[i,:], dead_arr[i+1,:])))

In [None]:
%load_ext snakeviz

In [None]:
%%snakeviz --new-tab

block_indices = get_block_indices(edges_classes, dead_arr)
all_edges = []
for depth  in range(8,-1,-1):
    inner_edges = []
    for inner_iter  in range(1):
        #print(depth)#, inner_iter, dead_arr[depth,:].sum())
        t0 = time.time()

        rewire_numba(edges_ordered, edges_classes[:,depth], is_mono[depth], block_indices[depth], G.is_directed())
        inner_edges.append(edges_ordered.copy())
        t1 = time.time()
        total_n = t1-t0
        print(depth, total_n)
    all_edges.append(inner_edges)

In [None]:
all_edges

In [None]:
%%time
a,b = WL_fast(all_edges[0][0])

In [None]:
b

In [None]:
is_sorted(all_edges[0][0][:,0])

In [None]:
def is_sorted(arr):
    return np.all((arr[1:]-arr[:-1])>0)

In [None]:
(block_indices[1][:,1]-block_indices[1][:,0]).sum()

In [None]:
dead_arr

In [None]:
np.random.seed(3)
lower = 0
upper = 20
print(edges_ordered[lower:upper])



rewire_bipartite(edges_ordered.copy(), lower, upper, 10)

In [None]:
edges_ordered

In [None]:
dead_arr

In [None]:
# identify dead_edges

#edges_ordered, edge_classes






In [None]:
edges_classes2 = np.hstack([relabel_edges(ordered_labelings[i,:], edges_ordered.copy()) for i in range(ordered_labelings.shape[0])])

In [None]:
edges_classes2

In [None]:
def is_sorted(arr):
    return np.all((arr[1:]-arr[0:-1])>0)

In [None]:
[is_sorted(a) for a in labelings]

In [None]:
2 & 4