Skip to content

Commit

Permalink
parallelizing BallTree construction
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianAment committed Jan 8, 2022
1 parent 33ccb17 commit f6acba9
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.3'
- '1'
- 'nightly'
os:
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
name = "NearestNeighbors"
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
version = "0.4.10"
version = "0.5.0"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Distances = "0.9, 0.10"
StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0"
julia = "1.0"
julia = "1.3"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 1 addition & 1 deletion src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Distances: Metric, result_type, eval_reduce, eval_end, eval_op, eval_star

using StaticArrays
import Base.show
using Base.Threads: @threads
using Base.Threads

export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
export knn, nn, inrange # TODOs? , allpairs, distmat, npairs
Expand Down
91 changes: 67 additions & 24 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@ struct BallTree{V <: AbstractVector,N,T,M <: Metric} <: NNTree{V,M}
reordered::Bool # If the data has been reordered
end

# When we create the bounding spheres we need some temporary arrays.
# We create a type to hold them to not allocate these arrays at every
# function call and to reduce the number of parameters in the tree builder.
struct ArrayBuffers{N,T <: AbstractFloat}
center::MVector{N,T}
end

function ArrayBuffers(::Type{Val{N}}, ::Type{T}) where {N, T}
ArrayBuffers(zeros(MVector{N,T}))
end

