Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

manifolds v0.9 fixes, and wip Manellic tree eval #258

Merged
merged 4 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ CoordinateTransformations = "0.5, 0.6"
Distributions = "0.25"
DocStringExtensions = "0.7, 0.8, 0.9"
KernelDensityEstimate = "0.5.10"
Manifolds = "0.8, 0.9"
Manifolds = "0.9"
ManifoldsBase = "0.14, 0.15"
NLsolve = "3, 4"
Optim = "1"
Expand All @@ -53,7 +53,7 @@ Rotations = "1"
StaticArrays = "0.15, 1"
TensorCast = "0.2, 0.3, 0.4"
TransformUtils = "0.2.10"
julia = "1.7"
julia = "1.9"

[extensions]
ApproxManiProdGadflyExt = "Gadfly"
Expand Down
1 change: 1 addition & 0 deletions src/ApproxManifoldProducts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ include("services/ManifoldPartials.jl")
include("Interface.jl")

# regular features
include("services/KernelEval.jl")
include("CommonUtils.jl")
include("services/ManifoldKernelDensity.jl")
include("services/Euclidean.jl")
Expand Down
10 changes: 10 additions & 0 deletions src/Deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@

## ======================================================================================================
## Remove below before v0.10
## ======================================================================================================

# function setPointsMani!(dest::ProductRepr, src::ProductRepr)
# for (k,prt) in enumerate(dest.parts)
# setPointsMani!(prt, src.parts[k])
# end
# end

## ======================================================================================================
## Remove below before v0.8
## ======================================================================================================
Expand Down
2 changes: 2 additions & 0 deletions src/ExportAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ export calcMean
export mean, cov, std, var
export getInfoPerCoord, getBandwidth
export antimarginal

