Skip to content

Commit

Permalink
Merge pull request #21 from JuliaPOMDP/fast_distrib
Browse files Browse the repository at this point in the history
got rid of GridWorldDistribution
  • Loading branch information
zsunberg committed Oct 9, 2017
2 parents 64ca404 + f62dc0a commit 252a71a
Showing 1 changed file with 13 additions and 49 deletions.
62 changes: 13 additions & 49 deletions src/GridWorlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,19 @@ mutable struct GridWorld <: MDP{GridWorldState, Symbol}
discount_factor::Float64 # disocunt factor
end
# we use key worded arguments so we can change any of the values we pass in
function GridWorld(;sx::Int64=10, # size_x
sy::Int64=10, # size_y
rs::Vector{GridWorldState}=[GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)],
rv::Vector{Float64}=[-10.,-5,10,3],
penalty::Float64=0.0, # penalty for trying to go out of bounds (will be added to reward)
tp::Float64=0.7, # tprob
discount_factor::Float64=0.95,
terminals=Set{GridWorldState}([rs[i] for i in filter(i->rv[i]>0.0, 1:length(rs))]))
function GridWorld(sx::Int64, # size_x
sy::Int64; # size_y
rs::Vector{GridWorldState}=[GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)],
rv::Vector{Float64}=[-10.,-5,10,3],
penalty::Float64=0.0, # penalty for trying to go out of bounds (will be added to reward)
tp::Float64=0.7, # tprob
discount_factor::Float64=0.95,
terminals=Set{GridWorldState}([rs[i] for i in filter(i->rv[i]>0.0, 1:length(rs))]))
return GridWorld(sx, sy, rs, rv, penalty, tp, Set{GridWorldState}(terminals), discount_factor)
end

GridWorld(;sx::Int64=10, sy::Int64=10, kwargs...) = GridWorld(sx, sy; kwargs...)

# convenience function
function term_from_rs(rs, rv)
terminals = Set{GridWorldState}()
Expand Down Expand Up @@ -101,44 +103,6 @@ end
# returns the action space
actions(mdp::GridWorld, s=nothing) = [:up, :down, :left, :right]

#################################################################
# Distributions
#################################################################

struct GridWorldDistribution
neighbors::SVector{5, GridWorldState}
probs::SVector{5, Float64}
end


# returns an iterator over the distirubtion
function POMDPs.iterator(d::GridWorldDistribution)
return d.neighbors
end

function pdf(d::GridWorldDistribution, s::GridWorldState)
for (i, sp) in enumerate(d.neighbors)
if s == sp
return d.probs[i]
end
end
return 0.0
end

function rand(rng::AbstractRNG, d::GridWorldDistribution)
# assume the sum of d.probs is one
t = rand(rng)
n = length(d.neighbors)
i = 1
c = d.probs[1]
while c < t && i < n
i += 1
@inbounds c += d.probs[i]
end
new = d.neighbors[i]
return GridWorldState(new.x, new.y, new.done)
end

n_states(mdp::GridWorld) = mdp.size_x*mdp.size_y+1
n_actions(mdp::GridWorld) = 4

Expand Down Expand Up @@ -238,7 +202,7 @@ function transition(mdp::GridWorld, state::GridWorldState, action::Symbol)
if state.done
fill_probability!(probability, 1.0, 5)
neighbors[5] = GridWorldState(x, y, true)
return GridWorldDistribution(neighbors, probability)
return SparseCat(neighbors, probability)
end

reward_states = mdp.reward_states
Expand All @@ -247,7 +211,7 @@ function transition(mdp::GridWorld, state::GridWorldState, action::Symbol)
if state in mdp.terminals
fill_probability!(probability, 1.0, 5)
neighbors[5] = GridWorldState(x, y, true)
return GridWorldDistribution(neighbors, probability)
return SparseCat(neighbors, probability)
end

# The following match the definition of neighbors
Expand Down Expand Up @@ -289,7 +253,7 @@ function transition(mdp::GridWorld, state::GridWorldState, action::Symbol)
end
end

return GridWorldDistribution(neighbors, probability)
return SparseCat(neighbors, probability)
end


Expand Down

0 comments on commit 252a71a

Please sign in to comment.