In [2]:
import numba
from numba.typed import Dict, List
import numba.types
import numpy as np

import adaptoctree.morton as morton
import adaptoctree.tree as tree
import adaptoctree.plotting as plotting



In [3]:
d = Dict.empty(
    key_type=numba.types.int64, value_type=numba.types.int64
)

In [4]:
d.keys()

KeysView(DictType[int64,int64]<iv=None>({}))

In [5]:
d.keys = [1, 2, 3]

In [8]:
List([1,2])

ListType[int64]([1, 2])

In [9]:
@numba.njit
def balance(tree, depth):
    
    def bfs(root, tree, depth):

        tree = set(tree)
        queue = [root]

        overlaps = set()

        while queue:
            for node in queue:
                level = morton.find_level(node)
                new_queue = set()
                for l in range(1, depth-level + 1):

                    descs = set(morton.find_descendents(node, l))

                    ints = descs.intersection(tree)
                    overlaps.update(ints)
                    
                    new_queue.update(ints)

            queue = list(new_queue)

        return overlaps


    def remove_overlaps(balanced, depth):

        unique = set(balanced)

        for node in balanced:
            if bfs(node, unique, depth):
                unique.remove(node)

        return unique


    balanced = set(tree)

    for l in range(depth, 0, -1):
        # nodes at current level
        Q = [x for x in balanced if morton.find_level(x) == l]
        
        for q in Q:
            parent = morton.find_parent(q)
            neighbours = morton.find_neighbours(q)
            parent_neighbours = morton.find_neighbours(parent)
            
            balanced.update(parent_neighbours)
            balanced.update(neighbours)
    
    return remove_overlaps(balanced, depth)


# Run with small number of particles, to allow Numba to compile

In [10]:
N = int(1e3)
particles = plotting.make_moon(N)
unbalanced = tree.build(particles)
depth = max(morton.find_level(unbalanced))

In [11]:
b = tree.balance(unbalanced, depth)

In [12]:
%timeit balanced = tree.balance(unbalanced, depth)

2.61 ms ± 71.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
b.intersection(np.array([98307]))

{98307}

In [29]:
a = balance(unbalanced, depth)

In [30]:
%timeit balance(unbalanced, depth)

151 ms ± 777 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [27]:
b == set(a)

True

# Run with larger number of particles

In [14]:
N = int(1e4)
particles = plotting.make_moon(N)
unbalanced = tree.build(particles)
depth = max(morton.find_level(unbalanced))

In [15]:
_ = balance(unbalanced, depth)

In [16]:
%timeit balance(unbalanced, depth)

238 ms ± 982 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
%timeit tree.balance(unbalanced, depth)

31.9 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
