From 75eeec6579c76c57cfeb14151a3e8557cc1374f1 Mon Sep 17 00:00:00 2001 From: Matthias Untergassmair <53627988+mattuntergassmair@users.noreply.github.com> Date: Mon, 16 Dec 2019 15:46:59 -0800 Subject: [PATCH 1/2] TMaze bugfix: terminal states (#58) * TMaze bugfix: terminal states * Fix Tuple{TMazeState,Float64} bug * version correction * cleanup: remove mutable, etc * fix to make tests pass - will finish cleanup later --- Project.toml | 2 +- src/TMazes.jl | 46 +++++++++++++++++----------------------------- 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index 2f4a470..ac6b60c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "POMDPModels" uuid = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" repo = "https://github.com/JuliaPOMDP/POMDPModels.jl" -version = "0.4.1" +version = "0.4.2" [deps] BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" diff --git a/src/TMazes.jl b/src/TMazes.jl index 45f7646..0872379 100644 --- a/src/TMazes.jl +++ b/src/TMazes.jl @@ -13,7 +13,7 @@ function Base.copy!(s1::TMazeState, s2::TMazeState) return s1 end -@with_kw mutable struct TMaze <: POMDP{TMazeState, Int64, Int64} +@with_kw struct TMaze <: POMDP{TMazeState, Int64, Int64} n::Int64 = 10 # corridor length discount::Float64 = 0.99 # discount factor end @@ -34,8 +34,10 @@ function states(maze::TMaze) end # 4 actions: go North, East, South, West (1, 2, 3, 4) actions(maze::TMaze) = 1:4 +actionindex(maze::TMaze, i::Int) = i # 5 observations: 2 for goal (left or right) + 2 for in corridor or at intersection + 1 term observations(maze::TMaze) = 1:5 +obsindex(maze::TMaze, i::Int) = i # transition distribution (actions are deterministic) mutable struct TMazeStateDistribution @@ -49,7 +51,7 @@ function create_transition_distribution(::TMaze) rp = [0.5, 0.5] TMazeStateDistribution(TMazeState(), false, rs, rp) end -support(d::TMazeStateDistribution) = d.reset ? (return [(d.current_state, 1.0)]) : (return zip(d.reset_states, d.reset_probs)) +support(d::TMazeStateDistribution) = d.reset ? (return [d.current_state]) : (return zip(d.reset_states, d.reset_probs)) function pdf(d::TMazeStateDistribution, s::TMazeState) if d.reset @@ -70,11 +72,11 @@ function rand(rng::AbstractRNG, d::TMazeStateDistribution) end #rand(rng::AbstractRNG, d::TMazeStateDistribution) -mutable struct TMazeInit +struct TMazeInit states::Vector{TMazeState} probs::Vector{Float64} end -support(d::TMazeInit) = zip(d.states, d.probs) +support(d::TMazeInit) = d.states function initialstate_distribution(maze::TMaze) s = states(maze) ns = length(s) @@ -86,14 +88,9 @@ function initialstate_distribution(maze::TMaze) return TMazeInit(s, p) end function rand(rng::AbstractRNG, d::TMazeInit) - s = TMazeState() - #idx = nothing - #rand(rng) < 0.5 ? (idx = 1) : (idx = 2) - #copy!(s, d.states[idx]) cat = Weights(d.probs) idx = sample(rng, cat) - copy!(s, d.states[idx]) - return s + return d.states[idx] end function pdf(d::TMazeInit, s::TMazeState) for i = 1:length(d.states) @@ -102,7 +99,6 @@ function pdf(d::TMazeInit, s::TMazeState) end end return 0.0 - #in(s, d.states) ? (return 0.5) : (return 0.0) end # observation distribution (deterministic) @@ -112,7 +108,8 @@ end create_observation_distribution(::TMaze) = TMazeObservationDistribution(1) iterator(d::TMazeObservationDistribution) = [d.current_observation] -pdf(d::TMazeObservationDistribution, o::Int64) = o == d.current_observation ? (return 1.0) : (return 0.0) +pdf(d::TMazeObservationDistribution, o::Int64) = Float64(o == d.current_observation) +support(d::TMazeObservationDistribution) = d.current_observation rand(rng::AbstractRNG, d::TMazeObservationDistribution) = d.current_observation function transition(maze::TMaze, s::TMazeState, a::Int64) @@ -204,14 +201,7 @@ isterminal(m::TMaze, s::TMazeState) = s.term discount(m::TMaze) = m.discount -function stateindex(maze::TMaze, s::TMazeState) - s.term ? (return maze.n + 1) : (nothing) - if s.g == :north - return s.x + (s.x - 1) - else - return s.x + (s.x) - end -end +stateindex(maze::TMaze, s::TMazeState) = s.term ? (2*(maze.n+1) + 1) : (2*s.x - (s.g==:north)) function Base.convert(maze::TMaze, s::TMazeState) v = Array{Float64}(undef, 2) @@ -220,26 +210,24 @@ function Base.convert(maze::TMaze, s::TMazeState) return v end -mutable struct MazeBelief +struct MazeBelief last_obs::Int64 mem::Symbol # memory end MazeBelief() = MazeBelief(1, :none) -mutable struct MazeUpdater <: Updater end -POMDPs.initialize_belief(bu::MazeUpdater, d::Any) = b +struct MazeUpdater <: Updater end +POMDPs.initialize_belief(bu::MazeUpdater, d::Any) = d function POMDPs.update(bu::MazeUpdater, b::MazeBelief, a, o) - bp::MazeBelief=create_belief(bu) - bp.last_obs = o - bp.mem = b.mem + mem = b.mem if o == 1 - bp.mem = :north + mem = :north end if o == 2 - bp.mem = :south + mem = :south end - return bp + return MazeBelief(o, mem) end From 727f0712b226a49312f408c3a137b18024c95a3f Mon Sep 17 00:00:00 2001 From: Matthias Untergassmair Date: Tue, 17 Dec 2019 16:34:59 -0800 Subject: [PATCH 2/2] missed the Float64 bug --- src/TMazes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TMazes.jl b/src/TMazes.jl index 0872379..7d065bc 100644 --- a/src/TMazes.jl +++ b/src/TMazes.jl @@ -51,7 +51,7 @@ function create_transition_distribution(::TMaze) rp = [0.5, 0.5] TMazeStateDistribution(TMazeState(), false, rs, rp) end -support(d::TMazeStateDistribution) = d.reset ? (return [d.current_state]) : (return zip(d.reset_states, d.reset_probs)) +support(d::TMazeStateDistribution) = d.reset ? [d.current_state] : d.reset_states function pdf(d::TMazeStateDistribution, s::TMazeState) if d.reset