Skip to content
This repository has been archived by the owner on Apr 26, 2023. It is now read-only.

added obs_weight (JuliaPOMDP/POMDPs.jl#172) #71

Merged
merged 1 commit into from
May 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/POMDPToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import POMDPs: Simulator, simulate
import POMDPs: action, value, solve
import POMDPs: actions, action_index, state_index, obs_index, iterator, sampletype, states, n_actions, n_states, observations, n_observations, discount, isterminal
import POMDPs: generate_sr, initial_state
import POMDPs: implemented
import Base: rand, rand!, mean, ==
import DataStructures: CircularBuffer, isfull, capacity, push!, append!

using ProgressMeter
using StatsBase
Expand Down Expand Up @@ -178,6 +178,9 @@ include("model/underlying_mdp.jl")
# tools for distributions
include("distributions/distributions_jl.jl")

export obs_weight
include("model/obs_weight.jl")

export
weighted_iterator
include("distributions/weighted_iteration.jl")
Expand Down
87 changes: 87 additions & 0 deletions src/model/obs_weight.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# obs_weight is a shortcut function for getting the relative likelihood of an observation without having to construct the observation distribution. Useful for particle filtering
# maintained by @zsunberg

"""
obs_weight(pomdp, sp, o)
obs_weight(pomdp, a, sp, o)
obs_weight(pomdp, s, a, sp, o)

Return a weight proportional to the likelihood of receiving observation o from state sp (and a and s if they are present).

This is a useful shortcut for particle filtering so that the observation distribution does not have to be represented.
"""
function obs_weight end

@generated function obs_weight(p, s, a, sp, o)
ow_impl = :(obs_weight(p, a, sp, o))
o_impl = :(pdf(observation(p, s, a, sp), o))
if implemented(obs_weight, Tuple{p, a, sp, o})
return ow_impl
elseif implemented(observation, Tuple{p, s, a, sp})
return o_impl
else
return quote
try # trick to get the compiler to put the right backedges in
$ow_impl
$o_impl
catch
throw(MethodError(obs_weight, (p,s,a,sp,o)))
end
end
end
end

@generated function obs_weight(p, a, sp, o)
ow_impl = :(obs_weight(p, sp, o))
o_impl = :(pdf(observation(p, a, sp), o))
if implemented(obs_weight, Tuple{p, sp, o})
return ow_impl
elseif implemented(observation, Tuple{p, a, sp})
return o_impl
else
return quote
try # trick to get the compiler to put the right backedges in
$ow_impl
$o_impl
catch
throw(MethodError(obs_weight, (p, a, sp, o)))
end
end
end
end

@generated function obs_weight(p, sp, o)
impl = :(pdf(observation(p, sp), o))
if implemented(observation, Tuple{p, sp})
return impl
else
return quote
try # trick to get the compiler to put the right backedges in
$impl
catch
return :(throw(MethodError(obs_weight, (p, sp, o))))
end
end
end
end

function implemented(f::typeof(obs_weight), TT::Type)
m = which(f, TT)
if length(TT.parameters) == 5
P, S, A, _, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,S,A,S}) || implemented(obs_weight, Tuple{P,A,S,O})
elseif length(TT.parameters) == 4
P, A, S, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,A,S}) || implemented(obs_weight, Tuple{P,S,O})
elseif length(TT.parameters) == 3
P, S, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,S})
else
return method_exists(f, TT)
end
if m.module == POMDPToolbox && !reqs_met
return false
else
true
end
end
2 changes: 1 addition & 1 deletion src/simulators/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function run_parallel(process::Function, queue::AbstractVector;
warn("""
run_parallel(...) was started with only 1 process, so simulations will be run in serial.

To supress this warning, use run_parallel(..., proc_warn=false).
To suppress this warning, use run_parallel(..., proc_warn=false).

To use multiple processes, use addprocs() or the -p option (e.g. julia -p 4).
""")
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ include("test_info.jl")
include("test_k_previous_observations_belief.jl")
include("test_fully_observable_pomdp.jl")
include("test_underlying_mdp.jl")
include("test_obs_weight.jl")
18 changes: 18 additions & 0 deletions test/test_obs_weight.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import POMDPToolbox: obs_weight
import POMDPs: observation

struct P <: POMDP{Void, Void, Void} end

@test !@implemented obs_weight(::P, ::Void, ::Void, ::Void, ::Void)
@test !@implemented obs_weight(::P, ::Void, ::Void, ::Void)
@test !@implemented obs_weight(::P, ::Void, ::Void)

obs_weight(::P, ::Void, ::Void, ::Void) = 1.0
@test @implemented obs_weight(::P, ::Void, ::Void, ::Void)
@test @implemented obs_weight(::P, ::Void, ::Void, ::Void, ::Void)
@test !@implemented obs_weight(::P, ::Void, ::Void)

@test obs_weight(P(), nothing, nothing, nothing, nothing) == 1.0

observation(::P, ::Void) = nothing
@test @implemented obs_weight(::P, ::Void, ::Void)