Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made the result idxs and dists datatypes selectable #161

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ end

function _knn(tree::BallTree,
point::AbstractVector,
best_idxs::AbstractVector{Int},
best_idxs::AbstractVector{T},
best_dists::AbstractVector,
skip::F) where {F}
skip::F) where {F, T <: Integer}
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip)
return
end
Expand All @@ -161,9 +161,9 @@ end
function knn_kernel!(tree::BallTree{V},
index::Int,
point::AbstractArray,
best_idxs::AbstractVector{Int},
best_idxs::AbstractVector{T},
best_dists::AbstractVector,
skip::F) where {V, F}
skip::F) where {V, F, T <: Integer}
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip)
return
Expand Down Expand Up @@ -194,7 +194,7 @@ end
function _inrange(tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{Int}}) where {V}
idx_in_ball::Union{Nothing, Vector{T}}) where {V, T <: Integer}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
end
Expand Down
8 changes: 4 additions & 4 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ end

function _knn(tree::BruteTree{V},
point::AbstractVector,
best_idxs::AbstractVector{Int},
best_idxs::AbstractVector{T},
best_dists::AbstractVector,
skip::F) where {V, F}
skip::F) where {V, F, T <: Integer}

knn_kernel!(tree, point, best_idxs, best_dists, skip)
return
end

function knn_kernel!(tree::BruteTree{V},
point::AbstractVector,
best_idxs::AbstractVector{Int},
best_idxs::AbstractVector{T},
best_dists::AbstractVector,
skip::F) where {V, F}
skip::F) where {V, F, T <: Integer}
for i in 1:length(tree.data)
if skip(i)
continue
Expand Down
8 changes: 4 additions & 4 deletions src/kd_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ end

function _knn(tree::KDTree,
point::AbstractVector,
best_idxs::AbstractVector{Int},
best_idxs::AbstractVector{T},
best_dists::AbstractVector,
skip::F) where {F}
skip::F) where {F, T <: Integer}
init_min = get_min_distance(tree.hyper_rec, point)
knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, skip)
@simd for i in eachindex(best_dists)
Expand All @@ -171,10 +171,10 @@ end
function knn_kernel!(tree::KDTree{V},
index::Int,
point::AbstractVector,
best_idxs::AbstractVector{Int},
best_idxs::AbstractVector{T},
best_dists::AbstractVector,
min_dist,
skip::F) where {V, F}
skip::F) where {V, F, T <: Integer}
# At a leaf node. Go through all points in node and add those in range
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip)
Expand Down
10 changes: 5 additions & 5 deletions src/knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@ in the order of increasing distance to the point. `skip` is an optional predicat
to determine if a point that would be returned should be skipped based on its
index.
"""
function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F<:Function}
function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false; idxs_type::DataType = Int, dists_type = get_T(eltype(V))) where {V, T <: AbstractVector, F<:Function}
check_input(tree, points)
check_k(tree, k)
n_points = length(points)
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
dists = [Vector{dists_type}(undef, k) for _ in 1:n_points]
idxs = [Vector{idxs_type}(undef, k) for _ in 1:n_points]
for i in 1:n_points
knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip)
end
return idxs, dists
end

function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F}
fill!(idx, -1)
fill!(idx, 0)
fill!(dist, typemax(get_T(eltype(V))))
_knn(tree, point, idx, dist, skip)
if skip !== always_false
skipped_idxs = findall(==(-1), idx)
skipped_idxs = findall(==(0), idx)
deleteat!(idx, skipped_idxs)
deleteat!(dist, skipped_idxs)
end
Expand Down
4 changes: 2 additions & 2 deletions src/tree_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ end

# Checks the distance function and add those points that are among the k best.
# Uses a heap for fast insertion.
@inline function add_points_knn!(best_dists::AbstractVector, best_idxs::AbstractVector{Int},
@inline function add_points_knn!(best_dists::AbstractVector, best_idxs::AbstractVector{T},
tree::NNTree, index::Int, point::AbstractVector,
do_end::Bool, skip::F) where {F}
do_end::Bool, skip::F) where {F, T <: Integer}
for z in get_leaf_range(tree.tree_data, index)
idx = tree.reordered ? z : tree.indices[z]
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
Expand Down
7 changes: 4 additions & 3 deletions test/test_knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import Distances.evaluate
@test idxs[1] == 8
@test idxs[2] == 3

idxs, dists = knn(tree, [SVector{2, Float64}(0.8,0.8), SVector{2, Float64}(0.1,0.8)], 1, true)
@test idxs[1][1] == 8
@test idxs[2][1] == 3
idxs, dists = knn(tree, [SVector{2, Float64}(0.8,0.8), SVector{2, Float64}(0.1,0.8)], 1, true; idxs_type = UInt32, dists_type = Float16)
@test typeof(idxs[1][1]) == UInt32
@test typeof(dists[2][1]) == Float16

idxs, dists = nn(tree, [SVector{2, Float64}(0.8,0.8), SVector{2, Float64}(0.1,0.8)])
@test idxs[1] == 8
Expand Down Expand Up @@ -91,3 +91,4 @@ end
@test nearest == [1, 3]
@test distance ≈ [0.02239688629947563, 0.13440059522389006]
end