Skip to content

Commit

Permalink
fixed error in BasicPOMCP minimal example
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Nov 13, 2018
1 parent 09160cf commit 435d745
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
41 changes: 29 additions & 12 deletions src/pomdps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,35 @@ end

function initialize_belief(up::BasicParticleFilter, d::D) where D
# using weighted iterator here is more likely to be order n than just calling rand() repeatedly
# but, this implementation may change in the future
if @implemented(support(::D)) && @implemented(pdf(::D, ::typeof(first(support(d)))))
# if @implemented(weighted_iterator(::D))
S = typeof(first(support(d)))
particles = S[]
weights = Float64[]
for (s, w) in weighted_iterator(d)
push!(particles, s)
push!(weights, w)
# but, this implementation is problematic and may change in the future
try
if @implemented(support(::D)) &&
@implemented(iterate(::typeof(support(d)))) &&
@implemented(pdf(::D, ::typeof(first(support(d)))))
S = typeof(first(support(d)))
particles = S[]
weights = Float64[]
for (s, w) in weighted_iterator(d)
push!(particles, s)
push!(weights, w)
end
return resample(ImportanceResampler(up.n_init), WeightedParticleBelief(particles, weights), up.rng)
end
catch ex
if ex isa MethodError
@warn("""
Suppressing MethodError in initialize_belief in ParticleFilters.jl. Please file an issue here:
https://github.com/JuliaPOMDP/ParticleFilters.jl/issues/new
The error was
$(sprint(showerror, ex))
""", maxlog=1)
else
rethrow(ex)
end
return resample(ImportanceResampler(up.n_init), WeightedParticleBelief(particles, weights), up.rng)
else
return ParticleCollection(collect(rand(up.rng, d) for i in 1:up.n_init))
end

return ParticleCollection(collect(rand(up.rng, d) for i in 1:up.n_init))
end
13 changes: 9 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ using Test
using POMDPPolicies
using POMDPSimulators
using Random
import ParticleFilters: obs_weight
import POMDPs: observation
using Distributions
using NBInclude

struct P <: POMDP{Nothing, Nothing, Nothing} end
Expand All @@ -15,7 +14,7 @@ struct P <: POMDP{Nothing, Nothing, Nothing} end
@test !@implemented obs_weight(::P, ::Nothing, ::Nothing, ::Nothing)
@test !@implemented obs_weight(::P, ::Nothing, ::Nothing)
end
obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) = 1.0
ParticleFilters.obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) = 1.0

@testset "implemented" begin
@test @implemented obs_weight(::P, ::Nothing, ::Nothing, ::Nothing)
Expand All @@ -24,12 +23,13 @@ obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) = 1.0
@test obs_weight(P(), nothing, nothing, nothing, nothing) == 1.0
end

observation(::P, ::Nothing) = nothing
POMDPs.observation(::P, ::Nothing) = nothing
@test @implemented obs_weight(::P, ::Nothing, ::Nothing)

include("example.jl")
include("domain_specific_resampler.jl")

struct ContinuousPOMDP <: POMDP{Float64, Float64, Float64} end
@testset "infer" begin
p = TigerPOMDP()
filter = SIRParticleFilter(p, 10000)
Expand Down Expand Up @@ -70,6 +70,11 @@ include("domain_specific_resampler.jl")
wp2 = @inferred collect(weighted_particles(WeightedParticleBelief([1,2], [0.5, 0.5])))
@test wp1 == wp2
end

@testset "normal" begin
pf = SIRParticleFilter(ContinuousPOMDP(), 100)
ps = @inferred initialize_belief(pf, Normal())
end
end

@testset "alpha" begin
Expand Down

0 comments on commit 435d745

Please sign in to comment.