In [2]:
import numba
import numpy as np

import adaptoctree.morton as morton
import adaptoctree.morton as t

In [88]:
a = set([1])

In [93]:
help(a)

Help on set object:

class set(object)
 |  set() -> new empty set object
 |  set(iterable) -> new set object
 |  
 |  Build an unordered collection of unique elements.
 |  
 |  Methods defined here:
 |  
 |  __and__(self, value, /)
 |      Return self&value.
 |  
 |  __contains__(...)
 |      x.__contains__(y) <==> y in x.
 |  
 |  __eq__(self, value, /)
 |      Return self==value.
 |  
 |  __ge__(self, value, /)
 |      Return self>=value.
 |  
 |  __getattribute__(self, name, /)
 |      Return getattr(self, name).
 |  
 |  __gt__(self, value, /)
 |      Return self>value.
 |  
 |  __iand__(self, value, /)
 |      Return self&=value.
 |  
 |  __init__(self, /, *args, **kwargs)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  __ior__(self, value, /)
 |      Return self|=value.
 |  
 |  __isub__(self, value, /)
 |      Return self-=value.
 |  
 |  __iter__(self, /)
 |      Implement iter(self).
 |  
 |  __ixor__(self, value, /)
 |      Return self^=value.


In [185]:
@numba.njit
def numba_bfs(root, tree, depth):
    
    tree = set(tree)
    
    queue = np.array([root], dtype=np.int64)
    
    overlaps = set()
    
    sentinel = -1
    while queue[0] != sentinel:
        for node in queue:
            
            level = morton.find_level(node)
            new_queue = np.array([sentinel], dtype=np.int64)
            relative_depth = depth-level+1
            
            for l in range(1, relative_depth):
                descs = morton.find_descendents(node, l)
                ints = np.zeros_like(descs, dtype=np.int64)
                
                i = 0
                for d in descs:
                    if d in tree:
                        ints[i] = d
                        i += 1
                        
                ints = ints[:i]
                
                overlaps.update(ints)
                
                if new_queue[0] == sentinel:
                    new_queue = ints
                else:
                    new_queue = np.hstack((new_queue, ints))
        
        queue = new_queue
        
    return overlaps


@numba.njit
def numba_remove_overlaps(balanced, depth):

    unique = set(balanced)

    for node in balanced:
        if numba_bfs(node, unique, depth):
            unique.remove(node)

    return unique


@numba.njit
def numba_balance(tree, depth):

    for l in range(depth, 0, -1):

        # nodes at current level
        balanced = set(tree)

        levels = morton.find_level(tree)
        Q = tree[levels == l]    
        
        for q in Q:
            parent = morton.find_parent(q)
            
            neighbours = morton.find_neighbours(q)
            parent_neighbours = morton.find_neighbours(parent)
            
            tree = np.hstack((tree, neighbours))
            tree = np.hstack((tree, parent_neighbours))
            
            balanced.update(parent_neighbours)
            balanced.update(neighbours)
    
    return numba_remove_overlaps(balanced, depth)

In [151]:
levels = morton.find_level(tree)

tree[levels == 1]

array([     1,  32769,  65537,  98305, 131073, 163841, 196609, 229377])

In [147]:
def bfs(root, tree, depth):

    tree = set(tree)
    queue = [root]

    overlaps = set()

    while queue:
        for node in queue:
            level = morton.find_level(node)
            new_queue = []
            for l in range(1, depth-level + 1):

                descs = set(morton.find_descendents(node, l))

                ints = descs.intersection(tree)
                overlaps.update(ints)
                new_queue.extend(list(ints))

        queue = new_queue

    return overlaps


def remove_overlaps(balanced, depth):

    unique = set(balanced)

    for node in balanced:
        if bfs(node, unique, depth):
            unique.remove(node)

    return unique


def balance(tree, depth):

    balanced = set(tree)

    for l in range(depth, 0, -1):
        # nodes at current level
        Q = {x for x in balanced if morton.find_level(x) == l}

        for q in Q:
            parent = morton.find_parent(q)
            neighbours = set(morton.find_neighbours(q))
            parent_neighbours = set(morton.find_neighbours(parent))
            balanced.update(parent_neighbours)
            balanced.update(neighbours)
    return remove_overlaps(list(balanced), depth)

In [72]:
tree = np.array([      1,  458754,  229377,   65537,  163841,  491522,  917506,
        950274, 1114114, 1146882, 1572866, 1605634, 1048578,   65538,
             2,  131074,  196610,  524290,  655362,   32770, 1769474,
       2031618,  229378,  753666, 1015810, 1277954, 1802242, 1507330,
        262146, 1310722,  786434, 1835010,   32769, 1081346,  557058,
        819202, 1867778,  294914, 1343490, 1900546,  327682,  851970,
       1376258, 1638402,   98305,   98306,  622594, 1671170,  360450,
       1409026,  884738, 1933314,  131073,  393218, 1179650, 1703938,
       1441794, 1966082,  163842,  688130,  425986, 1212418, 1736706,
        196609,  720898,  983042, 1245186])

In [121]:
tree.sort()

In [132]:
depth = max(morton.find_level(tree))

In [128]:
bfs(tree[1], tree, depth)

set()

In [123]:
%timeit bfs(1, tree, depth)

14.9 µs ± 90.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [188]:
%timeit remove_overlaps(tree, depth)

220 µs ± 5.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [193]:
a = balance(tree, depth)

In [191]:
%timeit balance(tree, depth)

830 µs ± 6.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [126]:
%timeit numba_bfs(1, tree, depth)

2.67 µs ± 28.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [145]:
res = numba_remove_overlaps(tree, depth)

In [146]:
%timeit numba_remove_overlaps(tree, depth)

80.4 µs ± 2.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [194]:
b = numba_balance(tree, depth)

In [190]:
%timeit numba_balance(tree, depth)

291 µs ± 1.99 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [195]:
assert a == b