Skip to content

Commit b148fdd

Browse files
committed
use ArrayOfArrays for return value to reduce the number of allocated arrays
A long standing gripe for me has been that indices and distances are returned as standard nested `Array`s. Typically, each inner array hold quite a small number of neighbors so it means that we allocate a large number of small arrays. Using ArrayOfArrays, these are stored contigously in one large flat array instead. The difference in allocations can be readily seen: ```julia julia> input = rand(3, 10^6); julia> tree = KDTree(rand(3, 10^6)); julia> @time knn(tree, input, 5); 1.538003 seconds (2.00 M allocations: 221.253 MiB, 10.03% gc time) julia> @time knn(tree, input, 5); 1.489310 seconds (98 allocations: 189.884 MiB, 0.29% gc time) ```
1 parent 0d814ad commit b148fdd

File tree

7 files changed

+55
-17
lines changed

7 files changed

+55
-17
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
33
version = "0.4.22"
44

55
[deps]
6+
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
67
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
78
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
89

910
[compat]
1011
Distances = "0.10.12"
1112
StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0"
1213
julia = "1.6"
14+
ArraysOfArrays = "0.6"
1315

1416
[extras]
1517
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

benchmark/Manifest.toml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.11.1"
3+
julia_version = "1.11.5"
44
manifest_format = "2.0"
55
project_hash = "c2d4f1e1a4db771bb121b0dd2aff4834a9af3804"
66

@@ -13,6 +13,22 @@ version = "0.4.5"
1313
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
1414
version = "1.1.2"
1515

16+
[[deps.ArraysOfArrays]]
17+
deps = ["Statistics"]
18+
git-tree-sha1 = "8e64c97ac7bffbd3327d8ddadf8dad26b87a2664"
19+
uuid = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
20+
version = "0.6.6"
21+
22+
[deps.ArraysOfArrays.extensions]
23+
ArraysOfArraysAdaptExt = "Adapt"
24+
ArraysOfArraysChainRulesCoreExt = "ChainRulesCore"
25+
ArraysOfArraysStaticArraysCoreExt = "StaticArraysCore"
26+
27+
[deps.ArraysOfArrays.weakdeps]
28+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
29+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
30+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
31+
1632
[[deps.Artifacts]]
1733
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
1834
version = "1.11.0"
@@ -167,10 +183,10 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
167183
version = "2023.12.12"
168184

169185
[[deps.NearestNeighbors]]
170-
deps = ["Distances", "StaticArrays"]
186+
deps = ["ArraysOfArrays", "Distances", "StaticArrays"]
171187
path = ".."
172188
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
173-
version = "0.4.21"
189+
version = "0.4.22"
174190

175191
[[deps.NetworkOptions]]
176192
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
@@ -184,7 +200,7 @@ version = "0.3.27+1"
184200
[[deps.OpenLibm_jll]]
185201
deps = ["Artifacts", "Libdl"]
186202
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
187-
version = "0.8.1+2"
203+
version = "0.8.5+0"
188204

189205
[[deps.OpenSpecFun_jll]]
190206
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]

src/NearestNeighbors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Distances
44
import Distances: PreMetric, Metric, UnionMinkowskiMetric, result_type, eval_reduce, eval_end, eval_op, eval_start, evaluate, parameters
55

66
using StaticArrays
7+
using ArraysOfArrays
78
import Base.show
89

910
export NNTree, BruteTree, KDTree, BallTree, DataFreeTree

src/inrange.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ function inrange(tree::NNTree,
2222
check_input(tree, points)
2323
check_radius(radius)
2424

25-
idxs = [Vector{Int}() for _ in 1:length(points)]
25+
idxs = VectorOfArrays{Int, 1}()
26+
idx = Int[]
2627

2728
for i in 1:length(points)
28-
inrange_point!(tree, points[i], radius, sortres, idxs[i])
29+
inrange_point!(tree, points[i], radius, sortres, idx)
30+
push!(idxs, idx)
31+
resize!(idx, 0)
2932
end
3033
return idxs
3134
end
@@ -79,11 +82,14 @@ function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Numb
7982
check_input(tree, points)
8083
check_radius(radius)
8184
n_points = size(points, 2)
82-
idxs = [Vector{Int}() for _ in 1:n_points]
85+
idxs = VectorOfArrays{Int, 1}()
86+
idx = Int[]
8387

8488
for i in 1:n_points
8589
point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim)))
86-
inrange_point!(tree, point, radius, sortres, idxs[i])
90+
inrange_point!(tree, point, radius, sortres, idx)
91+
push!(idxs, idx)
92+
resize!(idx, 0)
8793
end
8894
return idxs
8995
end

src/knn.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,17 @@ function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false,
2626
check_input(tree, points)
2727
check_k(tree, k)
2828
n_points = length(points)
29-
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
30-
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
29+
dists = VectorOfArrays{get_T(eltype(V)), 1}()
30+
idxs = VectorOfArrays{Int, 1}()
31+
dist = zeros(get_T(eltype(V)), k)
32+
idx = zeros(Int, k)
33+
3134
for i in 1:n_points
32-
knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip)
35+
knn_point!(tree, points[i], sortres, dist, idx, skip)
36+
push!(dists, dist)
37+
push!(idxs, idx)
38+
fill!(dist, 0)
39+
fill!(idx, 0)
3340
end
3441
return idxs, dists
3542
end
@@ -93,12 +100,18 @@ function knn_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, ::Val{di
93100
check_input(tree, points)
94101
check_k(tree, k)
95102
n_points = size(points, 2)
96-
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
97-
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
103+
dists = VectorOfArrays{Float64, 1}()
104+
idxs = VectorOfArrays{Int, 1}()
105+
dist = zeros(Float64, k)
106+
idx = zeros(Int, k)
98107

99108
for i in 1:n_points
100109
point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim)))
101-
knn_point!(tree, point, sortres, dists[i], idxs[i], skip)
110+
knn_point!(tree, point, sortres, dist, idx, skip)
111+
push!(dists, dist)
112+
push!(idxs, idx)
113+
fill!(dist, 0)
114+
fill!(idx, 0)
102115
end
103116
return idxs, dists
104117
end

test/test_inrange.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
points = rand(SVector{3, Float64}, 100)
8282
kdtree = KDTree(points)
8383
idxs = inrange(kdtree, view(points, 1:10), 0.1)
84-
@test idxs isa Vector{Vector{Int}}
84+
@test eltype(idxs) <: AbstractVector{Int}
8585
end
8686

8787
@testset "mutating" begin

test/test_knn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ end
133133
points = rand(SVector{3, Float64}, 100)
134134
kdtree = KDTree(points)
135135
idxs, dists = knn(kdtree, view(points, 1:10), 3)
136-
@test idxs isa Vector{Vector{Int}}
137-
@test dists isa Vector{Vector{Float64}}
136+
@test eltype(idxs) <: AbstractVector{Int}
137+
@test eltype(dists) <: AbstractVector{Float64}
138138
end
139139

140140
@testset "mutating" begin

0 commit comments

Comments
 (0)