Skip to content

Commit

Permalink
Merge pull request #125 from JuliaRobotics/21Q3/fix/testpartialbw
Browse files Browse the repository at this point in the history
fix tests
  • Loading branch information
dehann committed Jul 27, 2021
2 parents d4bc643 + 60929eb commit f77d29b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/services/ManifoldKernelDensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ function ManifoldKernelDensity( M::MB.AbstractManifold,
manis = convert(Tuple, M)
# find or have the bandwidth
_bw = bw === nothing ? getKDEManifoldBandwidths(arr, manis ) : bw
# NOTE workaround for partials and user did not specify a bw
if bw === nothing && partial !== nothing
mask = ones(Int, length(_bw)) .== 1
mask[partial] .= false
_bw[mask] .= 1.0
end
addopT, diffopT, _, _ = buildHybridManifoldCallbacks(manis)
bel = KernelDensityEstimate.kde!(arr, _bw, addopT, diffopT)
return ManifoldKernelDensity(M, bel, partial, u0, infoPerCoord)
Expand Down
15 changes: 8 additions & 7 deletions test/testMarginalProducts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,18 @@ d = 2
M = TranslationGroup(d)

pts4 = [randn(d) .- 10.0 for _ in 1:N]
(x->x[2]=-100.0).(pts4)
(x->x[2]-=90.0).(pts4)
pts5 = [randn(d) .+ 10.0 for _ in 1:N]
(x->x[1]=100.0).(pts5)
(x->x[1]+=90.0).(pts5)

P4 = marginal(manikde!(M, pts4), [1;])
P5 = marginal(manikde!(M, pts5), [d;])

# test duplication
pts4_ = [randn(d) .- 10.0 for _ in 1:N]
(x->x[2]=-100.0).(pts4_)
(x->x[2]-=90.0).(pts4_)
pts5_ = [randn(d) .+ 10.0 for _ in 1:N]
(x->x[1]=100.0).(pts5_)
(x->x[1]+=90.0).(pts5_)

P4_ = marginal(manikde!(M, pts4_), [1;])
P5_ = marginal(manikde!(M, pts5_), [d;])
Expand All @@ -324,10 +324,11 @@ P45__

## check the selection of labels and resulting Gaussian products are correct

println("getPoints(P45__) = ")
getPoints(P45__) .|> println
println()
# println("getPoints(P45__) = ")
# getPoints(P45__) .|> println
# println()

# sidx = 1
for sidx in 1:N

bw1 = getBW(P4)[:,1] .^2
Expand Down

0 comments on commit f77d29b

Please sign in to comment.