Skip to content

Commit

Permalink
overload statistics mean var std cov
Browse files Browse the repository at this point in the history
  • Loading branch information
dehann committed Jul 15, 2022
1 parent b10dc8e commit 2943e6e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
31 changes: 5 additions & 26 deletions src/ApproxManifoldProducts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ using StaticArrays
using Logging
using Statistics

import Random: rand

import Base: *, isapprox, convert
import LinearAlgebra: rotate!
import Statistics: mean
import Statistics: mean, std, cov, var
import KernelDensityEstimate: getPoints, getBW
import TransformUtils: rotate!

Expand All @@ -39,31 +41,8 @@ const CTs = CoordinateTransformations
# TODO temporary for initial version of on-manifold products
KDE.setForceEvalDirect!(true)

export
# new local features
AMP,
MKD,
AbstractManifold,
ManifoldKernelDensity,
get2DLambda,
get2DMu,
get2DMuMin,
resid2DLinear,
solveresid2DLinear!,
solveresid2DLinear,
*,
isapprox,

# APi and util functions
buildHybridManifoldCallbacks,
getKDEManifoldBandwidths,
manifoldProduct,
manikde!,
calcCovarianceBasic,
isPartial,
mean,
calcProductGaussians

# the exported API
include("ExportAPI.jl")

# internal features not exported
include("_BiMaps.jl")
Expand Down
32 changes: 32 additions & 0 deletions src/ExportAPI.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

export
# new local features
AMP,
MKD,
AbstractManifold,
ManifoldKernelDensity,
get2DLambda,
get2DMu,
get2DMuMin,
resid2DLinear,
solveresid2DLinear!,
solveresid2DLinear,
*,
isapprox,

# APi and util functions
buildHybridManifoldCallbacks,
getKDEManifoldBandwidths,
manifoldProduct,
manikde!,
calcCovarianceBasic,
isPartial,
calcProductGaussians

export getPoints, getBW, Ndim, Npts
export getKDERange, getKDEMax, getKDEMean, getKDEfit
export sample, rand, resample, kld, minkld
export calcMean
export mean, cov, std, var
export getInfoPerCoord, getBandwidth
export antimarginal
32 changes: 19 additions & 13 deletions src/services/ManifoldKernelDensity.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@

import Random: rand

export getPoints, getBW, Ndim, Npts
export getKDERange, getKDEMax, getKDEMean, getKDEfit
export sample, rand, resample, kld, minkld
export calcMean
export getInfoPerCoord, getBandwidth
export antimarginal


## ==========================================================================================
## helper functions to contruct MKD objects
## ==========================================================================================
Expand Down Expand Up @@ -86,14 +76,18 @@ manikde!( M::MB.AbstractManifold,
## a few utilities
## ==========================================================================================

function Statistics.mean(mkd::ManifoldKernelDensity, aspartial::Bool=true)
M = if aspartial && isPartial(mkd)
# TODO this should be a public method relating to getManifold
function _getManifoldFullOrPart(mkd::ManifoldKernelDensity, aspartial::Bool=true)
if aspartial && isPartial(mkd)
getManifoldPartial(mkd.manifold, mkd._partial)
else
mkd.manifold
end
end

mean(mkd.manifold, getPoints(mkd, aspartial))
function Statistics.mean(mkd::ManifoldKernelDensity, aspartial::Bool=true)
M = _getManifoldFullOrPart(mkd, aspartial)
mean(M, getPoints(mkd, aspartial))
end

"""
Expand All @@ -103,6 +97,18 @@ Alias for overloaded `Statistics.mean`.
"""
calcMean(mkd::ManifoldKernelDensity, aspartial::Bool=true) = mean(mkd, aspartial)

function Statistics.std(mkd::ManifoldKernelDensity, aspartial::Bool=true; kwargs...)
std(_getManifoldFullOrPart(mkd,aspartial), getPoints(mkd, aspartial); kwargs...)
end
function Statistics.var(mkd::ManifoldKernelDensity, aspartial::Bool=true; kwargs...)
var(_getManifoldFullOrPart(mkd,aspartial), getPoints(mkd, aspartial); kwargs...)
end

function Statistics.cov(mkd::ManifoldKernelDensity, aspartial::Bool=true; basis::Manifolds.AbstractBasis = Manifolds.DefaultOrthogonalBasis(), kwargs...)
return cov(_getManifoldFullOrPart(mkd,aspartial), getPoints(mkd, aspartial); basis, kwargs... )
end



_getFieldPartials(mkd::ManifoldKernelDensity{M,B,Nothing}, field::Function, aspartial::Bool=true) where {M,B} = field(mkd)

Expand Down

0 comments on commit 2943e6e

Please sign in to comment.