Skip to content

Commit

Permalink
Merge fde04d9 into e6fb186
Browse files Browse the repository at this point in the history
  • Loading branch information
Shushman committed Oct 10, 2019
2 parents e6fb186 + fde04d9 commit bf84e96
Show file tree
Hide file tree
Showing 17 changed files with 48 additions and 103 deletions.
6 changes: 2 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ language: julia

julia:
- 1.0
- 1
- 1.2

os:
- linux
- osx
- windows

notifications:
email:
- maxim.bouton@gmail.com

# script:
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "POMDPModels"
uuid = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
repo = "https://github.com/JuliaPOMDP/POMDPModels.jl"
version = "0.3.5"
version = "0.4.0"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
Expand All @@ -18,7 +18,8 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
POMDPSimulators = "< 0.3.0"
POMDPSimulators = "0.3"
POMDPs = "0.8.1, 0.8"
julia = "1"

[extras]
Expand All @@ -27,7 +28,6 @@ POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
POMDPTesting = "92e6a534-49c2-5324-9027-86e3c861ab81"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TikzPictures = "37f6aa50-8035-52d0-81c2-5a1d08754b2d"

[targets]
test = ["NBInclude", "POMDPPolicies", "POMDPSimulators", "POMDPTesting", "Test", "TikzPictures"]
test = ["NBInclude", "POMDPPolicies", "POMDPSimulators", "POMDPTesting", "Test"]
17 changes: 5 additions & 12 deletions src/CryingBabies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ updater(problem::BabyPOMDP) = DiscreteUpdater(problem)

actions(::BabyPOMDP) = (true, false)
actionindex(::BabyPOMDP, a::Bool) = a + 1
n_actions(::BabyPOMDP) = 2
states(::BabyPOMDP) = (true, false)
stateindex(::BabyPOMDP, s::Bool) = s + 1
n_states(::BabyPOMDP) = 2
observations(::BabyPOMDP) = (true, false)
obsindex(::BabyPOMDP, o::Bool) = o + 1
n_observations(::BabyPOMDP) = 2


# start knowing baby is not not hungry
initialstate_distribution(::BabyPOMDP) = BoolDistribution(0.0)
Expand All @@ -39,7 +37,7 @@ function transition(pomdp::BabyPOMDP, s::Bool, a::Bool)
end
end

function observation(pomdp::BabyPOMDP, a::Bool, sp::Bool)
function observation(pomdp::BabyPOMDP, sp::Bool)
if sp # hungry
return BoolDistribution(pomdp.p_cry_when_hungry)
else
Expand All @@ -61,22 +59,17 @@ end

discount(p::BabyPOMDP) = p.discount

function generate_o(p::BabyPOMDP, s::Bool, rng::AbstractRNG)
d = observation(p, true, s) # obs distrubtion not action dependant
return rand(rng, d)
end

# some example policies
mutable struct Starve <: Policy end
struct Starve <: Policy end
action(::Starve, ::B) where {B} = false
updater(::Starve) = NothingUpdater()

mutable struct AlwaysFeed <: Policy end
struct AlwaysFeed <: Policy end
action(::AlwaysFeed, ::B) where {B} = true
updater(::AlwaysFeed) = NothingUpdater()

# feed when the previous observation was crying - this is nearly optimal
mutable struct FeedWhenCrying <: Policy end
struct FeedWhenCrying <: Policy end
updater(::FeedWhenCrying) = PreviousObservationUpdater()
function action(::FeedWhenCrying, b::Union{Nothing, Bool})
if b == nothing || b == false # not crying (or null)
Expand Down
10 changes: 5 additions & 5 deletions src/InvertedPendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
end

actions(ip::InvertedPendulum) = [-50., 0., 50.]
n_actions(ip::InvertedPendulum) = 3

function initialstate(ip::InvertedPendulum, rng::AbstractRNG)
sp = ((rand(rng)-0.5)*0.1, (rand(rng)-0.5)*0.1, )
Expand Down Expand Up @@ -53,10 +52,11 @@ function euler(m::InvertedPendulum,s::Tuple{Float64,Float64},a::Float64)
return (th_,w_)
end

