Skip to content

Commit

Permalink
Merge pull request #271 from JuliaRobotics/24Q1/enh/manelicentropy
Browse files Browse the repository at this point in the history
add entropy test and plot bugfix
  • Loading branch information
dehann committed Jan 21, 2024
2 parents 61671c2 + 622fc6a commit c4f4169
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ext/ApproxManiProdGadflyExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Colors
using Manifolds
using ApproxManifoldProducts: BallTreeDensity, ManifoldKernelDensity

import ApproxManifoldProducts: plotCircBeliefs, plotKDECircular, plotMKD
import ApproxManifoldProducts: plotCircBeliefs, plotKDECircular, plotMKD, addtheta, difftheta


include("CircularPlotting.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/ApproxManifoldProducts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import Random: rand
import Base: *, isapprox, convert, show, eltype
import LinearAlgebra: rotate!, det
import Statistics: mean, std, cov, var
import KernelDensityEstimate: getPoints, getBW, evalAvgLogL
import KernelDensityEstimate: getPoints, getBW, evalAvgLogL, entropy

const MB = ManifoldsBase
const CTs = CoordinateTransformations
Expand Down
28 changes: 25 additions & 3 deletions src/services/ManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
# end
# end

function Base.show(io::IO, mt::ManellicTree{M,D,N,KT}) where {M,D,N,KT}
getPoints(
mt::ManellicTree
) = view(mt.data, mt.permute)

function Base.show(
io::IO,
mt::ManellicTree{M,D,N,KT}
) where {M,D,N,KT}
printstyled(io, "ManellicTree{"; bold=true,color = :blue)
println(io)
printstyled(io, " M = ", M, color = :magenta)
Expand Down Expand Up @@ -277,10 +284,12 @@ end
function evaluate(
mt::ManellicTree{M,D,N},
p,
bw_scl::Real = 1
) where {M,D,N}

dim = manifold_dimension(mt.manifold)
sumval = 0.0
# FIXME, brute force for loop
for i in 1:N
ekr = mt.leaf_kernels[i]
nscl = 1/sqrt((2*pi)^dim * det(cov(ekr.p)))
Expand All @@ -295,12 +304,13 @@ end

function evalAvgLogL(
mt::ManellicTree{M,D,N},
epts::AbstractVector
epts::AbstractVector,
bw_scl::Real = 1
) where {M,D,N}
# TODO really slow brute force evaluation
eL = MVector{length(epts),Float64}(undef)
for (i,p) in enumerate(epts)
eL[i] = evaluate(mt, p)
eL[i] = evaluate(mt, p, bw_scl)
end
# set numerical tolerance floor
ind = findall(isapprox.(0,eL; atol=1e-14))
Expand All @@ -313,6 +323,18 @@ end
# return -Inf
# end

entropy(
mt::ManellicTree,
bw_scl::Real = 1,
) = -evalAvgLogL(mt, getPoints(mt), bw_scl)

leaveOneOutLogL(
mt::ManellicTree,
bw_scl::Real = 1,
) = entropy(mt, bw_scl)




# ## Pseudo code

Expand Down
3 changes: 2 additions & 1 deletion test/manellic/testManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=bw, kernel=

AMP.evaluate(mtree, SA[0.0;])


AMP.evalAvgLogL(mtree, [randn(1) for _ in 1:5])

@show AMP.entropy(mtree)

##
end

0 comments on commit c4f4169

Please sign in to comment.