diff --git a/src/POMDPToolbox.jl b/src/POMDPToolbox.jl index b3fd996..ffbae3c 100644 --- a/src/POMDPToolbox.jl +++ b/src/POMDPToolbox.jl @@ -10,8 +10,8 @@ import POMDPs: Simulator, simulate import POMDPs: action, value, solve import POMDPs: actions, action_index, state_index, obs_index, iterator, sampletype, states, n_actions, n_states, observations, n_observations, discount, isterminal import POMDPs: generate_sr, initial_state +import POMDPs: implemented import Base: rand, rand!, mean, == -import DataStructures: CircularBuffer, isfull, capacity, push!, append! using ProgressMeter using StatsBase @@ -178,6 +178,9 @@ include("model/underlying_mdp.jl") # tools for distributions include("distributions/distributions_jl.jl") +export obs_weight +include("model/obs_weight.jl") + export weighted_iterator include("distributions/weighted_iteration.jl") diff --git a/src/model/obs_weight.jl b/src/model/obs_weight.jl new file mode 100644 index 0000000..1d1cc55 --- /dev/null +++ b/src/model/obs_weight.jl @@ -0,0 +1,87 @@ +# obs_weight is a shortcut function for getting the relative likelihood of an observation without having to construct the observation distribution. Useful for particle filtering +# maintained by @zsunberg + +""" + obs_weight(pomdp, sp, o) + obs_weight(pomdp, a, sp, o) + obs_weight(pomdp, s, a, sp, o) + +Return a weight proportional to the likelihood of receiving observation o from state sp (and a and s if they are present). + +This is a useful shortcut for particle filtering so that the observation distribution does not have to be represented. +""" +function obs_weight end + +@generated function obs_weight(p, s, a, sp, o) + ow_impl = :(obs_weight(p, a, sp, o)) + o_impl = :(pdf(observation(p, s, a, sp), o)) + if implemented(obs_weight, Tuple{p, a, sp, o}) + return ow_impl + elseif implemented(observation, Tuple{p, s, a, sp}) + return o_impl + else + return quote + try # trick to get the compiler to put the right backedges in + $ow_impl + $o_impl + catch + throw(MethodError(obs_weight, (p,s,a,sp,o))) + end + end + end +end + +@generated function obs_weight(p, a, sp, o) + ow_impl = :(obs_weight(p, sp, o)) + o_impl = :(pdf(observation(p, a, sp), o)) + if implemented(obs_weight, Tuple{p, sp, o}) + return ow_impl + elseif implemented(observation, Tuple{p, a, sp}) + return o_impl + else + return quote + try # trick to get the compiler to put the right backedges in + $ow_impl + $o_impl + catch + throw(MethodError(obs_weight, (p, a, sp, o))) + end + end + end +end + +@generated function obs_weight(p, sp, o) + impl = :(pdf(observation(p, sp), o)) + if implemented(observation, Tuple{p, sp}) + return impl + else + return quote + try # trick to get the compiler to put the right backedges in + $impl + catch + return :(throw(MethodError(obs_weight, (p, sp, o)))) + end + end + end +end + +function implemented(f::typeof(obs_weight), TT::Type) + m = which(f, TT) + if length(TT.parameters) == 5 + P, S, A, _, O = TT.parameters + reqs_met = implemented(observation, Tuple{P,S,A,S}) || implemented(obs_weight, Tuple{P,A,S,O}) + elseif length(TT.parameters) == 4 + P, A, S, O = TT.parameters + reqs_met = implemented(observation, Tuple{P,A,S}) || implemented(obs_weight, Tuple{P,S,O}) + elseif length(TT.parameters) == 3 + P, S, O = TT.parameters + reqs_met = implemented(observation, Tuple{P,S}) + else + return method_exists(f, TT) + end + if m.module == POMDPToolbox && !reqs_met + return false + else + true + end +end diff --git a/src/simulators/parallel.jl b/src/simulators/parallel.jl index 7de3512..e813fed 100644 --- a/src/simulators/parallel.jl +++ b/src/simulators/parallel.jl @@ -127,7 +127,7 @@ function run_parallel(process::Function, queue::AbstractVector; warn(""" run_parallel(...) was started with only 1 process, so simulations will be run in serial. - To supress this warning, use run_parallel(..., proc_warn=false). + To suppress this warning, use run_parallel(..., proc_warn=false). To use multiple processes, use addprocs() or the -p option (e.g. julia -p 4). """) diff --git a/test/runtests.jl b/test/runtests.jl index 03ced34..674b129 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,3 +29,4 @@ include("test_info.jl") include("test_k_previous_observations_belief.jl") include("test_fully_observable_pomdp.jl") include("test_underlying_mdp.jl") +include("test_obs_weight.jl") diff --git a/test/test_obs_weight.jl b/test/test_obs_weight.jl new file mode 100644 index 0000000..42caff8 --- /dev/null +++ b/test/test_obs_weight.jl @@ -0,0 +1,18 @@ +import POMDPToolbox: obs_weight +import POMDPs: observation + +struct P <: POMDP{Void, Void, Void} end + +@test !@implemented obs_weight(::P, ::Void, ::Void, ::Void, ::Void) +@test !@implemented obs_weight(::P, ::Void, ::Void, ::Void) +@test !@implemented obs_weight(::P, ::Void, ::Void) + +obs_weight(::P, ::Void, ::Void, ::Void) = 1.0 +@test @implemented obs_weight(::P, ::Void, ::Void, ::Void) +@test @implemented obs_weight(::P, ::Void, ::Void, ::Void, ::Void) +@test !@implemented obs_weight(::P, ::Void, ::Void) + +@test obs_weight(P(), nothing, nothing, nothing, nothing) == 1.0 + +observation(::P, ::Void) = nothing +@test @implemented obs_weight(::P, ::Void, ::Void)