From 22914f0d7a7bcba7e376c51cccc5c0337f6cc619 Mon Sep 17 00:00:00 2001 From: dehann Date: Sun, 22 Oct 2023 21:33:17 -0700 Subject: [PATCH 1/3] wip kernel distance --- Project.toml | 2 +- src/ApproxManifoldProducts.jl | 1 + src/ExportAPI.jl | 2 ++ src/KernelHilbertEmbeddings.jl | 11 ------- src/services/KernelEval.jl | 57 +++++++++++++++++++++++++++++++++ src/services/ManellicTree.jl | 29 ++++++++++++++--- test/testManiProductBigSmall.jl | 5 ++- 7 files changed, 88 insertions(+), 19 deletions(-) create mode 100644 src/services/KernelEval.jl diff --git a/Project.toml b/Project.toml index 843aa57..061babe 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ Rotations = "1" StaticArrays = "0.15, 1" TensorCast = "0.2, 0.3, 0.4" TransformUtils = "0.2.10" -julia = "1.7" +julia = "1.9" [extensions] ApproxManiProdGadflyExt = "Gadfly" diff --git a/src/ApproxManifoldProducts.jl b/src/ApproxManifoldProducts.jl index dd7f32c..48741cb 100644 --- a/src/ApproxManifoldProducts.jl +++ b/src/ApproxManifoldProducts.jl @@ -56,6 +56,7 @@ include("services/ManifoldPartials.jl") include("Interface.jl") # regular features +include("services/KernelEval.jl") include("CommonUtils.jl") include("services/ManifoldKernelDensity.jl") include("services/Euclidean.jl") diff --git a/src/ExportAPI.jl b/src/ExportAPI.jl index ea4a8df..2a992e4 100644 --- a/src/ExportAPI.jl +++ b/src/ExportAPI.jl @@ -29,3 +29,5 @@ export calcMean export mean, cov, std, var export getInfoPerCoord, getBandwidth export antimarginal + +export mmd!, mmd diff --git a/src/KernelHilbertEmbeddings.jl b/src/KernelHilbertEmbeddings.jl index 962650f..e11f4bb 100644 --- a/src/KernelHilbertEmbeddings.jl +++ b/src/KernelHilbertEmbeddings.jl @@ -1,17 +1,6 @@ # see: A Gretton, e.g. http://www.gatsby.ucl.ac.uk/~gretton/coursefiles/lecture4_introToRKHS.pdf -export - mmd!, # KED - mmd - -""" - $SIGNATURES - -Normal kernel used for Hilbert space embeddings. -""" -ker(M::MB.AbstractManifold, p, q, sigma::Real=0.001) = @fastmath exp( -sigma*(distance(M, p, q)^2) ) - # overwrite non-symmetric with alternate implementations # ker(M::MB.AbstractManifold, p, q, sigma::Real=0.001) = exp( -sigma*(distance(M, p, q)^2) ) diff --git a/src/services/KernelEval.jl b/src/services/KernelEval.jl new file mode 100644 index 0000000..3ecc4ec --- /dev/null +++ b/src/services/KernelEval.jl @@ -0,0 +1,57 @@ + +abstract type AbstractKernel end + +@kwdef struct MvNormalKernel{T,M} <: AbstractKernel + p::MvNormal{T,M} + sqrt_iΣ::M = sqrt(inv(p.Σ)) +end + + +Statistics.mean(m::MvNormalKernel) = m.p.μ + +function distanceMalahanobisCoordinates( + M::AbstractManifold, + K::AbstractKernel, + q, + basis=DefaultOrthonormalBasis() +) + p = mean(K) + i_p = inv(M,p) + pq = Manifolds.compose(M, i_p, q) + X = log(M, p, pq) + Xc = get_coordinates(M, p, X, basis) + return K.sqrt_iΣ*Xc +end + +function distanceMalahanobisSq( + M::AbstractManifold, + K::AbstractKernel, + q, + basis=DefaultOrthonormalBasis() +) + δc = distanceMalahanobisCoordinates(M,K,q,basis) + p = mean(K) + # ϵ = identity_element(M, q) + X = get_vector(M, p, δc, basis) + return inner(M, p, X, X) +end + +# function distance( +# M::AbstractManifold, +# p::AbstractVector, +# q::AbstractVector, +# kernel=(_p) -> MvNormalKernel( +# p=MvNormal(_p,SVector(ones(manifold_dimension(M))...)) +# ), +# distFnc::Function=distanceMalahanobisSq +# ) +# distFnc(M, kernel(p), q) +# end + +""" + $SIGNATURES + +Normal kernel used for Hilbert space embeddings. +""" +ker(M::AbstractManifold, p, q, sigma::Real=0.001) = exp( -sigma*(distance(M, p, q)) ) + diff --git a/src/services/ManellicTree.jl b/src/services/ManellicTree.jl index e4b5ed0..604041d 100644 --- a/src/services/ManellicTree.jl +++ b/src/services/ManellicTree.jl @@ -21,6 +21,7 @@ struct ManellicTree{M,D<:AbstractVector,N,HL,HT} manifold::M data::D + weights::MVector{N,<:Real} permute::MVector{N,Int} leaf_kernels::MVector{N,HL} tree_kernels::MVector{N,HT} @@ -46,8 +47,9 @@ function Base.show(io::IO, mt::ManellicTree{M,D,N,HL,HT}) where {M,D,N,HL,HT} println(io, "(") @assert N == length(mt.data) "show(::ManellicTree,) noticed a data size issue, expecting N$(N) == length(.data)$(length(mt.data))" if 0 < N - println(io, " .data[1:]: ", mt.data[1], ", ...") - println(io, " .permute[1:]: ", mt.permute[1], ", ...") + println(io, " .data[1:]: ", mt.data[1], " ... ", mt.data[end]) + println(io, " .weights[1:]: ", mt.weights[1], " ... ", mt.weights[end]) + println(io, " .permute[1:]: ", mt.permute[1], " ... ", mt.permute[end]) printstyled(io, " .tkernels[1]: ", " __see below__"; color=:light_black) println(io) println(io, " ...,") @@ -221,11 +223,12 @@ end function buildTree_Manellic!( M::AbstractManifold, r_PP::AbstractVector{P}; # vector of points referenced to the r_frame + len = length(r_PP), + weights::AbstractVector{<:Real} = ones(len).*(1/len), kernel = MvNormal, - kernel_bw = nothing # TODO + kernel_bw = nothing, # TODO ) where {P <: AbstractArray} # - len = length(r_PP) D = manifold_dimension(M) CV = SMatrix{D,D,Float64,D*D}(diagm(ones(D))) tknlT = kernel( @@ -241,10 +244,13 @@ function buildTree_Manellic!( end ) |> typeof + # kernel scale + # mtree = ManellicTree( M, r_PP, + MVector{len,Int}(weights), MVector{len,Int}(1:len), MVector{len,lknlT}(undef), MVector{len,tknlT}(undef), @@ -261,4 +267,19 @@ function buildTree_Manellic!( kernel, kernel_bw ) +end + +# TODO use geometric computing for faster evaluation +function evaluate( + mt::ManellicTree{M,D,N}, + p, +) where {M,D,N} + + sumval = 0.0 + for i in 1:N + sumval += mt.weights[i] * ker(mt.manifold, mt.leaf_kernels[i], p, 0.5) + end + + + end \ No newline at end of file diff --git a/test/testManiProductBigSmall.jl b/test/testManiProductBigSmall.jl index aa1d086..dbe7e0b 100644 --- a/test/testManiProductBigSmall.jl +++ b/test/testManiProductBigSmall.jl @@ -32,13 +32,12 @@ p = manikde!(M, X1) q = manikde!(M, X2) # check new MKD have right type info cached -@test (p._u0 |> typeof) == typeof(u0) -@test (p._u0 |> typeof) == typeof(u0) +@test_broken (p._u0 |> typeof) == typeof(u0) pq = manifoldProduct([p;q], M) # check new product also has right point type info cached -@test (pq._u0 |> typeof) == typeof(u0) +@test_broken (pq._u0 |> typeof) == typeof(u0) ## From b7e8fbb8b12d00818f7ef2baee3cfda4955320e8 Mon Sep 17 00:00:00 2001 From: dehann Date: Mon, 23 Oct 2023 01:02:53 -0700 Subject: [PATCH 2/3] bug fix on ker --- src/services/KernelEval.jl | 2 +- test/basic_se3.jl | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/services/KernelEval.jl b/src/services/KernelEval.jl index 3ecc4ec..ab38ff3 100644 --- a/src/services/KernelEval.jl +++ b/src/services/KernelEval.jl @@ -53,5 +53,5 @@ end Normal kernel used for Hilbert space embeddings. """ -ker(M::AbstractManifold, p, q, sigma::Real=0.001) = exp( -sigma*(distance(M, p, q)) ) +ker(M::AbstractManifold, p, q, sigma::Real=0.001) = exp( -sigma*(distance(M, p, q)^2) ) diff --git a/test/basic_se3.jl b/test/basic_se3.jl index 3c8eca3..9991568 100644 --- a/test/basic_se3.jl +++ b/test/basic_se3.jl @@ -41,9 +41,20 @@ A = ManifoldKernelDensity(M, pts1) B = ManifoldKernelDensity(M, pts2) C = ManifoldKernelDensity(M, pts3) + +@test 0.75 < AMP.ker(M, pts1[1], pts1[2], 0.001) < 1.25 +@test 0.75 < AMP.ker(M, pts2[1], pts2[2], 0.001) < 1.25 +@test 0.75 < AMP.ker(M, pts3[1], pts3[2], 0.001) < 1.25 + +@test 0 < AMP.ker(M, pts1[1], pts3[1], 0.001) < 0.25 +@test 0 < AMP.ker(M, pts1[1], pts3[2], 0.001) < 0.25 +@test 0 < AMP.ker(M, pts1[2], pts3[1], 0.001) < 0.25 + + @test isapprox(A, B) @test !isapprox(A, C) + ## show(A) From 8cde06d3d540fea4e3258cf8338108ff4903dd4c Mon Sep 17 00:00:00 2001 From: dehann Date: Wed, 25 Oct 2023 00:44:33 -0700 Subject: [PATCH 3/3] Manifolds 0.9, wip MT eval --- Project.toml | 2 +- src/Deprecated.jl | 10 ++++++++++ src/Interface.jl | 6 +----- src/services/ManellicTree.jl | 2 +- src/services/ManifoldPartials.jl | 6 +++--- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 7e300fb..af790c8 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ CoordinateTransformations = "0.5, 0.6" Distributions = "0.25" DocStringExtensions = "0.7, 0.8, 0.9" KernelDensityEstimate = "0.5.10" -Manifolds = "0.8, 0.9" +Manifolds = "0.9" ManifoldsBase = "0.14, 0.15" NLsolve = "3, 4" Optim = "1" diff --git a/src/Deprecated.jl b/src/Deprecated.jl index 64e7587..2d31a9f 100644 --- a/src/Deprecated.jl +++ b/src/Deprecated.jl @@ -1,4 +1,14 @@ +## ====================================================================================================== +## Remove below before v0.10 +## ====================================================================================================== + +# function setPointsMani!(dest::ProductRepr, src::ProductRepr) +# for (k,prt) in enumerate(dest.parts) +# setPointsMani!(prt, src.parts[k]) +# end +# end + ## ====================================================================================================== ## Remove below before v0.8 ## ====================================================================================================== diff --git a/src/Interface.jl b/src/Interface.jl index 8000671..4f04311 100644 --- a/src/Interface.jl +++ b/src/Interface.jl @@ -167,11 +167,7 @@ function setPointsMani!(dest::AbstractVector, src::AbstractVector{<:AbstractVect setPointsMani!(dest, src[1]) end -function setPointsMani!(dest::ProductRepr, src::ProductRepr) - for (k,prt) in enumerate(dest.parts) - setPointsMani!(prt, src.parts[k]) - end -end + diff --git a/src/services/ManellicTree.jl b/src/services/ManellicTree.jl index 604041d..edfde29 100644 --- a/src/services/ManellicTree.jl +++ b/src/services/ManellicTree.jl @@ -250,7 +250,7 @@ function buildTree_Manellic!( mtree = ManellicTree( M, r_PP, - MVector{len,Int}(weights), + MVector{len,Float64}(weights), MVector{len,Int}(1:len), MVector{len,lknlT}(undef), MVector{len,tknlT}(undef), diff --git a/src/services/ManifoldPartials.jl b/src/services/ManifoldPartials.jl index 15f4925..b668b4b 100644 --- a/src/services/ManifoldPartials.jl +++ b/src/services/ManifoldPartials.jl @@ -2,7 +2,7 @@ export getManifoldPartial # forcing ProductManifold to use ArrayPartition as accompanying representation -const _PartiableRepresentationProduct = Union{Nothing,<:ArrayPartition, <:ProductRepr} +const _PartiableRepresentationProduct = Union{Nothing,<:ArrayPartition} # forcing ProductManifold to use ArrayPartition as accompanying representation const _PartiableRepresentationFlat{T} = Union{Nothing,<:AbstractVector{T}} # More general representation for Manifold Factors or Groups @@ -39,7 +39,7 @@ function _getReprPartial( M::MB.AbstractManifold, return ret end -function _getReprPartial( M::Union{<:typeof(SpecialOrthogonal(2)), <:Rotations{2}}, +function _getReprPartial( M::Union{<:typeof(SpecialOrthogonal(2)), <:Rotations{ManifoldsBase.TypeParameter{Tuple{2}}}}, repr::AbstractMatrix{T}, partial::AbstractVector{Int}, # total partial from user over all Factors offset::Base.RefValue{Int}=Ref(0), @@ -77,7 +77,7 @@ function getManifoldPartial(M::Circle, return (M,repr) end -function getManifoldPartial(M::Rotations{2}, +function getManifoldPartial(M::Rotations{ManifoldsBase.TypeParameter{Tuple{2}}}, partial::AbstractVector{Int}, repr::_PartiableRepresentation=nothing, offset::Base.RefValue{Int}=Ref(0);