In [4]:
using KernelAbstractions, CUDA
using Test

In [10]:
@kernel function knn_kernel(
        @Const(points1),
        @Const(points2),
        @Const(lengths1),
        @Const(lengths2),
        @Const(D),
        @Const(K),
        dists,
        idxs)

    p1, n = @index(Global, NTuple)

    # MinK parameters
    size = 0                    # Size of mink (when less than K)
    max_key = dists[1, p1, n]   # Placeholder for max_key
    max_idx = idxs[1, p1, n]    # Placeholder for max_idx

    # Runs in current thread
    for p2 = 1:lengths2[n]
        dist = eltype(dists)(0)
        for d = 1:D
            dist += (points1[d, p1, n] - points2[d, p2, n])^2
        end
        # Add (dist, p2) to MinK data structure
        if size < K           # Runtime: O(1)
            dists[size + 1, p1, n] = dist
            idxs[size + 1, p1, n] = p2
            if size == 0 || dist > max_key
                max_key = dist
                max_idx = size + 1
            end
            size += 1
        elseif dist < max_key       # Runtime: O(K)
            # Current key replaces old max
            dists[max_idx, p1, n] = dist
            idxs[max_idx, p1, n] = p2
            # Find new max from all dists
            max_key, max_idx = dist, -1
            for i = 1:K
                if dists[i, p1, n] ≥ max_key
                    max_key, max_idx = dists[i, p1, n], i
                end
            end
        end
    end
end


In [9]:
"""
Really simple nearest neighbor test

Check that the nearest neighbor to p1=1 in p2 = [-2, -1, 0, 1, 2]
is p1=1 at index 4 with distance 0
"""
function test_knn_kernel()
    D, P1, N = 1, 1, 1
    K, P2 = 1, 5

    p1 = reshape([1f0], (D, P1, N))
    p2 = reshape([-2f0, -1f0, 0f0, 1f0, 2f0], (D, P2, N))
    lengths1 = [P1]
    lengths2 = [P2]
    dists = zeros(Float32, (K, P1, N))
    idxs = zeros(Int64, (K, P1, N))

    kernel! = knn_kernel(CPU(), 1, (P1, N))
    event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
    wait(event)

    @test dists[1, 1, 1] == 0f0
    @test idxs[1, 1, 1] == 4
    @test p2[1, 4, 1] == 1f0
end
test_knn_kernel()

# function main()
#     N = 2
#     D = 1
#     P1 = 2
#     P2 = 5
#     K = 3

#     p1 = CUDA.randn(Float32, (D, P1, N))
#     p2 = CUDA.randn(Float32, (D, P2, N))
#     lengths1 = CUDA.fill(P1, (N,))
#     lengths2 = CUDA.fill(P2, (N,))


#     dists = CUDA.zeros(Float32, (K, P1, N))
#     idxs = CUDA.zeros(Int64, (K, P1, N))
# end
# main()

[32m[1mTest Passed[22m[39m
  Expression: p2[1, 4, 1] == 1.0f0
   Evaluated: 1.0f0 == 1.0f0