Skip to content

Commit

Permalink
fix #59
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Dec 22, 2019
1 parent dbd56fa commit b3c3dc7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 154 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "POMDPModels"
uuid = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
repo = "https://github.com/JuliaPOMDP/POMDPModels.jl"
version = "0.4.2"
version = "0.4.3"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
Expand All @@ -18,13 +18,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
POMDPSimulators = "0.3"
POMDPs = "0.8.1"
BeliefUpdaters = "0.1"
ColorSchemes = "3"
Compose = "0.7"
Distributions = "0.21"
POMDPModelTools = "0.2"
POMDPSimulators = "0.3"
POMDPs = "0.8.1"
Parameters = "0"
StaticArrays = "0"
StatsBase = "0"
Expand Down
205 changes: 60 additions & 145 deletions src/TMazes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,203 +3,120 @@
g::Symbol = :north# goal north or south
end

struct TMaze <: POMDP{Union{TMazeState,TerminalState}, Int64, Int64}
@with_kw struct TMaze <: POMDP{Union{TMazeState,TerminalState}, Int64, Int64}
n::Int64 = 10 # corridor length
discount::Float64 = 0.99 # discount factor
end


# state space is length of corr + 3 cells at the end
# state space is (length of corr)*(north, south) + terminal
# |G|
# |S| | | | | | | | | |
# | | |x| | | | | | | |
# | |
# depending on where the goal is
function states(maze::TMaze)
space = statetype(maze)[]
for x in 1:(maze.n + 1), g in [:north, :south]
push!(space, TMazeState(x, g, false))
push!(space, TMazeState(x, g))
end
push!(space, ) # terminal
push!(space, terminalstate) # terminal
return space
end
stateindex(m::TMaze, s::TMazeState) = 2*s.x - (s.g==:north)
stateindex(m::TMaze, s::TerminalState) = 2*(m.n+1) + 1

# 4 actions: go North, East, South, West (1, 2, 3, 4)
actions(maze::TMaze) = 1:4
actionindex(maze::TMaze, i::Int) = i

# 5 observations: 2 for goal (left or right) + 2 for in corridor or at intersection + 1 term
observations(maze::TMaze) = 1:5
obsindex(maze::TMaze, i::Int) = i

# transition distribution (actions are deterministic)
mutable struct TMazeStateDistribution
current_state::TMazeState # deterministic
reset::Bool
reset_states::Vector{TMazeState}
reset_probs::Vector{Float64}
end
function create_transition_distribution(::TMaze)
rs = [TMazeState(1,:north,false), TMazeState(1,:south,false)]
rp = [0.5, 0.5]
TMazeStateDistribution(TMazeState(), false, rs, rp)
end
support(d::TMazeStateDistribution) = d.reset ? (return [d.current_state]) : (return zip(d.reset_states, d.reset_probs))

function pdf(d::TMazeStateDistribution, s::TMazeState)
if d.reset
in(s, d.reset_states) ? (return 0.5) : (return 0.0)
else
s == d.current_state ? (return 1.0) : (return 0.0)
end
end
function rand(rng::AbstractRNG, d::TMazeStateDistribution)
s = TMazeState()
if d.reset
rand(rng) < 0.5 ? (copy!(s, d.reset_states[1])) : (copy!(s, d.reset_states[2]))
return s
else
copy!(s, d.current_state)
return s
end
end
#rand(rng::AbstractRNG, d::TMazeStateDistribution)

struct TMazeInit
states::Vector{TMazeState}
probs::Vector{Float64}
end
support(d::TMazeInit) = d.states
function initialstate_distribution(maze::TMaze)
s = states(maze)
ns = length(s)
p = zeros(ns) .+ 1.0 / (ns-1)
p[end] = 0.0
#s1 = TMazeState(1, :north, false)
#s2 = TMazeState(1, :south, false)
#d = TMazeInit([s1, s2])
return TMazeInit(s, p)
end
function rand(rng::AbstractRNG, d::TMazeInit)
cat = Weights(d.probs)
idx = sample(rng, cat)
return d.states[idx]
end
function pdf(d::TMazeInit, s::TMazeState)
for i = 1:length(d.states)
if d.states[i] == s
return d.probs[i]
end
end
return 0.0
return SparseCat(s, p)
end

# observation distribution (deterministic)
mutable struct TMazeObservationDistribution
current_observation::Int64
end
create_observation_distribution(::TMaze) = TMazeObservationDistribution(1)
iterator(d::TMazeObservationDistribution) = [d.current_observation]

pdf(d::TMazeObservationDistribution, o::Int64) = Float64(o == d.current_observation)
support(d::TMazeObservationDistribution) = d.current_observation
rand(rng::AbstractRNG, d::TMazeObservationDistribution) = d.current_observation

