Skip to content

Commit

Permalink
Merge 5a30cbc into a12c9b9
Browse files Browse the repository at this point in the history
  • Loading branch information
Omastto1 committed Apr 21, 2021
2 parents a12c9b9 + 5a30cbc commit d32ad0e
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 52 deletions.
4 changes: 1 addition & 3 deletions Project.toml
@@ -1,17 +1,15 @@
name = "FiniteHorizonPOMDPs"
uuid = "8a13bbfe-798e-11e9-2f1c-eba9ee5ef093"
authors = ["Tomas Omasta <omastto1@fel.cvut.cz> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755"
POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
BeliefUpdaters = "0.2"
POMDPLinter = "0.1"
POMDPModelTools = "0.3"
POMDPs = "0.9"
Expand Down
29 changes: 24 additions & 5 deletions README.md
Expand Up @@ -15,20 +15,39 @@ The goals are to

Notably, in accordance with goal (4), this package does **not** define something like `AbstractFiniteHorizonPOMDP`.

## Use
## Usage
Package offers interface for finite horizon POMDPs.
Solver currently supports only MDPs.
User can either implement:
- finite horizon MDP using both POMDPs.jl and FiniteHorizonPOMDPs.jl interface functions or
- infinite horizon MDP and transform it to finite horizon one using `fixhorizon` utility

```
using FiniteHorizonPOMDPs
import POMDPModels
gw = SimpleGridWorld() # initialize Infinite Horizon model
fhgw = fixhorizon(gw, 2) # use fixhorizon utility to transform it to Finite Horizon
```

## Interface

- `HorizonLength(::Type{<:Union{MDP,POMDP}) = InfiniteHorizon()`
- `FiniteHorizon`
- `InfiniteHorizon`
- `HorizonLength(::Type{<:Union{MDP,POMDP})`
- Checks whether MDP is Finite or Infinite Horizon and return corresponding struct (FiniteHorizon or InfiniteHorizon).

- `horizon(m::Union{MDP,POMDP})::Int`
- Returns the number of *steps* that will be taken in the (PO)MDP, given it is Finite Horizon.
- `stage(m::Union{MDP,POMDP}, ss)::Int`
- Returns the number of input variable's stage.
- `stage_states(m::Union{MDP,POMDP}, stage::Int)`
- Creates (PO)MDP's states for given stage.
- `stage_stateindex(m::Union{MDP,POMDP}, state)`
- `stage(m::Union{MDP,POMDP}, state)`
- Computes the index of the given state in the corresponding stage.
- `ordered_stage_states(w::FHWrapper, stage::Int)`
- Returns an `AbstractVector` of states from given stage ordered according to `stage_stateindex(mdp, s)`.
- `stage_observations(m::Union{MDP,POMDP}, stage::Int)`
- Creates (PO)MDP's observations for given stage.
- `stage_obsindex(m::Union{MDP,POMDP}, o::stage::Int)`
- Computes the index of the given observation in the corresponding stage.
- `ordered_stage_observations(w::FHWrapper, stage::Int)`
- Returns an `AbstractVector` of observations from given stage ordered according to `stage_obsindex(w,o)`.
9 changes: 4 additions & 5 deletions src/FiniteHorizonPOMDPs.jl
Expand Up @@ -3,20 +3,19 @@ module FiniteHorizonPOMDPs
using POMDPs
using POMDPModelTools
using Random: Random, AbstractRNG
using BeliefUpdaters


export
stage,
stage_states,
stage_stateindex,
HorizonLength,
FiniteHorizon,
InfiniteHorizon,
horizon,
stage,
stage_states,
stage_stateindex,
ordered_stage_states,
stage_observations,
stage_obsindex,
ordered_stage_states,
ordered_stage_observations

include("interface.jl")
Expand Down
67 changes: 49 additions & 18 deletions src/fixhorizon.jl
@@ -1,4 +1,3 @@
# TODO: Docstring
"""
fixhorizon(m::Union{MDP,POMDP}, horizon::Int)
Expand Down Expand Up @@ -60,6 +59,11 @@ Mark the state as terminal if its stage number if greater than horizon, else let
"""
POMDPs.isterminal(w::FHWrapper, ss::Tuple{<:Any,Int}) = stage(w, ss) > horizon(w) || isterminal(w.m, first(ss))

"""
POMDPs.gen(w::FHWrapper, ss::Tuple{<:Any,Int}, a, rng::AbstractRNG)
Implement the entire MDP/POMDP generative model by returning a NamedTuple.
"""
function POMDPs.gen(w::FHWrapper, ss::Tuple{<:Any,Int}, a, rng::AbstractRNG)
out = gen(w.m, first(ss), a, rng)
if haskey(out, :sp)
Expand All @@ -77,6 +81,12 @@ Wrap the transition result of Infinite Horizon MDP with stage number.
POMDPs.transition(w::FHWrapper, ss::Tuple{<:Any,Int}, a) = InStageDistribution(transition(w.m, first(ss), a), stage(w, ss)+1)
# TODO: convert_s

"""
POMDPs.actions(w::FHWrapper, ss::Tuple{<:Any,Int})
Return the actions of Infinite Horizon (PO)MDP.
This method assumes similar actions for all stages.
"""
POMDPs.actions(w::FHWrapper, ss::Tuple{<:Any,Int}) = actions(w.m, first(ss))

"""
Expand All @@ -93,12 +103,12 @@ Create a product of Infinite Horizon MDP's observations with all not-terminal st
"""
POMDPs.observations(w::FixedHorizonPOMDPWrapper) = Iterators.product(observations(w.m), 1:horizon(w))

stage_observations(w::FixedHorizonPOMDPWrapper, stage::Int) = Iterators.product(observations(w.m), stage)

stage_obsindex(w::FixedHorizonPOMDPWrapper, o::Tuple{<:Any,Int}) = obsindex(w.m, first(o))
"""
POMDPs.obsindex(w::FixedHorizonPOMDPWrapper, o::Tuple{<:Any, Int})::Int
# TODO: Write Docstring
function POMDPs.obsindex(w::FixedHorizonPOMDPWrapper, o::Tuple{<:Any, Int})
Compute the index of the given observation in the Finite Horizon observation space (meaning in observation space of all stages).
"""
function POMDPs.obsindex(w::FixedHorizonPOMDPWrapper, o::Tuple{<:Any, Int})::Int
s, k = o
return (k-1)*length(stage_observations(w, 1)) + obsindex(w.m, s)
end
Expand All @@ -109,7 +119,7 @@ end
Create a product of Infinite Horizon MDP's observations given destination state and action (and original state) with original state's stage.
"""
POMDPs.observation(w::FixedHorizonPOMDPWrapper, ss::Tuple{<:Any,Int}, a, ssp::Tuple{<:Any, Int}) = InStageDistribution(observation(w.m, first(ss), a, first(ssp)), stage(w, ss))
POMDPs.observation(w::FixedHorizonPOMDPWrapper, a, ssp::Tuple{<:Any, Int}) = InStageDistribution(observation(w.m, a, first(ssp)), last(ssp)-1)
POMDPs.observation(w::FixedHorizonPOMDPWrapper, a, ssp::Tuple{<:Any, Int}) = InStageDistribution(observation(w.m, a, first(ssp)), stage(w, ssp)-1)

"""
POMDPs.initialstate(w::FHWrapper)
Expand All @@ -125,6 +135,10 @@ POMDPs.initialobs(w::FixedHorizonPOMDPWrapper, ss::Tuple{<:Any,Int}) = initialob
stage(w::FHWrapper, ss::Tuple{<:Any,Int}) = last(ss)
stage_states(w::FHWrapper, stage::Int) = Iterators.product(states(w.m), stage)
stage_stateindex(w::FHWrapper, ss::Tuple{<:Any,Int}) = stateindex(w.m, first(ss))
stage_observations(w::FixedHorizonPOMDPWrapper, stage::Int) = Iterators.product(observations(w.m), stage)
stage_obsindex(w::FixedHorizonPOMDPWrapper, o::Tuple{<:Any,Int}) = obsindex(w.m, first(o))
ordered_stage_states(w::FHWrapper, stage::Int) = POMDPModelTools.ordered_vector(statetype(typeof(w)), s->stage_stateindex(w,s), stage_states(w, stage), "stage_state")
ordered_stage_observations(w::FHWrapper, stage::Int) = POMDPModelTools.ordered_vector(obstype(typeof(w)), o->stage_obsindex(w,o), stage_observations(w, stage), "stage_observation")

###############################
# Forwarded parts of POMDPs interface
Expand All @@ -135,37 +149,54 @@ POMDPs.reward(w::FHWrapper, ss::Tuple{<:Any,Int}, a) = reward(w.m, first(ss), a)
POMDPs.actions(w::FHWrapper) = actions(w.m)
POMDPs.actionindex(w::FHWrapper, a) = actionindex(w.m, a)
POMDPs.discount(w::FHWrapper) = discount(w.m)
POMDPModelTools.ordered_actions(w::FHWrapper) = ordered_actions(w.m)
# TODO: convert_a

#################################
# distribution with a fixed stage
#################################
"""
InStageDistribution{D}
# TODO: Define access functions for InStageDistribution - to access with method instead of .d or .stage
Wrap given distribution with a given stage
"""
struct InStageDistribution{D}
d::D
stage::Int
end

function BeliefUpdaters.DiscreteBelief(pomdp, b::InStageDistribution; check::Bool=true)
return DiscreteBelief(pomdp, b.d; check)
"""
distrib(d::InStageDistribution{D})::D
Return distrubution wrapped in InStageDistribution without stage
"""
function distrib(d::InStageDistribution)
return d.d
end

"""
stage(d::InStageDistribution)
Return stage of InStageDistribution
"""
stage(d::InStageDistribution) = d.stage

Base.rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:InStageDistribution}) = (rand(rng, s[].d), s[].stage)

