In [1]:
from cc_model.load_datasets import *
from cc_model.fast_wl import WL_fast
from cc_model.utils import nx_to_gt
from cc_model.pagerank import all_pagerank
from cc_model.fast_rewire import rewire_fast, sort_edges, get_block_indices
from cc_model.fast_graph import FastGraph

import networkx as nx
from pathlib import Path
import graph_tool.all as gt
import time

In [2]:
datasets = [#"karate", 
            #"phonecalls",
            #"HepPh"#, 
    #"AstroPh", 
            #"web-Google",# "soc-Pokec"
#            "deezer_HR", "deezer_HU", "deezer_RO","tw_musae_DE",
#            "tw_musae_ENGB","tw_musae_FR","lastfm_asia","fb_ath",
#            "fb_pol", "facebook_sc"
           ]

In [3]:
%load_ext snakeviz

In [4]:
dataset_path = Path("/home/felix/projects/colorful_configuration/datasets")

In [5]:
#edges, is_directed = load_dataset(dataset_path, "soc-Pokec")

In [6]:
from cc_model.fast_rewire import *

In [7]:
def check_is_undirected(edges):
    d = {(a,b) for a,b in edges}
    for a,b in edges:
        assert (b,a) in d

In [8]:
#check_is_undirected(edges)
from numba import njit

In [9]:
@njit
def get_edge_id2(labels, edges, out):


    d = {(0, 0) : 0}
    del d[(0, 0)]
    is_mono = {0 : True}
    for i in range(edges.shape[0]):
        e1 = edges[i,0]
        e2 = edges[i,1]
        tpl = (labels[e1], labels[e2])
        if tpl in d:
            out[i] = d[tpl]
        else:
            n = len(d)
            out[i] = n
            d[tpl] = n
            if labels[e1] == labels[e2]:
                is_mono[n] = True

    return out, is_mono

In [10]:
#%%snakeviz --new-tab
G_base = load_gt_dataset_cached(dataset_path, "karate",
                                        verbosity=1,
                                        force_reload=False)
edges = G_base.get_edges()


        
G = FastGraph(edges, G_base.is_directed())
G.calc_base_wl()

In [11]:
def get_dead_edges_full(edge_with_node_labels, edges, order, num_nodes):

    num_labelings = edge_with_node_labels.shape[1]//2
    
    dead_indicators = np.zeros((edges.shape[0], num_labelings), dtype=np.bool)
    for i in range(num_labelings):
        _get_dead_edges(edge_with_node_labels[:,i*2:i*2+2], edges, order, num_nodes, dead_indicators[:,i])
    return dead_indicators



In [12]:
def reset_edges_ordered(self):
    print("resetting2")
    self.edges_ordered, self.edges_classes, self.dead_arr, self.is_mono = sort_edges(self._edges, self.base_partitions)
    self.block_indices = get_block_indices(self.edges_classes, self.dead_arr)

In [13]:
FastGraph.reset_edges_ordered = reset_edges_ordered

In [14]:
#@njit
def get_edge_id1(edge_with_node_labels, order, out):
    #order = np.lexsort(edge_with_node_labels.T)
    return _get_edge_id(edge_with_node_labels, order, out)
    
@njit
def _get_edge_id(edge_with_node_labels, order, out):
    last_label_0 = edge_with_node_labels[order[0],0]
    last_label_1 = edge_with_node_labels[order[0],1]

    if last_label_0==last_label_1:
        is_mono = {0 : True}
    else:
        is_mono = {0 : False}
    num_edge_colors = 0
    for i in range(order.shape[0]):
        curr_edge = order[i]
        node_label_0 = edge_with_node_labels[curr_edge,0]
        node_label_1 = edge_with_node_labels[curr_edge,1]
        if node_label_0!=last_label_0 or node_label_1!=last_label_1:
            num_edge_colors += 1
            last_label_0=node_label_0
            last_label_1=node_label_1
            if node_label_0==node_label_1:
                is_mono[num_edge_colors] = True
            
        out[curr_edge] = num_edge_colors

    return out, is_mono

In [15]:
@njit(parallel=True)
def assign_node_labels(labels, edges, out):
    for i in prange(edges.shape[0]):
        node_0 = edges[i,0]
        node_1 = edges[i,1]
        out[i,0]=labels[node_0]
        out[i,1]=labels[node_1]
    

