diff --git a/ext/ApproxManiProdGadflyExt.jl b/ext/ApproxManiProdGadflyExt.jl index e65ddbb..13a4931 100644 --- a/ext/ApproxManiProdGadflyExt.jl +++ b/ext/ApproxManiProdGadflyExt.jl @@ -5,7 +5,7 @@ using Colors using Manifolds using ApproxManifoldProducts: BallTreeDensity, ManifoldKernelDensity -import ApproxManifoldProducts: plotCircBeliefs, plotKDECircular, plotMKD +import ApproxManifoldProducts: plotCircBeliefs, plotKDECircular, plotMKD, addtheta, difftheta include("CircularPlotting.jl") diff --git a/src/ApproxManifoldProducts.jl b/src/ApproxManifoldProducts.jl index ac398e9..6dd72e0 100644 --- a/src/ApproxManifoldProducts.jl +++ b/src/ApproxManifoldProducts.jl @@ -32,7 +32,7 @@ import Random: rand import Base: *, isapprox, convert, show, eltype import LinearAlgebra: rotate!, det import Statistics: mean, std, cov, var -import KernelDensityEstimate: getPoints, getBW, evalAvgLogL +import KernelDensityEstimate: getPoints, getBW, evalAvgLogL, entropy const MB = ManifoldsBase const CTs = CoordinateTransformations diff --git a/src/services/ManellicTree.jl b/src/services/ManellicTree.jl index 2fdab54..2f1c2e9 100644 --- a/src/services/ManellicTree.jl +++ b/src/services/ManellicTree.jl @@ -9,7 +9,14 @@ # end # end -function Base.show(io::IO, mt::ManellicTree{M,D,N,KT}) where {M,D,N,KT} +getPoints( + mt::ManellicTree +) = view(mt.data, mt.permute) + +function Base.show( + io::IO, + mt::ManellicTree{M,D,N,KT} +) where {M,D,N,KT} printstyled(io, "ManellicTree{"; bold=true,color = :blue) println(io) printstyled(io, " M = ", M, color = :magenta) @@ -277,10 +284,12 @@ end function evaluate( mt::ManellicTree{M,D,N}, p, + bw_scl::Real = 1 ) where {M,D,N} dim = manifold_dimension(mt.manifold) sumval = 0.0 + # FIXME, brute force for loop for i in 1:N ekr = mt.leaf_kernels[i] nscl = 1/sqrt((2*pi)^dim * det(cov(ekr.p))) @@ -295,12 +304,13 @@ end function evalAvgLogL( mt::ManellicTree{M,D,N}, - epts::AbstractVector + epts::AbstractVector, + bw_scl::Real = 1 ) where {M,D,N} # TODO really slow brute force evaluation eL = MVector{length(epts),Float64}(undef) for (i,p) in enumerate(epts) - eL[i] = evaluate(mt, p) + eL[i] = evaluate(mt, p, bw_scl) end # set numerical tolerance floor ind = findall(isapprox.(0,eL; atol=1e-14)) @@ -313,6 +323,18 @@ end # return -Inf # end +entropy( + mt::ManellicTree, + bw_scl::Real = 1, +) = -evalAvgLogL(mt, getPoints(mt), bw_scl) + +leaveOneOutLogL( + mt::ManellicTree, + bw_scl::Real = 1, +) = entropy(mt, bw_scl) + + + # ## Pseudo code diff --git a/test/manellic/testManellicTree.jl b/test/manellic/testManellicTree.jl index fc3dcc6..38b66c4 100644 --- a/test/manellic/testManellicTree.jl +++ b/test/manellic/testManellicTree.jl @@ -161,8 +161,9 @@ mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=bw, kernel= AMP.evaluate(mtree, SA[0.0;]) - AMP.evalAvgLogL(mtree, [randn(1) for _ in 1:5]) +@show AMP.entropy(mtree) + ## end \ No newline at end of file