Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add skip predicate to inrange, fixes #53 #56

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ idxs, dists = knn(kdtree, point, k, true)
A range search finds all neighbors within the range `r` of given point(s).
This is done with the method:
```jl
inrange(tree, points, r, sortres = false) -> idxs
inrange(tree, points, r, sortres = false, skip = always_false) -> idxs
```
Note that for performance reasons the distances are not returned. The arguments to `inrange` are the same as for `knn` except that `sortres` just sorts the returned index vector.

Expand Down
9 changes: 6 additions & 3 deletions benchmarks/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ SUITE["build tree"] = BenchmarkGroup()
SUITE["knn"] = BenchmarkGroup()
SUITE["inrange"] = BenchmarkGroup()

get_name(T::UnionAll) = get_name(T.body)
get_name(T::DataType) = T.name

for n_points in (EXTENSIVE_BENCHMARK ? (10^3, 10^5) : 10^5)
for dim in (EXTENSIVE_BENCHMARK ? (1, 3) : 3)
data = rand(MersenneTwister(1), dim, n_points)
Expand All @@ -19,17 +22,17 @@ for n_points in (EXTENSIVE_BENCHMARK ? (10^3, 10^5) : 10^5)
for (tree_type, SUITE_name) in ((KDTree, "kd tree"),
(BallTree, "ball tree"))
tree = tree_type(data; leafsize = leafsize, reorder = reorder)
SUITE["build tree"]["$(tree_type.name.name) $dim × $n_points, ls = $leafsize"] = @benchmarkable $(tree_type)($data; leafsize = $leafsize, reorder = $reorder)
SUITE["build tree"]["$(get_name(tree_type)) $dim × $n_points, ls = $leafsize"] = @benchmarkable $(tree_type)($data; leafsize = $leafsize, reorder = $reorder)
for input_size in (1, 1000)
input_data = rand(MersenneTwister(1), dim, input_size)
for k in (EXTENSIVE_BENCHMARK ? (1, 10) : 10)
SUITE["knn"]["$(tree_type.name.name) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($tree, $input_data, $k)
SUITE["knn"]["$(get_name(tree_type)) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($tree, $input_data, $k)
end
perc = 0.01
V = π^(dim / 2) / gamma(dim / 2 + 1) * (1 / 2)^dim
r = (V * perc * gamma(dim / 2 + 1))^(1/dim)
r_formatted = @sprintf("%3.2e", r)
SUITE["inrange"]["$(tree_type.name.name) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($tree, $input_data, $r)
SUITE["inrange"]["$(get_name(tree_type)) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($tree, $input_data, $r)
end
end
end
Expand Down
18 changes: 10 additions & 8 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ function _knn(tree::BallTree,
point::AbstractVector,
best_idxs::Vector{Int},
best_dists::Vector,
skip::Function)
skip::F) where {F}
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip)
return
end
Expand Down Expand Up @@ -189,17 +189,19 @@ end
function _inrange(tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Vector{Int}) where {V}
idx_in_ball::Vector{Int},
skip::F) where {V, F}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip) # Call the recursive range finder
return
end

function inrange_kernel!(tree::BallTree,
index::Int,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Vector{Int})
idx_in_ball::Vector{Int},
skip::F) where {F}
@NODE 1

if index > length(tree.hyper_spheres)
Expand All @@ -216,17 +218,17 @@ function inrange_kernel!(tree::BallTree,

# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true)
add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true, skip)
return
end

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses(tree.metric, sphere, query_ball)
addall(tree, index, idx_in_ball)
addall(tree, index, idx_in_ball, skip)
else
# Recursively call the left and right sub tree.
inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip)
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip)
end
end
14 changes: 10 additions & 4 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function _knn(tree::BruteTree{V},
point::AbstractVector,
best_idxs::Vector{Int},
best_dists::Vector,
skip::Function) where {V}
skip::F) where {V, F}

knn_kernel!(tree, point, best_idxs, best_dists, skip)
return
Expand Down Expand Up @@ -55,17 +55,23 @@ end
function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number,
idx_in_ball::Vector{Int})
inrange_kernel!(tree, point, radius, idx_in_ball)
idx_in_ball::Vector{Int},
skip::F) where {F}
inrange_kernel!(tree, point, radius, idx_in_ball, skip)
return
end