"""
BallTree(data [, metric = Euclidean(), leafsize = 10]) -> balltree
Expand All @@ -33,14 +22,14 @@ function BallTree(data::AbstractVector{V},
leafsize::Int = 10,
reorder::Bool = true,
storedata::Bool = true,
parallel::Bool = true,
reorderbuffer::Vector{V} = Vector{V}()) where {V <: AbstractArray, M <: Metric}
reorder = !isempty(reorderbuffer) || (storedata ? reorder : false)

tree_data = TreeData(data, leafsize)
n_d = length(V)
n_p = length(data)

array_buffs = ArrayBuffers(Val{length(V)}, get_T(eltype(V)))
indices = collect(1:n_p)

# Bottom up creation of hyper spheres so need spheres even for leafs)
Expand Down Expand Up @@ -70,7 +59,8 @@ function BallTree(data::AbstractVector{V},
if n_p > 0
# Call the recursive BallTree builder
build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
1, length(data), tree_data, array_buffs, reorder)
1, length(data), tree_data, reorder, Val(parallel))

end

if reorder
Expand All @@ -86,6 +76,7 @@ function BallTree(data::AbstractVecOrMat{T},
leafsize::Int = 10,
storedata::Bool = true,
reorder::Bool = true,
parallel::Bool = true,
reorderbuffer::Matrix{T} = Matrix{T}(undef, 0, 0)) where {T <: AbstractFloat, M <: Metric}
dim = size(data, 1)
npoints = size(data, 2)
Expand All @@ -96,7 +87,7 @@ function BallTree(data::AbstractVecOrMat{T},
reorderbuffer_points = copy_svec(T, reorderbuffer, Val(dim))
end
BallTree(points, metric, leafsize = leafsize, storedata = storedata, reorder = reorder,
reorderbuffer = reorderbuffer_points)
parallel = parallel, reorderbuffer = reorderbuffer_points)
end

# Recursive function to build the tree.
Expand All @@ -110,16 +101,16 @@ function build_BallTree(index::Int,
low::Int,
high::Int,
tree_data::TreeData,
array_buffs::ArrayBuffers{N,T},
reorder::Bool) where {V <: AbstractVector, N, T}
reorder::Bool,
parallel::Val{false}) where {V <: AbstractVector, N, T}

n_points = high - low + 1 # Points left
if n_points <= tree_data.leafsize
if reorder
reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data)
end
# Create bounding sphere of points in leaf node by brute force
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high, array_buffs)
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high)
return
end

Expand All @@ -132,22 +123,74 @@ function build_BallTree(index::Int,

# Sort the data at the mid_idx boundary using the split_dim
# to compare
select_spec!(indices, mid_idx, low, high, data, split_dim)
select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads

build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, low, mid_idx - 1,
tree_data, array_buffs, reorder)
indices, indices_reordered, low, mid_idx - 1,
tree_data, reorder, parallel)

build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, mid_idx, high,
tree_data, array_buffs, reorder)
indices, indices_reordered, mid_idx, high,
tree_data, reorder, parallel)

# Finally create bounding hyper sphere from the two children's hyper spheres
hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)],
hyper_spheres[getright(index)],
array_buffs)
hyper_spheres[getright(index)])
return
end

# Parallelized recursive function to build the tree.
function build_BallTree(index::Int,
data::Vector{V},
data_reordered::Vector{V},
hyper_spheres::Vector{HyperSphere{N,T}},
metric::Metric,
indices::Vector{Int},
indices_reordered::Vector{Int},
low::Int,
high::Int,
tree_data::TreeData,
reorder::Bool,
parallel::Val{true}) where {V <: AbstractVector, N, T}

n_points = high - low + 1 # Points left
if n_points <= tree_data.leafsize
if reorder
reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data)
end
# Create bounding sphere of points in leaf node by brute force
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high)
return
end

# Find split such that one of the sub trees has 2^p points
# and the left sub tree has more points
mid_idx = find_split(low, tree_data.leafsize, n_points)

# Brute force to find the dimension with the largest spread
split_dim = find_largest_spread(data, indices, low, high)

# Sort the data at the mid_idx boundary using the split_dim
# to compare
select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads

@sync begin
@spawn build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, low, mid_idx - 1,
tree_data, reorder, parallel)

@spawn build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, mid_idx, high,
tree_data, reorder, parallel)
end

# Finally create bounding hyper sphere from the two children's hyper spheres
hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)],
hyper_spheres[getright(index)])
return
end


function _knn(tree::BallTree,
point::AbstractVector,
best_idxs::AbstractVector{Int},
Expand Down
64 changes: 22 additions & 42 deletions src/hyperspheres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ end

HyperSphere(center::SVector{N,T1}, r::T2) where {N, T1, T2} = HyperSphere(center, convert(T1, r))

Base.:(==)(A::HyperSphere, B::HyperSphere) = A.center == B.center && A.r == B.r

@inline function intersects(m::M,
s1::HyperSphere{N,T},
s2::HyperSphere{N,T}) where {T <: AbstractFloat, N, M <: Metric}
Expand All @@ -19,55 +21,22 @@ end
evaluate(m, s1.center, s2.center) + s1.r <= s2.r
end

@inline function interpolate(::M,
c1::V,
c2::V,
x,
d,
ab) where {V <: AbstractVector, M <: NormMetric}
alpha = x / d
@assert length(c1) == length(c2)
@inbounds for i in eachindex(ab.center)
ab.center[i] = (1 - alpha) .* c1[i] + alpha .* c2[i]
end
return ab.center, true
end

@inline function interpolate(::M,
c1::V,
::V,
::Any,
::Any,
::Any) where {V <: AbstractVector, M <: Metric}
return c1, false
end

function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high, ab) where {V}
n_dim = size(data, 1)
n_points = high - low + 1
# First find center of all points
fill!(ab.center, 0.0)
for i in low:high
for j in 1:length(ab.center)
ab.center[j] += data[indices[i]][j]
end
end
ab.center .*= 1 / n_points

# versions with no array buffer - still not allocating in sequential BallTree construction
using Statistics: mean
function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high) where {V}
# find center
center = mean(@views(data[indices[low:high]]))
# Then find r
r = zero(get_T(eltype(V)))
for i in low:high
r = max(r, evaluate(metric, data[indices[i]], ab.center))
r = max(r, evaluate(metric, data[indices[i]], center))
end
r += eps(get_T(eltype(V)))
return HyperSphere(SVector{length(V),eltype(V)}(ab.center), r)
return HyperSphere(SVector{length(V),eltype(V)}(center), r)
end

# Creates a bounding sphere from two other spheres
function create_bsphere(m::Metric,
s1::HyperSphere{N,T},
s2::HyperSphere{N,T},
ab) where {N, T <: AbstractFloat}
function create_bsphere(m::Metric, s1::HyperSphere{N,T}, s2::HyperSphere{N,T}) where {N, T <: AbstractFloat}
if encloses(m, s1, s2)
return HyperSphere(s2.center, s2.r)
elseif encloses(m, s2, s1)
Expand All @@ -79,7 +48,7 @@ function create_bsphere(m::Metric,
# neither s1 nor s2 contains the other)
dist = evaluate(m, s1.center, s2.center)
x = 0.5 * (s2.r - s1.r + dist)
center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist, ab)
center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist)
if is_exact_center
rad = 0.5 * (s2.r + s1.r + dist)
else
Expand All @@ -88,3 +57,14 @@ function create_bsphere(m::Metric,

return HyperSphere(SVector{N,T}(center), rad)
end

@inline function interpolate(::M, c1::V, c2::V, x, d) where {V <: AbstractVector, M <: NormMetric}
length(c1) == length(c2) || throw(DimensionMismatch("interpolate arguments have length $(length(c1)) and $(length(c2))"))
alpha = x / d
center = (1 - alpha) * c1 + alpha * c2
return center, true
end

@inline function interpolate(::M, c1::V, ::V, ::Any, ::Any) where {V <: AbstractVector, M <: Metric}
return c1, false
end
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ using LinearAlgebra

using Distances: Distances, Metric, evaluate, PeriodicEuclidean
struct CustomMetric1 <: Metric end
Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs.(a .- b))
Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs, (a .- b))
function NearestNeighbors.interpolate(::CustomMetric1,
a::V,
b::V,
x,
d,
ab) where {V <: AbstractVector}
d) where {V <: AbstractVector}
idx = (abs.(b .- a) .>= d - x)
c = copy(Array(a))
c[idx] = (1 - x / d) * a[idx] + (x / d) * b[idx]
Expand Down
28 changes: 20 additions & 8 deletions test/test_monkey.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import NearestNeighbors.MinkowskiMetric
# This contains a bunch of random tests that should hopefully detect if
# some edge case has been missed in the real tests


@testset "metric $metric" for metric in fullmetrics
nrep = 30
@testset "tree type $TreeType" for TreeType in trees_with_brute
@testset "element type $T" for T in (Float32, Float64)
@testset "knn monkey" begin
Expand All @@ -14,7 +13,7 @@ import NearestNeighbors.MinkowskiMetric
elseif TreeType == BallTree && isa(metric, Hamming)
continue
end
for i in 1:30
for i in 1:nrep
dim_data = rand(1:4)
size_data = rand(1000:1300)
data = rand(T, dim_data, size_data)
Expand All @@ -28,7 +27,7 @@ import NearestNeighbors.MinkowskiMetric
end

# Compares vs Brute Force
for i in 1:30
for i in 1:nrep
dim_data = rand(1:5)
size_data = rand(100:151)
data = rand(T, dim_data, size_data)
Expand All @@ -45,7 +44,7 @@ import NearestNeighbors.MinkowskiMetric

@testset "inrange monkey" begin
# Test against brute force
for i in 1:30
for i in 1:nrep
dim_data = rand(1:6)
size_data = rand(20:250)
data = rand(T, dim_data, size_data)
Expand All @@ -62,17 +61,30 @@ import NearestNeighbors.MinkowskiMetric
end

@testset "coupled monkey" begin
for i in 1:50
for i in 1:nrep
dim_data = rand(1:5)
size_data = rand(100:1000)
data = randn(T, dim_data, size_data)
tree = TreeType(data, metric; leafsize = rand(1:8))

lf = rand(1:8)
tree = TreeType(data, metric; leafsize = lf)

if TreeType == BallTree # this caught a race-condition in an early version of the parallel BallTree code
tree2 = TreeType(data, metric; leafsize = lf, parallel = false)
@test tree.data == tree2.data
@test tree.hyper_spheres[1] == tree2.hyper_spheres[1]
@test tree.indices == tree2.indices
@test tree.metric == tree2.metric
@test tree.tree_data == tree2.tree_data
@test tree.reordered == tree2.reordered
end

point = randn(dim_data)
idxs_ball = Int[]
r = 0.1
while length(idxs_ball) < 10
r *= 2.0
idxs_ball = inrange(tree, point, r, true)
idxs_ball = inrange(tree, point, r, true)
end
idxs_knn, dists = knn(tree, point, length(idxs_ball))

Expand Down

0 comments on commit f6acba9

Please sign in to comment.