Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POMDPs v0.8 Compatibility #26

Merged
merged 10 commits into from
Sep 20, 2019
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[compat]
Distributions = ">= 0.17"
POMDPs = "0.7.3, 0.9.0"
POMDPs = "0.7.3, 0.8.0"
julia = "1"

[extras]
Expand Down
7 changes: 4 additions & 3 deletions src/POMDPModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ include("visualization.jl")

# info interface
export
generate_sri,
generate_sori,
add_infonode,
action_info,
solve_info,
update_info
update_info,
generate_sri,
generate_sori
include("info.jl")

export
Expand Down
64 changes: 46 additions & 18 deletions src/info.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,6 @@
# functions for passing out info from simulations, similar to the info return from openai gym
# maintained by @zsunberg

"""
Return a tuple containing the next state and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step.

By default, returns `nothing` as info.
"""
function generate_sri(p::MDP, s, a, rng::AbstractRNG)
return generate_sr(p, s, a, rng)..., nothing
end

"""
Return a tuple containing the next state, observation, and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step.

By default, returns `nothing` as info.
"""
function generate_sori(p::POMDP, s, a, rng::AbstractRNG)
return generate_sor(p, s, a, rng)..., nothing
end

"""
a, ai = action_info(policy, x)

Expand Down Expand Up @@ -51,3 +33,49 @@ By default, returns `nothing` as info.
function update_info(up::Updater, b, a, o)
return update(up, b, a, o), nothing
end

"""
add_infonode(ddn::DDNStructure)

Create a new DDNStructure object with a new node labeled :info with parents :s and :a
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be valuable to have the example from the conversation in here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 it is in the docstring now

"""
function add_infonode(ddn) # for DDNStructure, but it is not declared in v0.7.3, so there is not annotation
add_node(ddn, :info, ConstantDDNNode(nothing), (:s, :a))
end

function add_infonode(ddn::POMDPs.DDNStructureV7{nodenames}) where nodenames
return POMDPs.DDNStructureV7{(nodenames..., :info)}()
end

###############################################################
# Note all generate functions will be deprecated in POMDPs v0.8
###############################################################


if DDNStructure(MDP) isa POMDPs.DDNStructureV7
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we check on the POMDPs.jl version instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I initially though, but it didn't seem like there was a good way to do that, so this seems like a reliable enough proxy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if Pkg.installed()["POMDPs"] < v"0.8.0"
something like that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to reply to your comment on this below.

The problem with Pkg.installed() is that it can take a really long time (or at least it could in the past)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and it adds Pkg as a dependency (which is probably fine)

"""
Return a tuple containing the next state and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step.

By default, returns `nothing` as info.
"""
function generate_sri(p::MDP, s, a, rng::AbstractRNG)
return generate_sr(p, s, a, rng)..., nothing
end

"""
Return a tuple containing the next state, observation, and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step.

By default, returns `nothing` as info.
"""
function generate_sori(p::POMDP, s, a, rng::AbstractRNG)
return generate_sor(p, s, a, rng)..., nothing
end

POMDPs.gen(::DDNOut{(:sp,:o,:r,:i)}, m, s, a, rng) = generate_sori(m, s, a, rng)
POMDPs.gen(::DDNOut{(:sp,:o,:r,:info)}, m, s, a, rng) = generate_sori(m, s, a, rng)
POMDPs.gen(::DDNOut{(:sp,:r,:i)}, m, s, a, rng) = generate_sri(m, s, a, rng)
POMDPs.gen(::DDNOut{(:sp,:r,:info)}, m, s, a, rng) = generate_sri(m, s, a, rng)
else
@deprecate generate_sri(args...) gen(DDNOut(:sp,:r,:info), args...)
@deprecate generate_sori(args...) gen(DDNOut(:sp,:o,:r,:info), args...)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ using SparseArrays
include("test_tabular.jl")
end

end
end
2 changes: 1 addition & 1 deletion test/test_fully_observable_pomdp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
let
mdp = GridWorld()
mdp = SimpleGridWorld()

pomdp = FullyObservablePOMDP(mdp)

Expand Down
12 changes: 9 additions & 3 deletions test/test_info.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,20 @@ let
rng = MersenneTwister(7)

mdp = LegacyGridWorld()
POMDPs.DDNStructure(::Type{typeof(mdp)}) = DDNStructure(MDP) |> add_infonode
@test :info in nodenames(DDNStructure(mdp))
s = initialstate(mdp, rng)
a = rand(rng, actions(mdp))
@inferred generate_sri(mdp, s, a, rng)
sp, r, i = @inferred gen(DDNOut(:sp,:r,:info), mdp, s, a, rng)
@test i === nothing

pomdp = TigerPOMDP()
POMDPs.DDNStructure(::Type{typeof(pomdp)}) = DDNStructure(POMDP) |> add_infonode
@test :info in nodenames(DDNStructure(pomdp))
s = initialstate(pomdp, rng)
a = rand(rng, actions(pomdp))
@inferred generate_sori(pomdp, s, a, rng)
sp, o, r, i = @inferred gen(DDNOut(:sp,:o,:r,:info), pomdp, s, a, rng)
@test i === nothing

up = VoidUpdater()
policy = RandomPolicy(rng, pomdp)
Expand All @@ -43,6 +49,6 @@ let
d = initialstate_distribution(pomdp)
b = initialize_belief(up, d)
a = action(policy, b)
sp, o = generate_so(pomdp, rand(rng, d), a, rng)
sp, o, r = gen(DDNOut(:sp,:o,:r), pomdp, rand(rng, d), a, rng)
@inferred update_info(up, b, a, o)
end