function inrange_kernel!(tree::BruteTree,
point::AbstractVector,
r::Number,
idx_in_ball::Vector{Int})
idx_in_ball::Vector{Int},
skip::F) where {F}
for i in 1:length(tree.data)
if skip(i)
continue
end

@POINT 1
d = evaluate(tree.metric, tree.data[i], point)
if d <= r
Expand Down
22 changes: 12 additions & 10 deletions src/inrange.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0"))

"""
inrange(tree::NNTree, points, radius [, sortres=false]) -> indices
inrange(tree::NNTree, points, radius [, sortres=false, skip=always_false]) -> indices

Find all the points in the tree which is closer than `radius` to `points`. If
`sortres = true` the resulting indices are sorted.
`sortres = true` the resulting indices are sorted. `skip` is an optional predicate
to determine if a point that would be returned should be skipped.
"""
function inrange(tree::NNTree,
points::Vector{T},
radius::Number,
sortres=false) where {T <: AbstractVector}
sortres=false,
skip::F=always_false) where {T <: AbstractVector, F}
check_input(tree, points)
check_radius(radius)

idxs = [Vector{Int}() for _ in 1:length(points)]

for i in 1:length(points)
inrange_point!(tree, points[i], radius, sortres, idxs[i])
inrange_point!(tree, points[i], radius, sortres, idxs[i], skip)
end
return idxs
end

function inrange_point!(tree, point, radius, sortres, idx)
_inrange(tree, point, radius, idx)
function inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F}
_inrange(tree, point, radius, idx, skip)
if tree.reordered
@inbounds for j in 1:length(idx)
idx[j] = tree.indices[idx[j]]
Expand All @@ -32,21 +34,21 @@ function inrange_point!(tree, point, radius, sortres, idx)
return
end

function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number}
function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip::F=always_false) where {V, T <: Number, F}
check_input(tree, point)
check_radius(radius)
idx = Int[]
inrange_point!(tree, point, radius, sortres, idx)
inrange_point!(tree, point, radius, sortres, idx, skip)
return idx
end

function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false) where {V, T <: Number}
function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false, skip::F=always_false) where {V, T <: Number, F}
dim = size(point, 1)
npoints = size(point, 2)
if isbits(T)
new_data = reinterpret(SVector{dim,T}, point, (length(point) ÷ dim,))
else
new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints]
end
inrange(tree, new_data, radius, sortres)
inrange(tree, new_data, radius, sortres, skip)
end
16 changes: 9 additions & 7 deletions src/kd_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ function _knn(tree::KDTree,
point::AbstractVector,
best_idxs::Vector{Int},
best_dists::Vector,
skip::Function)
skip::F) where {F}

init_min = get_min_distance(tree.hyper_rec, point)
knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, skip)
Expand Down Expand Up @@ -203,10 +203,11 @@ end
function _inrange(tree::KDTree,
point::AbstractVector,
radius::Number,
idx_in_ball = Int[])
idx_in_ball = Int[],
skip::F = always_false) where {F}
init_min = get_min_distance(tree.hyper_rec, point)
inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(init_min)), idx_in_ball,
init_min)
init_min, skip)
return
end

Expand All @@ -216,7 +217,8 @@ function inrange_kernel!(tree::KDTree,
point::AbstractVector,
r::Number,
idx_in_ball::Vector{Int},
min_dist)
min_dist,
skip::F) where {F}
@NODE 1
# Point is outside hyper rectangle, skip the whole sub tree
if min_dist > r
Expand All @@ -225,7 +227,7 @@ function inrange_kernel!(tree::KDTree,

# At a leaf node. Go through all points in node and add those in range
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_inrange!(idx_in_ball, tree, index, point, r, false)
add_points_inrange!(idx_in_ball, tree, index, point, r, false, skip)
return
end

Expand All @@ -247,7 +249,7 @@ function inrange_kernel!(tree::KDTree,
ddiff = max(zero(lo - p_dim), lo - p_dim)
end
# Call closer sub tree
inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist)
inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist, skip)

