Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build ManellicTree #245

Merged
merged 5 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
KernelDensityEstimate = "2472808a-b354-52ea-a80e-1658a3c6056d"
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
Expand Down Expand Up @@ -60,10 +61,11 @@ ApproxManiProdGadflyExt = "Gadfly"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BSON", "DataFrames", "Rotations", "Test"]
test = ["BSON", "DataFrames", "Rotations", "TensorCast", "Test"]

[weakdeps]
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
3 changes: 2 additions & 1 deletion src/ApproxManifoldProducts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ using TensorCast
using StaticArrays
using Logging
using Statistics
using Distributions

import Random: rand

import Base: *, isapprox, convert
import Base: *, isapprox, convert, show
import LinearAlgebra: rotate!
import Statistics: mean, std, cov, var
import KernelDensityEstimate: getPoints, getBW
Expand Down
181 changes: 169 additions & 12 deletions src/services/ManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,69 @@



# """
# $TYPEDEF

struct HyperEllipse{P,D}
""" manifold point at which this ellipse is based """
point::P
""" Covariance of coords at either TBD this point or some other reference point? """
coord_cov::SMatrix{D,D,Float64,<:Integer}
end
# Elliptical structure for use in a (Manellic) Ball Tree.
# """
# struct HyperEllipse{P <:AbstractArray,D,DD}
# """ manifold point at which this ellipse is based """
# point::P
# """ Covariance of coords at either TBD this point or some other reference point? """
# coord_cov::SMatrix{D,D,Float64,DD}
# end

# ManellicTree

# Short for Manifold Ellipse Metric Tree
# starting as a balanced tree, relax to unbalanced in future.
struct ManellicTree{M,N,P}
struct ManellicTree{M,D<:AbstractVector,N,HL,HT}
manifold::M
data::MVector{N,P}
hyper_ellipse::MVector{N,<:HyperEllipse}
data::D
permute::MVector{N,Int}
leaf_kernels::MVector{N,HL}
tree_kernels::MVector{N,HT}
left_idx::MVector{N,Int}
right_idx::MVector{N,Int}
end


function Base.show(io::IO, mt::ManellicTree{M,D,N,HL,HT}) where {M,D,N,HL,HT}
printstyled(io, "ManellicTree{"; bold=true,color = :blue)
println(io)
printstyled(io, " M = ", M, color = :magenta)
println(io)
printstyled(io, " D = ", D, color = :magenta)
println(io)
printstyled(io, " N = ", N, color = :magenta)
println(io)
printstyled(io, " HL = ", HL, color = :magenta)
println(io)
printstyled(io, " HT = ", HT, color = :magenta)
println(io)
printstyled(io, "}", bold=true, color = :blue)
println(io, "(")
@assert N == length(mt.data) "show(::ManellicTree,) noticed a data size issue, expecting N$(N) == length(.data)$(length(mt.data))"
if 0 < N
println(io, " .data[1:]: ", mt.data[1], ", ...")
println(io, " .permute[1:]: ", mt.permute[1], ", ...")
printstyled(io, " .tkernels[1]: ", " __see below__"; color=:light_black)
println(io)
println(io, " ...,")
end
println(io, ")")
if 0 < N
printstyled(io, " .tkernels[1] = "; color=:light_black)
println(io, mt.tree_kernels[1])
end
# TODO ad dmore stats: max depth, widest point, longest chain, max clique size, average nr children

return nothing
end

Base.show(io::IO, ::MIME"text/plain", mt::ManellicTree) = show(io, mt)


