diff --git a/src/GridWorlds.jl b/src/GridWorlds.jl index 244aed7..4198729 100644 --- a/src/GridWorlds.jl +++ b/src/GridWorlds.jl @@ -60,17 +60,19 @@ mutable struct GridWorld <: MDP{GridWorldState, Symbol} discount_factor::Float64 # disocunt factor end # we use key worded arguments so we can change any of the values we pass in -function GridWorld(;sx::Int64=10, # size_x - sy::Int64=10, # size_y - rs::Vector{GridWorldState}=[GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)], - rv::Vector{Float64}=[-10.,-5,10,3], - penalty::Float64=0.0, # penalty for trying to go out of bounds (will be added to reward) - tp::Float64=0.7, # tprob - discount_factor::Float64=0.95, - terminals=Set{GridWorldState}([rs[i] for i in filter(i->rv[i]>0.0, 1:length(rs))])) +function GridWorld(sx::Int64=10, # size_x + sy::Int64=10; # size_y + rs::Vector{GridWorldState}=[GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)], + rv::Vector{Float64}=[-10.,-5,10,3], + penalty::Float64=0.0, # penalty for trying to go out of bounds (will be added to reward) + tp::Float64=0.7, # tprob + discount_factor::Float64=0.95, + terminals=Set{GridWorldState}([rs[i] for i in filter(i->rv[i]>0.0, 1:length(rs))])) return GridWorld(sx, sy, rs, rv, penalty, tp, Set{GridWorldState}(terminals), discount_factor) end +GridWorld(;sx::Int64=10, sy::Int64=10, kwargs...) = GridWorld(sx, sy; kwargs...) + # convenience function function term_from_rs(rs, rv) terminals = Set{GridWorldState}() @@ -101,44 +103,6 @@ end # returns the action space actions(mdp::GridWorld, s=nothing) = [:up, :down, :left, :right] -################################################################# -# Distributions -################################################################# - -struct GridWorldDistribution - neighbors::SVector{5, GridWorldState} - probs::SVector{5, Float64} -end - - -# returns an iterator over the distirubtion -function POMDPs.iterator(d::GridWorldDistribution) - return d.neighbors -end - -function pdf(d::GridWorldDistribution, s::GridWorldState) - for (i, sp) in enumerate(d.neighbors) - if s == sp - return d.probs[i] - end - end - return 0.0 -end - -function rand(rng::AbstractRNG, d::GridWorldDistribution) - # assume the sum of d.probs is one - t = rand(rng) - n = length(d.neighbors) - i = 1 - c = d.probs[1] - while c < t && i < n - i += 1 - @inbounds c += d.probs[i] - end - new = d.neighbors[i] - return GridWorldState(new.x, new.y, new.done) -end - n_states(mdp::GridWorld) = mdp.size_x*mdp.size_y+1 n_actions(mdp::GridWorld) = 4 @@ -238,7 +202,7 @@ function transition(mdp::GridWorld, state::GridWorldState, action::Symbol) if state.done fill_probability!(probability, 1.0, 5) neighbors[5] = GridWorldState(x, y, true) - return GridWorldDistribution(neighbors, probability) + return SparseCat(neighbors, probability) end reward_states = mdp.reward_states @@ -247,7 +211,7 @@ function transition(mdp::GridWorld, state::GridWorldState, action::Symbol) if state in mdp.terminals fill_probability!(probability, 1.0, 5) neighbors[5] = GridWorldState(x, y, true) - return GridWorldDistribution(neighbors, probability) + return SparseCat(neighbors, probability) end # The following match the definition of neighbors @@ -289,7 +253,7 @@ function transition(mdp::GridWorld, state::GridWorldState, action::Symbol) end end - return GridWorldDistribution(neighbors, probability) + return SparseCat(neighbors, probability) end