In [16]:
@njit
def _get_dead_edges(edge_with_node_labels, edges, order, num_nodes, out):
    #print(edge_with_node_labels.shape)
    start_edge = order[0]
    last_label_0 = edge_with_node_labels[start_edge, 0]
    last_label_1 = edge_with_node_labels[start_edge, 1]
    
    last_id_0 = edges[start_edge, 0]
    last_id_1 = edges[start_edge, 1]
    
    start_of_last_group = 0
    last_group_is_dead_0 = False
    last_group_is_dead_1 = False
    len_last_group = 0
    
    for i in range(order.shape[0]):
        curr_edge = order[i]
        curr_label_0 = edge_with_node_labels[curr_edge, 0]
        curr_label_1 = edge_with_node_labels[curr_edge, 1]
        
        curr_id_0 = edges[curr_edge, 0]
        curr_id_1 = edges[curr_edge, 1]
        
        if curr_label_0 != last_label_0 or curr_label_1 != last_label_1:
            if (last_group_is_dead_0 or last_group_is_dead_1) or len_last_group==1:
                for j in range(start_of_last_group, i):
                    out[order[j]] = True
            last_group_is_dead_0 = True
            last_group_is_dead_1 = True
            
            start_of_last_group = i
            len_last_group = 0
            last_label_0 = curr_label_0
            last_label_1 = curr_label_1
            
            last_id_0 = curr_id_0
            last_id_1 = curr_id_1
        if last_id_0 != curr_id_0:
            last_group_is_dead_0 = False
        if last_id_1 != curr_id_1:
            last_group_is_dead_1 = False
        len_last_group+=1
    if (last_group_is_dead_0 and last_group_is_dead_1) or len_last_group==1:
        for j in range(start_of_last_group, len(out)):
            out[order[j]] = True
            

    return out

In [17]:

def sort_edges(edges, labelings, directed = True):
    """Sort edges such that that edges of similar classes are consecutive

    additionally puts dead edges at the end

    """
    print("sort_edges2")
    # WARNING If network is undirected edges need to be sorted first
    if directed is False:
        raise ValueError()


    edges_classes = []
    is_mono = []
    edge_with_node_labels = np.empty((edges.shape[0], 2*labelings.shape[0]), dtype=labelings.dtype)

    edge_with_node_labels
    for i in range(labelings.shape[0]):
        assign_node_labels(labelings[i,:], edges , edge_with_node_labels[:,i*2:i*2+2])
    print(edge_with_node_labels.max())
    order = np.lexsort(edge_with_node_labels[:,::-1].T)
    #order = get_order(edge_with_node_labels)
    
    #print(edge_with_node_labels.shape)
    #print(len(order))
    #print(edges.shape)
    #print(edge_with_node_labels[order,:])
    for i in range(labelings.shape[0]):
        #assign_node_labels(labelings[i,:], edges , edge_with_node_labels)
        edge_class, mono = get_edge_id1(edge_with_node_labels[:,i*2:i*2+2], order, np.empty(len(edges), dtype=np.uint32))

        edges_classes.append(edge_class)
        is_mono.append(mono)
    
    
    dead_indicator = get_dead_edges_full(edge_with_node_labels, edges, order, labelings.shape[1]).T
    print(np.hstack((edges, edge_with_node_labels, dead_indicator.T))[order,:])
    #raise ValueError
    tmp = list(chain.from_iterable(zip(edges_classes, dead_indicator)))
    #print(tmp)
    #(list(tmp))
    edges_classes_arr = np.vstack(edges_classes)
    to_sort_arr = np.vstack(tmp)#[dead_ids]+ edges_classes)

    # sort edges such that each of the classes are in order
    edge_order = np.lexsort(to_sort_arr[::-1,:])
    print(edge_order)
    edges_ordered = edges[edge_order,:]
    print(np.hstack((edges_ordered, edges_classes_arr[:, edge_order].T, dead_indicator[:, edge_order].T)))
    
    return edges_ordered, edges_classes_arr[:, edge_order].T, dead_indicator[:, edge_order], is_mono    

In [18]:
def get_order(edge_with_node_labels):
    #print(edge_with_node_labels)
    my_ls =  my_lexsort(edge_with_node_labels.copy())
    #np_ls = np.lexsort(edge_with_node_labels[:,::-1].T)
    #print(np.hstack((edge_with_node_labels[my_ls,:], edge_with_node_labels[np_ls,:])))
    #np.testing.assert_array_equal(edge_with_node_labels[my_ls,:], edge_with_node_labels[np_ls,:])
    return my_ls

In [19]:
#%%snakeviz --new-tab
G.reset_edges_ordered()