function POMDPs.pdf(d::InStageDistribution, ss::Tuple{<:Any, Int})
s, k = ss
if k == d.stage
return pdf(d.d, s)
if k == stage(d)
return pdf(distrib(d), s)
else
return 0.0
end
end

POMDPs.mean(d::InStageDistribution) = (mean(d.d), d.stage)
POMDPs.mode(d::InStageDistribution) = (mode(d.d), d.stage)
POMDPs.support(d::InStageDistribution) = Iterators.product(support(d.d), d.stage)
POMDPs.mean(d::InStageDistribution) = (mean(distrib(d)), stage(d))
POMDPs.mode(d::InStageDistribution) = (mode(distrib(d)), stage(d))
POMDPs.support(d::InStageDistribution) = Iterators.product(support(distrib(d)), stage(d))
POMDPs.rand(r::AbstractRNG, d::FiniteHorizonPOMDPs.InStageDistribution) = (rand(r, distrib(d)), stage(d))

ordered_stage_states(w::FHWrapper, stage::Int) = POMDPModelTools.ordered_vector(statetype(typeof(w)), s->stage_stateindex(w,s), stage_states(w, stage), "stage_state")
ordered_stage_observations(w::FHWrapper, stage::Int) = POMDPModelTools.ordered_vector(obstype(typeof(w)), o->stage_obsindex(w,o), stage_observations(w, stage), "stage_observation")
#################################
# POMDPModelTools ordered actions
#################################
POMDPModelTools.ordered_actions(w::FHWrapper) = ordered_actions(w.m)
65 changes: 44 additions & 21 deletions src/interface.jl
@@ -1,50 +1,73 @@
"""
stage(m::Union{MDP,POMDP}, ss::MDPState)::Int
HorizonLength(::Type{<:Union{MDP,POMDP})
HorizonLength(::Union{MDP,POMDP})
Check whether MDP is Finite or Infinite Horizon and return corresponding struct (FiniteHorizon or InfiniteHorizon).
"""
abstract type HorizonLength end

