Skip to content

Commit

Permalink
Merge pull request #25 from JuliaPOMDP/tiny_fixes
Browse files Browse the repository at this point in the history
Tiny fixes
  • Loading branch information
zsunberg committed Mar 24, 2018
2 parents 1888fb3 + 3cdeafd commit dfcd739
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 14 deletions.
10 changes: 0 additions & 10 deletions src/GridWorlds.jl
Expand Up @@ -74,16 +74,6 @@ end

GridWorld(;sx::Int64=10, sy::Int64=10, kwargs...) = GridWorld(sx, sy; kwargs...)

# convenience function
function term_from_rs(rs, rv)
terminals = Set{GridWorldState}()
for (i,v) in enumerate(rv)
if v > 0.0
push!(terminals, rs[i])
end
end
end


#################################################################
# State and Action Spaces
Expand Down
4 changes: 2 additions & 2 deletions src/Random.jl
@@ -1,4 +1,4 @@
function RandomMDP(ns::Int64, na::Int64, discount::Float64; rng::AbstractRNG=MersenneTwister())
function RandomMDP(ns::Int64, na::Int64, discount::Float64; rng::AbstractRNG=Base.GLOBAL_RNG)
# random dynamics
T = rand(rng, ns, na, ns)
# normalize
Expand All @@ -12,7 +12,7 @@ end
RandomMDP() = RandomMDP(100, 5, 0.9)


function RandomPOMDP(ns::Int64, na::Int64, no::Int64, discount::Float64; rng::AbstractRNG=MersenneTwister())
function RandomPOMDP(ns::Int64, na::Int64, no::Int64, discount::Float64; rng::AbstractRNG=Base.GLOBAL_RNG)
# random dynamics
T = rand(rng, ns, na, ns)
# random observation model
Expand Down
4 changes: 2 additions & 2 deletions test/crying.jl
Expand Up @@ -8,10 +8,10 @@ problem = BabyPOMDP()

# starve policy
# when the baby is never fed, the reward for starting in the hungry state should be -100
sim = RolloutSimulator(eps=0.0001, initial_state=true)
sim = RolloutSimulator(eps=0.0001)
ib = nothing
policy = Starve()
r = simulate(sim, problem, policy, updater(policy), ib)
r = simulate(sim, problem, policy, updater(policy), ib, true)
@test r -100.0 atol=0.01

# test generate_o
Expand Down
4 changes: 4 additions & 0 deletions test/random.jl
Expand Up @@ -25,3 +25,7 @@ ov = convert_o(Array{Float64}, 1, pomdp)
@test ov == [1.]
o = convert_o(Int, ov, pomdp)
@test o == 1

# to catch anything in the default constructors
RandomPOMDP()
RandomMDP()

0 comments on commit dfcd739

Please sign in to comment.