function generate_s(ip::InvertedPendulum,
s::Tuple{Float64,Float64},
a::Float64,
rng::AbstractRNG)
function gen(::DDNNode{:sp},
ip::InvertedPendulum,
s::Tuple{Float64,Float64},
a::Float64,
rng::AbstractRNG)
a_offset = 20*(rand(rng)-0.5)
a_ = a + a_offset

Expand Down
1 change: 0 additions & 1 deletion src/LightDark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ isterminal(::LightDark1D, s::LightDark1DState) = s.status < 0


actions(::LightDark1D) = -1:1
n_actions(p::LightDark1D) = length(actions(p))


struct LDNormalStateDist
Expand Down
10 changes: 5 additions & 5 deletions src/MountainCar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
end

actions(::MountainCar) = [-1., 0., 1.]
n_actions(mc::MountainCar) = 3

reward(mc::MountainCar,
s::Tuple{Float64,Float64},
Expand All @@ -20,10 +19,11 @@ initialstate(mc::MountainCar, ::AbstractRNG) = (-0.5,0.,)
isterminal(::MountainCar,s::Tuple{Float64,Float64}) = s[1] >= 0.5
discount(mc::MountainCar) = mc.discount

function generate_s( mc::MountainCar,
s::Tuple{Float64,Float64},
a::Float64,
::AbstractRNG)
function gen(::DDNNode{:sp},
mc::MountainCar,
s::Tuple{Float64,Float64},
a::Float64,
::AbstractRNG)
x,v = s
v_ = v + a*0.001+cos(3*x)*-0.0025
v_ = max(min(0.07,v_),-0.07)
Expand Down
16 changes: 5 additions & 11 deletions src/POMDPModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ import Base: ==, hash
import Random: rand, rand!
import Distributions: pdf

import POMDPs: initialstate, generate_s, generate_o, generate_sor, support, discount, isterminal
import POMDPs: actions, n_actions, actionindex, action, dimensions
import POMDPs: states, n_states, stateindex, transition
import POMDPs: observations, observation, n_observations, obsindex
import POMDPs: initialstate, initialstate_distribution
import POMDPs: gen, support, discount, isterminal
import POMDPs: actions, actionindex, action, dimensions
import POMDPs: states, stateindex, transition
import POMDPs: observations, observation, obsindex
import POMDPs: initialstate, initialstate_distribution, initialobs
import POMDPs: updater, update
import POMDPs: reward
import POMDPs: convert_s, convert_a, convert_o
Expand Down Expand Up @@ -113,11 +113,5 @@ export
GridWorldStateSpace,
GridWorldDistribution,
static_reward
# plot

@deprecate GridWorld LegacyGridWorld
@deprecate GridWorld(sx::Int64, sy::Int64; kwargs...) LegacyGridWorld(sx, sy; kwargs...)
@deprecate GridWorld(; kwargs...) LegacyGridWorld(; kwargs...)
export GridWorld

end # module
26 changes: 2 additions & 24 deletions src/TMazes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ end
discount::Float64 = 0.99 # discount factor
end

n_states(m::TMaze) = 2 * (m.n + 1) + 1 # 2*(corr length + 1 (junction)) + 1 (term)
n_actions(::TMaze) = 4
n_observations(::TMaze) = 5

# state space is length of corr + 3 cells at the end
# |G|
Expand Down Expand Up @@ -80,7 +77,7 @@ end
support(d::TMazeInit) = zip(d.states, d.probs)
function initialstate_distribution(maze::TMaze)
s = states(maze)
ns = n_states(maze)
ns = length(s)
p = zeros(ns) .+ 1.0 / (ns-1)
p[end] = 0.0
#s1 = TMazeState(1, :north, false)
Expand Down Expand Up @@ -182,7 +179,7 @@ end
# observation mapping
# 1 2 3 4 5
# goal N goal S corridor junction terminal
function observation(maze::TMaze, a::Int64, sp::TMazeState)
function observation(maze::TMaze, sp::TMazeState)
d::TMazeObservationDistribution = create_observation_distribution(maze)
sp.term ? (d.current_observation = 5; return d) : (nothing)
x = sp.x; g = sp.g
Expand All @@ -202,9 +199,6 @@ function observation(maze::TMaze, a::Int64, sp::TMazeState)
d.current_observation = 5
return d
end
function observation(maze::TMaze, s::TMazeState, a::Int64, sp::TMazeState)
return observation(maze, a, sp)
end

isterminal(m::TMaze, s::TMazeState) = s.term

Expand All @@ -219,22 +213,6 @@ function stateindex(maze::TMaze, s::TMazeState)
end
end

function generate_o(maze::TMaze, s::TMazeState, rng::AbstractRNG)
s.term ? (return 5) : (nothing)
x = s.x; g = s.g
#if x == 1
if x <= 2
g == :north ? (return 1) : (return 2)
end
if 1 < x < (maze.n + 1)
return 3
end
if x == maze.n + 1
return 4
end
return 5
end

function Base.convert(maze::TMaze, s::TMazeState)
v = Array{Float64}(undef, 2)
v[1] = s.x
Expand Down
10 changes: 3 additions & 7 deletions src/Tabular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ pdf(d::DiscreteDistribution, sp::Int64) = d.p[sp] # T(s', a, s)
rand(rng::AbstractRNG, d::DiscreteDistribution) = sample(rng, Weights(d.p))

# MDP and POMDP common methods

n_states(prob::TabularProblem) = size(prob.T, 1)
n_actions(prob::TabularProblem) = size(prob.T, 2)

states(p::TabularProblem) = 1:n_states(p)
actions(p::TabularProblem) = 1:n_actions(p)
states(p::TabularProblem) = 1:size(p.T, 1)
actions(p::TabularProblem) = 1:size(p.T, 2)

stateindex(::TabularProblem, s::Int64) = s
actionindex(::TabularProblem, a::Int64) = a
Expand All @@ -59,7 +55,7 @@ transition(p::TabularProblem, s::Int64, a::Int64) = DiscreteDistribution(view(p.

reward(prob::TabularProblem, s::Int64, a::Int64) = prob.R[s, a]

initialstate_distribution(p::TabularProblem) = DiscreteDistribution(ones(n_states(p))./n_states(p))
initialstate_distribution(p::TabularProblem) = DiscreteDistribution(ones(length(states(p)))./length(states(p)))

# POMDP only methods
n_observations(p::TabularProblem) = size(p.O, 1)
Expand Down
10 changes: 3 additions & 7 deletions src/TigerPOMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ const TIGER_LEFT = true
const TIGER_RIGHT = false


n_states(::TigerPOMDP) = 2
n_actions(::TigerPOMDP) = 3
n_observations(::TigerPOMDP) = 2

# Resets the problem after opening door; does nothing after listening
function transition(pomdp::TigerPOMDP, s::Bool, a::Int64)
p = 1.0
Expand Down Expand Up @@ -81,7 +77,7 @@ end

discount(pomdp::TigerPOMDP) = pomdp.discount_factor

function generate_o(p::TigerPOMDP, s::Bool, rng::AbstractRNG)
d = observation(p, 0, s) # obs distrubtion not action dependant
function initialobs(p::TigerPOMDP, s::Bool, rng::AbstractRNG)
d = observation(p, 0, s) # listen
return rand(rng, d)
end
end
13 changes: 6 additions & 7 deletions src/gridworld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ function POMDPs.states(mdp::SimpleGridWorld)
push!(ss, GWPos(-1,-1))
return ss
end
POMDPs.n_states(mdp::SimpleGridWorld) = prod(mdp.size) + 1

function POMDPs.stateindex(mdp::SimpleGridWorld, s::AbstractVector{Int})
if all(s.>0)
return LinearIndices(mdp.size)[s...]
else
return n_states(mdp)
return prod(mdp.size) + 1 # TODO: Change
end
end

Expand All @@ -59,7 +59,7 @@ POMDPs.initialstate_distribution(mdp::SimpleGridWorld) = GWUniform(mdp.size)

POMDPs.actions(mdp::SimpleGridWorld) = (:up, :down, :left, :right)
Base.rand(rng::AbstractRNG, t::NTuple{L,Symbol}) where L = t[rand(rng, 1:length(t))] # don't know why this doesn't work out of the box
POMDPs.n_actions(mdp::SimpleGridWorld) = 4


const dir = Dict(:up=>GWPos(0,1), :down=>GWPos(0,-1), :left=>GWPos(-1,0), :right=>GWPos(1,0))
const aind = Dict(:up=>1, :down=>2, :left=>3, :right=>4)
Expand All @@ -76,16 +76,15 @@ function POMDPs.transition(mdp::SimpleGridWorld, s::AbstractVector{Int}, a::Symb
return Deterministic(GWPos(-1,-1))
end

destinations = MVector{n_actions(mdp)+1, GWPos}(undef)
destinations = MVector{length(actions(mdp))+1, GWPos}(undef)
destinations[1] = s

# probs = MVector{n_actions(mdp)+1, Float64}()
probs = @MVector(zeros(n_actions(mdp)+1))
probs = @MVector(zeros(length(actions(mdp))+1))
for (i, act) in enumerate(actions(mdp))
if act == a
prob = mdp.tprob # probability of transitioning to the desired cell
else
prob = (1.0 - mdp.tprob)/(n_actions(mdp) - 1) # probability of transitioning to another cell
prob = (1.0 - mdp.tprob)/(length(actions(mdp)) - 1) # probability of transitioning to another cell
end

dest = s + dir[act]
Expand Down
2 changes: 0 additions & 2 deletions src/legacy/GridWorlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ end

actions(mdp::LegacyGridWorld) = [:up, :down, :left, :right]

n_states(mdp::LegacyGridWorld) = mdp.size_x*mdp.size_y+1
n_actions(mdp::LegacyGridWorld) = 4

function reward(mdp::LegacyGridWorld, state::GridWorldState, action::Symbol)
if state.done
Expand Down
4 changes: 2 additions & 2 deletions test/crying.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ let
r = simulate(sim, problem, policy, updater(policy), ib, true)
@test r -100.0 atol=0.01

# test generate_o
o = generate_o(problem, true, MersenneTwister(1))
# test gen(::o,...)
o = gen(DDNNode(:o), problem, true, MersenneTwister(1))
@test o == 1
# test vec
ov = convert_s(Array{Float64}, true, problem)
Expand Down
4 changes: 0 additions & 4 deletions test/legacy_gridworld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,4 @@ let

POMDPTesting.trans_prob_consistency_check(problem)

# test gridworld deprecation - these can be removed once the deprecation period is over
@test GridWorld() isa LegacyGridWorld
@test GridWorld(sx=9) isa LegacyGridWorld
@test GridWorld(9, 9, tp=0.8) isa LegacyGridWorld
end
6 changes: 3 additions & 3 deletions test/lightdark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ let
p = LightDark1D()
@test discount(p) == 0.9
s0 = LightDark1DState(0,0)
s0, _, r = generate_sor(p, s0, +1, rng)
s0, _, r = gen(DDNOut(:sp, :o, :r), p, s0, +1, rng)
@test s0.y == 1.0
@test r == 0
s1, _, r = generate_sor(p, s0, 0, rng)
s1, _, r = gen(DDNOut(:sp, :o, :r), p, s0, 0, rng)
@test s1.status != 0
@test r == -10.0
s2 = LightDark1DState(0, 5)
obs = generate_o(p, nothing, nothing, s2, rng)
obs = gen(DDNNode(:o), p, nothing, nothing, s2, rng)
@test abs(obs-6.0) <= 1.1


Expand Down
4 changes: 2 additions & 2 deletions test/tiger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ let

simulate(sim, pomdp1, policy, updater(policy), initialstate_distribution(pomdp1))

# test generate_o
o = generate_o(pomdp1, true, MersenneTwister(1))
# test gen(:o, ...)
o = initialobs(pomdp1, true, MersenneTwister(1))
@test o == 1
# test vec
ov = convert_o(Array{Float64}, true, pomdp1)
Expand Down

0 comments on commit bf84e96

Please sign in to comment.