function transition(maze::TMaze, s::TMazeState, a::Int64)
d=create_transition_distribution(maze)
d.reset = false
# check if terminal
if s.term
# reset
d.reset = true
#copy!(d.current_state, s) # state doesn't change
return d
end
# check if move into terminal move north or south
if s.x == maze.n + 1
if a == 1 || a == 3
d.current_state = TMazeState(1,:none,true) # state now terminal
return d
elseif a == 4
copy!(d.current_state, s)
d.current_state.x -= 1
return d
function transition(m::TMaze, s::TMazeState, a::Int64)
if a == 1 || a == 3
if s.x == m.n + 1
Deterministic(terminalstate)
else
copy!(d.current_state, s)
return d
Deterministic(s)
end
elseif a == 2
xp = min(s.x + 1, m.n + 1)
return Deterministic(TMazeState(xp, s.g))
elseif a == 4
xp = max(s.x - 1, 1)
return Deterministic(TMazeState(xp, s.g))
end
# check if move along hallway
if a == 2
copy!(d.current_state, s)
d.current_state.x += 1
return d
end
if a == 4
copy!(d.current_state, s)
s.x > 1 ? (d.current_state.x -= 1) : (nothing)
return d
end
# if none of the above just stay in place
copy!(d.current_state, s)
return d
end

function reward(maze::TMaze, s::TMazeState, a::Int64)
# check terminal
s.term ? (return 0.0) : (nothing)
# check if at junction
if s.x == maze.n + 1
transition(m::TMaze, s::TerminalState, a::Int64) = Deterministic(s)

function reward(m::TMaze, s::TMazeState, a::Int64)
if s.x == m.n + 1
# if at junction check action
if (s.g == :north && a == 1) || (s.g == :south && a == 3)
return 4.0
elseif (s.g == :north && a == 3) || (s.g == :south && a == 1)
return -0.1
else
return -0.1
end
end
# if bump against wall
if s.x < maze.n + 1 && (a == 1 || a == 3)
elseif a == 1 || a == 3
# bump against wall
return -0.1
else
return 0.0
end
return 0.0
end

# observation mapping
# 1 2 3 4 5
# goal N goal S corridor junction terminal
function observation(maze::TMaze, sp::TMazeState)
d::TMazeObservationDistribution = create_observation_distribution(maze)
sp.term ? (d.current_observation = 5; return d) : (nothing)
x = sp.x; g = sp.g
#if x == 1
if x <= 2
g == :north ? (d.current_observation = 1) : (d.current_observation = 2)
return d
end
if 1 < x < (maze.n + 1)
d.current_observation = 3
return d
end
if x == maze.n + 1
d.current_observation = 4
return d
function observation(m::TMaze, sp::TMazeState)
if sp.x <= 2
if sp.g == :north
return Deterministic(1)
else
return Deterministic(2)
end
elseif sp.x == m.n+1
return Deterministic(4)
else
return Deterministic(3)
end
d.current_observation = 5
return d
end

isterminal(m::TMaze, s::TMazeState) = s.term
observation(m::TMaze, sp::TerminalState) = Deterministic(5)

discount(m::TMaze) = m.discount

stateindex(maze::TMaze, s::TMazeState) = s.term ? (2*(maze.n+1) + 1) : (2*s.x - (s.g==:north))
function POMDPs.convert_s(::Type{A}, s::Union{TMazeState,TerminalState}, m::TMaze) where A <: AbstractArray
return convert(A, [stateindex(m, s)])
end

function Base.convert(maze::TMaze, s::TMazeState)
v = Array{Float64}(undef, 2)
v[1] = s.x
s.g == :north ? (v[2] = 0.0) : (v[2] = 1.0)
return v
# inverse of stateindex(m::TMaze, s::TMazeState) = 2*s.x - (s.g==:north)
function POMDPs.convert_s(::Type{S}, v::AbstractVector, m::TMaze) where S <: Union{TMazeState,TerminalState}
i = first(v)
if i == 2*(m.n + 1) + 1
return terminalstate
end

if i%2 == 0
g = :south
else
g = :north
end
x = div(i-1, 2) + 1
@assert x <= m.n + 1
return TMazeState(x, g)
end


struct MazeBelief
last_obs::Int64
mem::Symbol # memory
Expand All @@ -220,8 +137,6 @@ function POMDPs.update(bu::MazeUpdater, b::MazeBelief, a, o)
return MazeBelief(o, mem)
end



mutable struct MazeOptimal <: Policy end
POMDPs.updater(p::MazeOptimal) = MazeUpdater()

Expand Down
18 changes: 12 additions & 6 deletions test/tmaze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,24 @@ simulate(sim, problem, policy, updater(policy), initialstate_distribution(proble

POMDPTesting.probability_check(problem)

function test_obs(s::TMazeState, o::Int64)
function test_obs(s, o)
ot = gen(DDNNode{:o}(), problem, s, MersenneTwister(1))
@test ot == o
end

test_obs(TMazeState(1, :north, false), 1) # north sign
test_obs(TMazeState(1, :south, false), 2) # south sign
test_obs(TMazeState(5, :south, false), 3) # corridor
test_obs(TMazeState(11, :south, false), 4) # junction
test_obs(TMazeState(11, :south, true), 5) # terminal
test_obs(TMazeState(1, :north), 1) # north sign
test_obs(TMazeState(1, :south), 2) # south sign
test_obs(TMazeState(5, :south), 3) # corridor
test_obs(TMazeState(11, :south), 4) # junction
test_obs(terminalstate, 5) # terminal

ov = convert_o(Array{Float64}, 1, problem)
@test ov == [1.]
o = convert_o(Int64, ov, problem)
@test o == 1

for s in states(problem)
v = convert_s(Vector{Float64}, s, problem)
s2 = convert_s(Union{TerminalState,TMazeState}, v, problem)
@test s2 == s
end

0 comments on commit b3c3dc7

Please sign in to comment.