Skip to content

Commit

Permalink
Enable construction and queries for empty trees (#36)
Browse files Browse the repository at this point in the history
* Allow 0 point tree

* Clean up check_k function

* Add tests for empty trees
  • Loading branch information
JoshChristie authored and KristofferC committed Oct 20, 2016
1 parent 4793669 commit 88368b5
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
15 changes: 11 additions & 4 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function BallTree{V <: AbstractArray, M <: Metric}(data::Vector{V},
# Bottom up creation of hyper spheres so need spheres even for leafs)
hyper_spheres = Array(HyperSphere{length(V), eltype(V)}, tree_data.n_internal_nodes + tree_data.n_leafs)

if reorder
if reorder
indices_reordered = Vector{Int}(n_p)
if isempty(reorderbuffer)
data_reordered = Vector{V}(n_p)
Expand All @@ -61,9 +61,11 @@ function BallTree{V <: AbstractArray, M <: Metric}(data::Vector{V},
data_reordered = Vector{V}()
end

# Call the recursive BallTree builder
build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
1, length(data), tree_data, array_buffs, reorder)
if n_p > 0
# Call the recursive BallTree builder
build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
1, length(data), tree_data, array_buffs, reorder)
end

if reorder
data = data_reordered
Expand Down Expand Up @@ -201,6 +203,11 @@ function inrange_kernel!(tree::BallTree,
query_ball::HyperSphere,
idx_in_ball::Vector{Int})
@NODE 1

if index > length(tree.hyper_spheres)
return
end

sphere = tree.hyper_spheres[index]

# If the query ball in the bounding sphere for the current sub tree
Expand Down
1 change: 0 additions & 1 deletion src/hyperrectangles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ end

# Computes a bounding box around a point cloud
function compute_bbox{V <: AbstractVector}(data::Vector{V})
@assert length(data) != 0
T = eltype(V)
n_dim = length(V)
maxes = Vector{T}(n_dim)
Expand Down
12 changes: 6 additions & 6 deletions src/knn.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
check_k(tree, k) = (k > length(tree.data)|| k <= 0) && throw(ArgumentError("k > number of points in tree or ≦ 0"))
function check_k(tree, k)
if k > length(tree.data) || k < 0
throw(ArgumentError("k > number of points in tree or < 0"))
end
end

"""
knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances
Expand All @@ -11,7 +15,6 @@ to determine if a point that would be returned should be skipped.
function knn{V, T <: AbstractVector}(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::Function=always_false)
check_input(tree, points)
check_k(tree, k)

n_points = length(points)
dists = [Vector{get_T(eltype(V))}(k) for _ in 1:n_points]
idxs = [Vector{Int}(k) for _ in 1:n_points]
Expand All @@ -34,10 +37,7 @@ function knn_point!{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, s
end

function knn{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::Function=always_false)
if k > length(tree.data)|| k <= 0
throw(ArgumentError("k > number of points in tree or ≦ 0"))
end

check_k(tree, k)
idx = Vector{Int}(k)
dist = Vector{get_T(eltype(V))}(k)
knn_point!(tree, point, sortres, dist, idx, skip)
Expand Down
4 changes: 4 additions & 0 deletions test/test_inrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@

@test_throws ArgumentError inrange(tree, rand(3), -0.1)
@test_throws ArgumentError inrange(tree, rand(5), 1.0)

empty_tree = TreeType(rand(3,0), metric)
idxs = inrange(empty_tree, [0.5, 0.5, 0.5], 1.0)
@test idxs == []
end
end
end
6 changes: 6 additions & 0 deletions test/test_knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ import Distances.evaluate
idxs, dists = knn(tree, [1//10, 8//10], 3, true)
@test idxs == [3, 2, 5]

@test_throws ArgumentError knn(tree, [0.1, 0.8], -1) # k < 0
@test_throws ArgumentError knn(tree, [0.1, 0.8], 10) # k > n_points
@test_throws ArgumentError knn(tree, [0.1], 10) # n_dim != trees dim

empty_tree = TreeType(rand(2,0), metric; leafsize=2)
idxs, dists = knn(empty_tree, [0.5, 0.5], 0, true)
@test idxs == Int[]
@test_throws ArgumentError knn(empty_tree, [0.1, 0.8], -1) # k < 0
@test_throws ArgumentError knn(empty_tree, [0.1, 0.8], 1) # k > n_points
end
end
end
Expand Down

0 comments on commit 88368b5

Please sign in to comment.