# covariance
function eigenCoords(
f_CVp::AbstractMatrix
Expand Down Expand Up @@ -50,16 +93,33 @@ end
# give vector of manifold points and split along largest covariance (i.e. major direction)
function splitPointsEigen(
M::AbstractManifold,
r_PP::AbstractVector{P}
r_PP::AbstractVector{P};
kernel = MvNormal,
kernel_bw = nothing,
) where {P <: AbstractArray}
#
len = length(r_PP)

# do calculations around mean point on manifold, i.e. support Riemannian
p = mean(M, r_PP)
cv = Manifolds.cov(M, r_PP)

r_XXp = log.(Ref(M), Ref(p), r_PP)
r_CCp = vee.(Ref(M), Ref(p), r_XXp)

D = manifold_dimension(M)
# FIXME, consider user provided bandwidth in estimating multisample covariance
cv = if 5 < len
SMatrix{D,D,Float64}(Manifolds.cov(M, r_PP))
elseif 1 < len < 5
SMatrix{D,D,Float64}(diagm(diag(Manifolds.cov(M, r_PP))))
else
# TODO case with user defined bandwidth for faster tree construction
bw = isnothing(kernel_bw) ? SMatrix{D,D,Float64}(diagm(eps(Float64)*ones(D))) : kernel_bw
return r_CCp, BitVector(ntuple(i->true,Val(len))), kernel(p, bw)
end
# S = SymmetricPositiveDefinite(2)
# @info "COV" cv LinearAlgebra.isposdef(cv) Manifolds.check_point(S,cv) len

# expecting largest variation on coord dimension `pidx[end]`
r_R_ax, Λ, pidx = eigenCoords(cv)
ax_R_r = r_R_ax'
Expand Down Expand Up @@ -103,5 +163,102 @@ function splitPointsEigen(
_flipmask_minormax!(mask, imask, ax_CC1; argminmax=argmax)

# return rotated coordinates and split mask
ax_CCp, mask
ax_CCp, mask, kernel(p, cv)
end


_getleft(i::Integer) = 2*i
_getright(i::Integer) = 2*1 + 1


function buildTree_Manellic!(
mtree::ManellicTree,
index::Integer,
low::Integer, # bottom index of segment
high::Integer; # top index of segment;
kernel = MvNormal,
kernel_bw = nothing,
leaf_size = 1
)
M = mtree.manifold
# take a slice of data
idc = low:high
# according to current index permutation (i.e. sort data as you build the tree)
ido = view(mtree.permute, idc)
# split the slice of order-permuted data
ax_CCp, mask, knl = splitPointsEigen(M, view(mtree.data, ido); kernel, kernel_bw)

# set HyperEllipse at this level in tree
# FIXME, replace with just the kernel choice, not hyper such and such needed?
N = length(mtree.tree_kernels)
if index < N
mtree.tree_kernels[index] = knl # HyperEllipse(knl.μ, knl.Σ.mat)
else
mtree.leaf_kernels[index-N+1] = knl
end

# sort the data as 'small' and 'big' elements either side of the eigen split
sml = view(ido, mask)
big = view(ido, xor.(mask, true))
# reorder the slice portion of the permutation with new ordering
ido .= SA[sml...; big...]

npts = high - low + 1
mid_idx = low + sum(mask)

if leaf_size < npts
# recursively call two branches of tree, left
buildTree_Manellic!(mtree, _getleft(index), low, mid_idx-1; kernel, kernel_bw, leaf_size)
# and right subtree
buildTree_Manellic!(mtree, _getright(index), mid_idx, high; kernel, kernel_bw, leaf_size)
end

return mtree
end



function buildTree_Manellic!(
M::AbstractManifold,
r_PP::AbstractVector{P}; # vector of points referenced to the r_frame
kernel = MvNormal,
kernel_bw = nothing # TODO
) where {P <: AbstractArray}
#
len = length(r_PP)
D = manifold_dimension(M)
CV = SMatrix{D,D,Float64,D*D}(diagm(ones(D)))
tknlT = kernel(
r_PP[1],
CV
) |> typeof
lknlT = kernel(
r_PP[1],
if isnothing(kernel_bw)
CV
else
kernel_bw
end
) |> typeof

#
mtree = ManellicTree(
M,
r_PP,
MVector{len,Int}(1:len),
MVector{len,lknlT}(undef),
MVector{len,tknlT}(undef),
MVector{len,Int}(undef),
MVector{len,Int}(undef)
)

#
return buildTree_Manellic!(
mtree,
1, # start at root
1, # spanning all data
len; # to end of data
kernel,
kernel_bw
)
end
45 changes: 40 additions & 5 deletions test/manellic/testManellicTree.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@

# using Revise
using Test
using ApproxManifoldProducts
using StaticArrays
import ApproxManifoldProducts: HyperEllipse, ManellicTree, eigenCoords, splitPointsEigen
using TensorCast
using Manifolds
using Distributions
import ApproxManifoldProducts: ManellicTree, eigenCoords, splitPointsEigen

##

Expand All @@ -20,7 +24,7 @@ function testEigenCoords(
r_R_ax*ax_C + SA[10;-100]
end
r_CV = Manifolds.cov(M, r_CC)
r_R_ax_, L, pidx = eigenCoords(r_CV)
r_R_ax_, L, pidx = ApproxManifoldProducts.eigenCoords(r_CV)

# spot check
@show _ax_ERR = log_lie(SpecialOrthogonal(2), (r_R_ax_')*r_R_ax)[1,2]
Expand All @@ -35,11 +39,15 @@ end
##

M = TranslationGroup(2)
r_CC, R, pidx, r_CV = testEigenCoords(pi/3);
ax_CCp, mask = splitPointsEigen(M, r_CC)
α = pi/3
r_CC, R, pidx, r_CV = testEigenCoords(α);
ax_CCp, mask, knl = splitPointsEigen(M, r_CC)
@test sum(mask) == (length(r_CC) ÷ 2)
@test knl isa MvNormal
Mr = SpecialOrthogonal(2)
@test isapprox( α, vee(Mr, Identity(Mr), log_lie(Mr, R))[1] ; atol=0.1)

# using GLMakie
#
# fig = Figure()
# ax = Axis(fig[1,1])
# ptsl = ax_CCp[mask]
Expand All @@ -53,9 +61,36 @@ ax_CCp, mask = splitPointsEigen(M, r_CC)
# plot!(ax, (s->s[1]).(ptsr), (s->s[2]).(ptsr), color=:red)
# fig

## ensure that view of view can update original memory

A = randn(3)
A_ = view(A, 1:2)
A__ = view(A_, 1:1)
A__[1] = -100
@test isapprox(-100, A[1]; atol=1e-10)

##

r_PP = r_CC # shortcut because we are in Euclidean space
mtree = ApproxManifoldProducts.buildTree_Manellic!(M, r_PP)


##

@cast pts[i,d] := r_PP[i][d]

ptsl = pts[mtree.permute[1:50],:]
ptsr = pts[mtree.permute[51:100],:]

##

# fig = Figure()
# ax = Axis(fig[1,1])

# plot!(ax, ptsl[:,1], ptsl[:,2], color=:blue)
# plot!(ax, ptsr[:,1], ptsr[:,2], color=:red)

# fig

##
end
Expand Down