In [1]:
import numpy as np

import adaptoctree.morton as morton
import adaptoctree.tree as tree
import adaptoctree.plotting as plotting

# Test logic

In [2]:
N = int(1e3)
particles = plotting.make_moon(N)
# particles = np.random.rand(N, 3)

max_level = 16
max_num_particles = 50

In [3]:
max_bound, min_bound = morton.find_bounds(particles)
x0 = morton.find_center(max_bound, min_bound)
r0 = morton.find_radius(x0, max_bound, min_bound)

In [4]:
unbalanced = tree.build(particles, max_level, max_num_particles, 1)
tst = unbalanced
unbalanced = np.unique(unbalanced)
depth = max(morton.find_level(unbalanced))

In [5]:
len(unbalanced)

45

In [6]:
balanced = tree.balance(unbalanced, depth)

In [7]:
balanced = np.fromiter(balanced, np.int64, len(balanced))

In [8]:
len(balanced)

337

In [None]:
balanced_leaves = tree.assign_points_to_keys(particles, balanced, x0, r0)

In [10]:
point = particles[balanced_leaves == -1][0]

In [21]:
particles[balanced_leaves == -1]

array([[ 2.7705871 ,  0.60934245,  0.37054475],
       [ 2.76791634,  0.73993218,  0.86420982],
       [ 2.7978668 ,  0.89182525,  0.48380788],
       [ 2.91328523,  0.85162724,  0.40966737],
       [ 2.88711294,  1.38149122,  1.15062644],
       [ 3.04315565,  0.8588149 ,  0.16988802],
       [ 3.00331582,  0.84432623,  0.45818889],
       [ 2.99076279,  1.16006097,  0.65291726],
       [ 3.03093899,  1.17114679,  0.50633323],
       [ 3.06140007,  1.23246635,  0.44843694],
       [ 3.04663481,  0.8938601 ,  0.87341321],
       [ 2.7761803 ,  0.75385167,  0.69463556],
       [ 2.82678431,  1.0188243 ,  0.93792322],
       [ 2.9963841 ,  1.44644181,  1.0522973 ],
       [ 3.16347468,  0.66166477,  0.82403081],
       [ 3.09587133,  1.01343776,  0.8242614 ],
       [ 2.89017291,  0.74833487,  0.24994967],
       [ 2.98598436,  0.670948  ,  0.41519831],
       [ 3.11752562,  0.83196236,  0.51097764],
       [ 3.06325807,  0.94400289,  0.34834186],
       [ 2.9665164 ,  0.62909187,  0.438

In [77]:
key = morton.encode_point(point, 4, x0, r0)

In [78]:
key

16711684

In [71]:
set(morton.find_descendents(key, 2)).intersection(balanced)

{14712836,
 14778372,
 14843908,
 14909444,
 14942212,
 14974980,
 15007748,
 15040516,
 15073284,
 15106052,
 15138820,
 15171588,
 15237124,
 15302660,
 15368196,
 15433732,
 15466500,
 15499268,
 15532036,
 15564804,
 15597572,
 15630340,
 15663108,
 15695876,
 16613380,
 16744452}

In [79]:
morton.decode_key(key)

array([6, 7, 7, 4], dtype=int32)

In [81]:
key in set(balanced)

False

In [76]:
particles[balanced_leaves == -1].shape

(147, 3)

In [80]:
key in set(unbalanced)

False

In [35]:
balanced.shape

(380,)

In [100]:
# Test num particles constraint
_, counts = np.unique(balanced_leaves, return_counts=True)

In [101]:
assert np.all(counts <= max_num_particles)

AssertionError: 

In [12]:
counts < max_num_particles

array([ True])

In [97]:
# Check for balancing condition

for i in balanced:
    for j in balanced:
        if (i != j) and morton.are_neighbours(i, j, x0, r0):
            diff = morton.find_level(i) - morton.find_level(j)
            assert diff <= 1

In [99]:
# Check for overlaps

for i, ki in enumerate(balanced):
    for j, kj in enumerate(balanced):
        if j != i:
            assert ki not in morton.find_ancestors(kj)            

# Benchmarking

In [None]:
N = int(1e5)
particles = plotting.make_moon(N)
# particles = np.random.rand(N, 3)

max_level = 16
max_num_particles = 5


max_bound, min_bound = morton.find_bounds(particles)
x0 = morton.find_center(max_bound, min_bound)
r0 = morton.find_radius(x0, max_bound, min_bound)

unbalanced = tree.build(particles, max_level=max_level, max_points=5)
unbalanced, counts = np.unique(unbalanced, return_counts=True)

len(unbalanced)

In [None]:
unbalanced

In [None]:
%timeit tree.build(particles, max_level, max_num_particles)

In [None]:
%timeit tree.balance(unbalanced, depth)

In [None]:
tst = tree.balance(unbalanced, depth)

In [None]:
tst == set(unbalanced)

In [None]:
import numba as nb

reflected_int_set = nb.types.Set(nb.int64, reflected=False)

@nb.njit(reflected_int_set(reflected_int_set, reflected_int_set))
def f(set_1, set_2):
    return set_1