"If HorizonLength(m::Union{MDP,POMDP}) == FiniteHorizon(), horizon(m) should be implemented and return an integer"
struct FiniteHorizon <: HorizonLength end
"HorizonLength(m::Union{MDP,POMDP}) == InfiniteHorizon() indicates that horizon(m) should not be called."
struct InfiniteHorizon <: HorizonLength end

HorizonLength(m::Union{MDP,POMDP}) = HorizonLength(typeof(m))
HorizonLength(::Type{<:Union{MDP,POMDP}}) = InfiniteHorizon()

"""
Return the number of *steps* that will be taken in the (PO)MDP, given it is Finite Horizon.
A simulation of a (PO)MDP with `horizon(m) == d` should contain *d+1* states and *d* actions and rewards.
"""
function horizon end

Return number of state's stage
"""
stage(m::Union{MDP,POMDP}, ss)::Int
stage(m::Union{MDP,POMDP}, o)::Int
stage(d)::Int
Considering a variable or distribution containing its stage assignment, return the number of its stage.
"""
function stage end

"""
stage_states(m::Union{MDP,POMDP}, stage::Int)
Create Infinite Horizon MDP's states for given stage.
Create (PO)MDP's states for given stage.
"""
function stage_states end

"""
stage_stateindex(m::Union{MDP,POMDP}, ss::MDPState}::Int
stage_stateindex(m::Union{MDP,POMDP}, ss}::Int
Compute the index of the given state in Infinite Horizon for given stage state space.
Compute the index of the given state in the corresponding stage.
"""
function stage_stateindex end