export mmd!, mmd
6 changes: 1 addition & 5 deletions src/Interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,7 @@ function setPointsMani!(dest::AbstractVector, src::AbstractVector{<:AbstractVect
setPointsMani!(dest, src[1])
end

function setPointsMani!(dest::ProductRepr, src::ProductRepr)
for (k,prt) in enumerate(dest.parts)
setPointsMani!(prt, src.parts[k])
end
end




Expand Down
11 changes: 0 additions & 11 deletions src/KernelHilbertEmbeddings.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@

# see: A Gretton, e.g. http://www.gatsby.ucl.ac.uk/~gretton/coursefiles/lecture4_introToRKHS.pdf

export
mmd!, # KED
mmd

"""
$SIGNATURES

Normal kernel used for Hilbert space embeddings.
"""
ker(M::MB.AbstractManifold, p, q, sigma::Real=0.001) = @fastmath exp( -sigma*(distance(M, p, q)^2) )

# overwrite non-symmetric with alternate implementations
# ker(M::MB.AbstractManifold, p, q, sigma::Real=0.001) = exp( -sigma*(distance(M, p, q)^2) )

Expand Down
57 changes: 57 additions & 0 deletions src/services/KernelEval.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

abstract type AbstractKernel end

@kwdef struct MvNormalKernel{T,M} <: AbstractKernel
p::MvNormal{T,M}
sqrt_iΣ::M = sqrt(inv(p.Σ))
end


Statistics.mean(m::MvNormalKernel) = m.p.μ

function distanceMalahanobisCoordinates(
M::AbstractManifold,
K::AbstractKernel,
q,
basis=DefaultOrthonormalBasis()
)
p = mean(K)
i_p = inv(M,p)
pq = Manifolds.compose(M, i_p, q)
X = log(M, p, pq)
Xc = get_coordinates(M, p, X, basis)
return K.sqrt_iΣ*Xc
end

function distanceMalahanobisSq(
M::AbstractManifold,
K::AbstractKernel,
q,
basis=DefaultOrthonormalBasis()
)
δc = distanceMalahanobisCoordinates(M,K,q,basis)
p = mean(K)
# ϵ = identity_element(M, q)
X = get_vector(M, p, δc, basis)
return inner(M, p, X, X)
end

# function distance(
# M::AbstractManifold,
# p::AbstractVector,
# q::AbstractVector,
# kernel=(_p) -> MvNormalKernel(
# p=MvNormal(_p,SVector(ones(manifold_dimension(M))...))
# ),
# distFnc::Function=distanceMalahanobisSq
# )
# distFnc(M, kernel(p), q)
# end

"""
$SIGNATURES

Normal kernel used for Hilbert space embeddings.
"""
ker(M::AbstractManifold, p, q, sigma::Real=0.001) = exp( -sigma*(distance(M, p, q)^2) )

29 changes: 25 additions & 4 deletions src/services/ManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
struct ManellicTree{M,D<:AbstractVector,N,HL,HT}
manifold::M
data::D
weights::MVector{N,<:Real}
permute::MVector{N,Int}
leaf_kernels::MVector{N,HL}
tree_kernels::MVector{N,HT}
Expand All @@ -46,8 +47,9 @@ function Base.show(io::IO, mt::ManellicTree{M,D,N,HL,HT}) where {M,D,N,HL,HT}
println(io, "(")
@assert N == length(mt.data) "show(::ManellicTree,) noticed a data size issue, expecting N$(N) == length(.data)$(length(mt.data))"
if 0 < N
println(io, " .data[1:]: ", mt.data[1], ", ...")
println(io, " .permute[1:]: ", mt.permute[1], ", ...")
println(io, " .data[1:]: ", mt.data[1], " ... ", mt.data[end])
println(io, " .weights[1:]: ", mt.weights[1], " ... ", mt.weights[end])
println(io, " .permute[1:]: ", mt.permute[1], " ... ", mt.permute[end])
printstyled(io, " .tkernels[1]: ", " __see below__"; color=:light_black)
println(io)
println(io, " ...,")
Expand Down Expand Up @@ -221,11 +223,12 @@ end
function buildTree_Manellic!(
M::AbstractManifold,
r_PP::AbstractVector{P}; # vector of points referenced to the r_frame
len = length(r_PP),
weights::AbstractVector{<:Real} = ones(len).*(1/len),
kernel = MvNormal,
kernel_bw = nothing # TODO
kernel_bw = nothing, # TODO
) where {P <: AbstractArray}
#
len = length(r_PP)
D = manifold_dimension(M)
CV = SMatrix{D,D,Float64,D*D}(diagm(ones(D)))
tknlT = kernel(
Expand All @@ -241,10 +244,13 @@ function buildTree_Manellic!(
end
) |> typeof

# kernel scale

#
mtree = ManellicTree(
M,
r_PP,
MVector{len,Float64}(weights),
MVector{len,Int}(1:len),
MVector{len,lknlT}(undef),
MVector{len,tknlT}(undef),
Expand All @@ -261,4 +267,19 @@ function buildTree_Manellic!(
kernel,
kernel_bw
)
end

# TODO use geometric computing for faster evaluation
function evaluate(
mt::ManellicTree{M,D,N},
p,
) where {M,D,N}

sumval = 0.0
for i in 1:N
sumval += mt.weights[i] * ker(mt.manifold, mt.leaf_kernels[i], p, 0.5)
end



end
6 changes: 3 additions & 3 deletions src/services/ManifoldPartials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
export getManifoldPartial

# forcing ProductManifold to use ArrayPartition as accompanying representation
const _PartiableRepresentationProduct = Union{Nothing,<:ArrayPartition, <:ProductRepr}
const _PartiableRepresentationProduct = Union{Nothing,<:ArrayPartition}
# forcing ProductManifold to use ArrayPartition as accompanying representation
const _PartiableRepresentationFlat{T} = Union{Nothing,<:AbstractVector{T}}
# More general representation for Manifold Factors or Groups
Expand Down Expand Up @@ -39,7 +39,7 @@ function _getReprPartial( M::MB.AbstractManifold,
return ret
end

function _getReprPartial( M::Union{<:typeof(SpecialOrthogonal(2)), <:Rotations{2}},
function _getReprPartial( M::Union{<:typeof(SpecialOrthogonal(2)), <:Rotations{ManifoldsBase.TypeParameter{Tuple{2}}}},
repr::AbstractMatrix{T},
partial::AbstractVector{Int}, # total partial from user over all Factors
offset::Base.RefValue{Int}=Ref(0),
Expand Down Expand Up @@ -77,7 +77,7 @@ function getManifoldPartial(M::Circle,
return (M,repr)
end

function getManifoldPartial(M::Rotations{2},
function getManifoldPartial(M::Rotations{ManifoldsBase.TypeParameter{Tuple{2}}},
partial::AbstractVector{Int},
repr::_PartiableRepresentation=nothing,
offset::Base.RefValue{Int}=Ref(0);
Expand Down
11 changes: 11 additions & 0 deletions test/basic_se3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,20 @@ A = ManifoldKernelDensity(M, pts1)
B = ManifoldKernelDensity(M, pts2)
C = ManifoldKernelDensity(M, pts3)


@test 0.75 < AMP.ker(M, pts1[1], pts1[2], 0.001) < 1.25
@test 0.75 < AMP.ker(M, pts2[1], pts2[2], 0.001) < 1.25
@test 0.75 < AMP.ker(M, pts3[1], pts3[2], 0.001) < 1.25

@test 0 < AMP.ker(M, pts1[1], pts3[1], 0.001) < 0.25
@test 0 < AMP.ker(M, pts1[1], pts3[2], 0.001) < 0.25
@test 0 < AMP.ker(M, pts1[2], pts3[1], 0.001) < 0.25


@test isapprox(A, B)
@test !isapprox(A, C)


##

show(A)
Expand Down
5 changes: 2 additions & 3 deletions test/testManiProductBigSmall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ p = manikde!(M, X1)
q = manikde!(M, X2)

# check new MKD have right type info cached
@test (p._u0 |> typeof) == typeof(u0)
@test (p._u0 |> typeof) == typeof(u0)
@test_broken (p._u0 |> typeof) == typeof(u0)

pq = manifoldProduct([p;q], M)

# check new product also has right point type info cached
@test (pq._u0 |> typeof) == typeof(u0)
@test_broken (pq._u0 |> typeof) == typeof(u0)

##

Expand Down
Loading