# TODO: We could potentially also keep track of the max distance
# between the point and the hyper rectangle and add the whole sub tree
Expand All @@ -259,5 +261,5 @@ function inrange_kernel!(tree::KDTree,
ddiff_pow = eval_pow(M, ddiff)
diff_tot = eval_diff(M, split_diff_pow, ddiff_pow)
new_min = eval_reduce(M, min_dist, diff_tot)
inrange_kernel!(tree, far, point, r, idx_in_ball, new_min)
inrange_kernel!(tree, far, point, r, idx_in_ball, new_min, skip)
end
8 changes: 4 additions & 4 deletions src/knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ function check_k(tree, k)
end

"""
knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances
knn(tree::NNTree, points, k [, sortres=false, skip=always_false]) -> indices, distances

Performs a lookup of the `k` nearest neigbours to the `points` from the data
in the `tree`. If `sortres = true` the result is sorted such that the results are
in the order of increasing distance to the point. `skip` is an optional predicate
to determine if a point that would be returned should be skipped.
"""
function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::Function=always_false) where {V, T <: AbstractVector}
function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F}
check_input(tree, points)
check_k(tree, k)
n_points = length(points)
Expand All @@ -36,15 +36,15 @@ function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, id
end
end

function knn(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::Function=always_false) where {V, T <: Number}
function knn(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F}
check_k(tree, k)
idx = Vector{Int}(k)
dist = Vector{get_T(eltype(V))}(k)
knn_point!(tree, point, sortres, dist, idx, skip)
return idx, dist
end

function knn(tree::NNTree{V}, point::Matrix{T}, k::Int, sortres=false, skip::Function=always_false) where {V, T <: Number}
function knn(tree::NNTree{V}, point::Matrix{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F}
dim = size(point, 1)
npoints = size(point, 2)
if isbits(T)
Expand Down
25 changes: 17 additions & 8 deletions src/tree_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ end
tree::NNTree, index::Int, point::AbstractVector,
do_end::Bool, skip::F) where {F}
for z in get_leaf_range(tree.tree_data, index)
if skip(tree.indices[z])
continue
end

@POINT 1
idx = tree.reordered ? z : tree.indices[z]
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
if dist_d <= best_dists[1]
if skip(tree.indices[z])
continue
end

best_dists[1] = dist_d
best_idxs[1] = idx
percolate_down!(best_dists, best_idxs, dist_d, idx)
Expand All @@ -116,8 +116,13 @@ end
# This will probably prevent SIMD and other optimizations so some care is needed
# to evaluate if it is worth it.
@inline function add_points_inrange!(idx_in_ball::Vector{Int}, tree::NNTree,
index::Int, point::AbstractVector, r::Number, do_end::Bool)
index::Int, point::AbstractVector, r::Number,
do_end::Bool, skip::F) where {F}
for z in get_leaf_range(tree.tree_data, index)
if skip(tree.indices[z])
continue
end

@POINT 1
idx = tree.reordered ? z : tree.indices[z]
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
Expand All @@ -129,18 +134,22 @@ end

# Add all points in this subtree since we have determined
# they are all within the desired range
function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int})
function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int}, skip::F) where {F}
tree_data = tree.tree_data
@NODE 1
if isleaf(tree.tree_data.n_internal_nodes, index)
for z in get_leaf_range(tree.tree_data, index)
if skip(tree.indices[z])
continue
end

@POINT_UNCHECKED 1
idx = tree.reordered ? z : tree.indices[z]
push!(idx_in_ball, idx)
end
return
else
addall(tree, getleft(index), idx_in_ball)
addall(tree, getright(index), idx_in_ball)
addall(tree, getleft(index), idx_in_ball, skip)
addall(tree, getright(index), idx_in_ball, skip)
end
end
13 changes: 13 additions & 0 deletions test/test_inrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@
end
end
end

@testset "inrange skip" begin
@testset "tree type" for TreeType in trees_with_brute
data = rand(2, 1000)
tree = TreeType(data)
id = 123

idxs = inrange(tree, data[:, id], 2, true)
@test id in idxs
idxs = inrange(tree, data[:, id], 2, true, i -> i == id)
@test !(id in idxs)
end
end