Skip to content

Commit

Permalink
Using NearestNeigbors.jl nearestneighborsimplices
Browse files Browse the repository at this point in the history
Improves algorithmic complexity a lot when having many samples.
As part of this, the nearestneighborsimplices interface was simplified
to a single function.
  • Loading branch information
rasmushenningsson committed Mar 31, 2020
1 parent 4310312 commit 209c354
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 76 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ version = "0.2.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
NearestNeighbors = "0.4"
julia = "1.0.5"

[extras]
Expand Down
4 changes: 1 addition & 3 deletions src/PrincipalMomentAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module PrincipalMomentAnalysis
using LinearAlgebra
using SparseArrays
using Statistics
using NearestNeighbors

include("svd.jl")
include("simplices.jl")
Expand All @@ -16,9 +17,6 @@ export
groupsimplices,
timeseriessimplices,
neighborsimplices,
neighborsimplices2,
sparseneighborsimplices,
sparseneighborsimplices2,
normalizemean!,
normalizemean,
normalizemeanstd!,
Expand Down
99 changes: 27 additions & 72 deletions src/simplices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,37 +145,7 @@ function timeseriessimplices(time::AbstractVector; groupby::AbstractVector=false
SimplexGraph(G,w)
end

"""
neighborsimplices2(D2; k, r, symmetric, groupby)
Create simplex graph connecting nearest neighbors in given symmetric matrix where element `i,j` equals the squared distance between samples `i` and `j`.

# Inputs
* `D2`: Matrix of squared distances.
* `k`: Number of nearest neighbors to connect. Default: `0`.
* `r`: Connected all neighbors with disctance `≤r`. Default: `0.0`.
* `symmetric`: Make the simplex graph symmetric. Default: `false`.
* `groupby`: Only connected samples within the specified groups. Default: Disabled.
"""
function neighborsimplices2(D2::AbstractMatrix; k::Integer=0, r2::Real=0.0, symmetric=false, groupby=falses(size(D2,1)))
@assert issymmetric(D2)
@assert all(x->x>=0.0, D2)
N = size(D2,1)

uniqueGroups = unique(groupby)
groupInds = Dict( g=>findall(groupby.==g) for g in uniqueGroups )

G = falses(N,N)
for j=1:N
gInds = groupInds[groupby[j]]
ind = gInds[sortperm(D2[gInds,j])]
kk = max(k+1, searchsortedlast(D2[ind,j], r2)) # k+1 to include current node
G[ind[1:min(kk,length(ind))], j] .= true
end

symmetric && (G .|= G')
SimplexGraph(G)
end

"""
neighborsimplices(A::AbstractMatrix; k, r, dim, symmetric, normalizedist, groupby)
Expand Down Expand Up @@ -208,62 +178,47 @@ julia> sg = neighborsimplices([0 0 2 2; 0 1 1 0]; r=0.45); sg.G
1 0 1 1
```
"""
function neighborsimplices(A::AbstractMatrix; k::Integer=0, r::Real=0.0, dim::Integer=typemax(Int), normalizedist=true, kwargs...)
@assert r>=0
P,N = size(A)
r2 = r*r

function neighborsimplices(A::AbstractMatrix; dim::Integer=typemax(Int), kwargs...)
if dim<minimum(size(A))
F = svdbyeigen(A; nsv=dim)
A = Diagonal(F.S)*F.Vt
end
K = A'A

if r2>0 && normalizedist
r2 *= 4*maximum(sum(x->x^2, A, dims=1))
end

d = diag(K)
D2 = Symmetric(max.(0., d .+ d' .- 2K)) # matrix of squared distances
neighborsimplices2(D2,k=k,r2=r2;kwargs...)
_neighborsimplices(A; kwargs...)
end

function _neighborsimplices(A::AbstractMatrix; k::Integer=0, r::Real=0.0, symmetric=false, normalizedist=true, groupby=nothing)#falses(size(A,2)))
k==0 && r==0.0 && return SimplexGraph(sparse(I(size(A,2)))) # trivial case, avoid some computations.
r>0.0 && normalizedist && (r *= 2*sqrt(maximum(sum(x->x^2, A, dims=1))))

N = size(A,2)
G = falses(N,N)


function sparseneighborsimplices2(D2::Symmetric; k::Integer=0, r2::Real=0.0, symmetric=false)
@assert all(x->x>=0.0, D2)
N = size(D2,1)

I,J = Int[],Int[]
for j=1:N
ind = sortperm(D2[:,j])
kk = max(k+1, searchsortedlast(D2[ind,j], r2)) # k+1 to include current node
rows = ind[1:min(kk,length(ind))]
append!(I,rows)
append!(J,Iterators.repeated(j,length(rows)))
if groupby===nothing
_neighborsimplices!(G, A, identity, k, r)
else
for g in unique(groupby)
ind = findall( groupby.==g )
_neighborsimplices!(G, A[:,ind], x->ind[x], k, r)
end
end
G = sparse(I,J,trues(length(I)))

symmetric && (G .|= G')
SimplexGraph(G)
end
function sparseneighborsimplices(A::AbstractMatrix; k::Integer=0, r::Real=0.0, dim::Integer=typemax(Int), normalizedist=true, kwargs...)
@assert r>=0
P,N = size(A)
r2 = r*r

if dim<minimum(size(A))
F = svdbyeigen(A; nsv=dim)
A = Diagonal(F.S)*F.Vt
end
K = A'A
function _neighborsimplices!(G, A, indfun, k, r)
tree = BallTree(A)

if r2>0 && normalizedist
r2 *= 4*maximum(sum(x->x^2, A, dims=1))
if k>0
indices,_ = knn(tree, A, min(k+1,size(A,2)))
for (j,I) in enumerate(indices)
G[indfun.(I),indfun(j)] .= true
end
end

d = diag(K)
D2 = Symmetric(max.(0., d .+ d' .- 2K)) # matrix of squared distances
sparseneighborsimplices2(D2,k=k,r2=r2;kwargs...)
if r>0.0
indices = inrange(tree, A, r)
for (j,I) in enumerate(indices)
G[indfun.(I),indfun(j)] .= true
end
end
end
9 changes: 8 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ using Test
using LinearAlgebra
using Statistics

@test [] == detect_ambiguities(Base, Core, PrincipalMomentAnalysis)
# change back to this simple oneliner if we can just ignore problems in StaticArrays that is imported by NearestNeighbors
# @test [] == detect_ambiguities(Base, Core, PrincipalMomentAnalysis)
@testset "Method ambiguities" begin
ambiguities = detect_ambiguities(Base, Core, PrincipalMomentAnalysis)
filter!(x->x[1].module==PrincipalMomentAnalysis || x[2].module==PrincipalMomentAnalysis, ambiguities)
@test ambiguities == []
end


const simplices2kernelmatrix = PrincipalMomentAnalysis.simplices2kernelmatrix
const simplices2kernelmatrixroot = PrincipalMomentAnalysis.simplices2kernelmatrixroot
Expand Down

0 comments on commit 209c354

Please sign in to comment.