resetting2
sort_edges2
26
[[ 0  1  0  0  0  1  0  1  0  1  1]
 [ 0  2  0  0  0  2  0  2  0  1  1]
 [ 0  3  0  0  0  3  0  3  0  1  1]
 [ 0 31  0  0  0  3  0 24  0  1  1]
 [ 0  4  0  0  0  4  0  4  0  1  1]
 [ 0 10  0  0  0  4  0  4  0  1  1]
 [ 0 19  0  0  0  4  0 15  0  1  1]
 [ 0  5  0  0  0  5  0  5  0  1  1]
 [ 0  6  0  0  0  5  0  5  0  1  1]
 [ 0  7  0  0  0  5  0  6  0  1  1]
 [ 0  8  0  0  0  6  0  7  0  1  1]
 [ 0 13  0  0  0  6  0 11  0  1  1]
 [ 0 12  0  0  0  7  0 10  0  1  1]
 [ 0 17  0  0  0  7  0 14  0  1  1]
 [ 0 21  0  0  0  7  0 14  0  1  1]
 [ 0 11  0  0  0  8  0  9  0  1  1]
 [ 1  2  0  0  1  2  1  2  0  1  1]
 [ 1  3  0  0  1  3  1  3  0  1  1]
 [ 1 19  0  0  1  4  1 15  0  1  1]
 [ 1  7  0  0  1  5  1  6  0  1  1]
 [ 1 30  0  0  1  5  1 23  0  1  1]
 [ 1 13  0  0  1  6  1 11  0  1  1]
 [ 1 17  0  0  1  7  1 14  0  1  1]
 [ 1 21  0  0  1  7  1 14  0  1  1]
 [ 2  3  0  0  2  3  2  3  0  1  1]
 [ 2 28  0  0  2  4  2 21  0  1  1]
 [ 2  7  0  0  2  5  2  6  0  1  1]
 [

In [20]:
np.iinfo(np.uint16).max

65535

In [21]:
np.random.seed(10)
arr1 = np.random.randint(100, size=10)
arr1[3]=30
arr1[4]=30
arr1[7]=30
arr2 = np.random.randint(100, size=10)
#arr2[4]=70
arr2[4]=11
arr2[7]=12
arr3 = np.random.randint(100, size=10)
input = np.array([arr1, arr2, arr3], dtype=np.uint32).T

In [22]:
order = my_lexsort(input)
print(order)
input[order,:]

NameError: name 'my_lexsort' is not defined

In [None]:
print(np.lexsort(np.array([arr3, arr2, arr1], dtype=np.uint32)))
print(order)

In [None]:
@njit
def find_crossing(arr, order):
    intervals = np.empty(len(arr)+1, dtype=np.uint32)
    intervals[0]=0
    last_val = arr[order[0]]
    index = 1
    for i in range(len(order)):
        val = arr[order[i]]
        
        if val != last_val:
            intervals[index]=i
            index+=1
            last_val=val
    intervals[index]=arr.shape[0]
    index+=1
    return intervals[:index]

@njit
def find_crossing2(arr):
    intervals = np.empty(len(arr)+1, dtype=np.uint32)
    intervals[0]=0
    last_val = arr[0]
    index = 1
    for i in range(arr.shape[0]):
        val = arr[i]
        
        if val != last_val:
            intervals[index]=i
            index+=1
            last_val=val
    intervals[index]=arr.shape[0]
    index+=1
    return intervals[:index]

@njit(parallel=True)
def my_lexsort(arrs):
    if arrs.shape[1]==1:
        return np.argsort(arrs[:,0])
    starting_index=0
    while np.all(arrs[:,starting_index])==arrs[0,starting_index]:
        starting_index+=1
    order = np.argsort(arrs[:,starting_index])
    
    #print(arrs)
    intervals = find_crossing(arrs[:,starting_index], order)
    arr_sorted = arrs[order,:]
    #print(arr_sorted)
    #print("inter", intervals)
    for i in prange(len(intervals)-1):
        lb = intervals[i]
        ub = intervals[i+1]
        if ub-lb >1:
            _lexsort(arr_sorted[lb:ub,starting_index+1:],  order[lb:ub])
            #order[lb:ub] = order[lb:ub][partial_order]
    return order

@njit(nogil=True)
def _lexsort(arrs, order):
    #print(arrs)
    new_order = np.argsort(arrs[:,0])
    #print(new_order)
    arrs[:,:] = arrs[new_order,:]
    order[:]=order[new_order]
    intervals = find_crossing2(arrs[:,0])
    #print("intervals", intervals)
    for i in range(len(intervals)-1):
        lb = intervals[i]
        ub = intervals[i+1]
        if ub-lb >1 and arrs.shape[1]>1:
            _lexsort(arrs[lb:ub,1:], order[lb:ub])
            #order[lb:ub] = order[lb:ub][partial_order]