"""
HorizonLength(::Type{<:Union{MDP,POMDP})
HorizonLength(::Union{MDP,POMDP})
ordered_stage_states(w::FHWrapper, stage::Int)
Check whether MDP is Finite or Infinite Horizon and return corresponding struct (FiniteHorizon or InfiniteHorizon).
Return an AbstractVector of states from given stage ordered according to stage_stateindex(mdp, s).
"""
abstract type HorizonLength end
function ordered_stage_states end

"If HorizonLength(m::Union{MDP,POMDP}) == FiniteHorizon(), horizon(m) should be implemented and return an integer"
struct FiniteHorizon <: HorizonLength end
"HorizonLength(m::Union{MDP,POMDP}) == InfiniteHorizon() indicates that horizon(m) should not be called."
struct InfiniteHorizon <: HorizonLength end
"""
stage_observations(m::Union{MDP,POMDP}, stage::Int)
HorizonLength(m::Union{MDP,POMDP}) = HorizonLength(typeof(m))
HorizonLength(::Type{<:Union{MDP,POMDP}}) = InfiniteHorizon()
Create (PO)MDP's observations for given stage.
"""
function stage_observations end

"""
Return the number of *steps* that will be taken in the (PO)MDP, given it is Finite Horizon.
stage_obsindex(m::Union{MDP,POMDP}, o::stage::Int)
A simulation of a (PO)MDP with `horizon(m) == d` should contain *d+1* states and *d* actions and rewards.
Compute the index of the given observation in the corresponding stage.
"""
function horizon end
function stage_obsindex end

"""
stage_observations(m::Union{MDP,POMDP}, stage::Int)
ordered_stage_observations(w::FHWrapper, stage::Int)
Infinite Horizon MDP's observation for given stage.
Return an AbstractVector of observations from given stage ordered according to stage_obsindex(w,o).
"""
function stage_observations end
function ordered_stage_observations end

0 comments on commit d32ad0e

Please sign in to comment.