Skip to content
This repository has been archived by the owner on May 21, 2022. It is now read-only.

Commit

Permalink
interface: introduce maxsteps
Browse files Browse the repository at this point in the history
In case of CartPole v0 and v1,
there is a limitation of steps in single episode.

close #16
  • Loading branch information
iblislin committed Feb 16, 2018
1 parent 5e9034c commit fc6bb5e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/Reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ The `ismdp` query returns true when the environment is MDP, and false otherwise.
"""
ismdp(env::AbstractEnvironment) = false

"""
maxsteps(env)::Int
Return the max steps in single episode.
Default is `0` (unlimit).
"""
maxsteps(env::AbstractEnvironment) = 0

# ----------------------------------------------------------------
# Implement this interface for a new policy
Expand Down
10 changes: 8 additions & 2 deletions src/envs/cartpole.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ const x_threshold = 2.4
mutable struct CartPole <: AbstractEnvironment
state::Vector{Float64}
reward::Float64
maxsteps::Int # max step in each episode
end
CartPole() = CartPole(0.1rand(4)-0.05, 0.0)
CartPole(; maxsteps = 0) = CartPole(0.1rand(4)-0.05, 0.0, maxsteps)

reset!(env::CartPole) = (env.state = 0.1rand(4)-0.05; env.reward = 0.0; return)
# see https://github.com/FluxML/model-zoo/pull/23#issuecomment-366030179
CartPoleV0() = CartPole(maxsteps = 200)
CartPoleV1() = CartPole(maxsteps = 500)

reset!(env::CartPole) = (env.state = 0.1rand(4)-0.05; env.reward = 0.0; return)
actions(env::CartPole, s) = DiscreteSet(1:2)
maxsteps(env::CartPole) = env.maxsteps

function step!(env::CartPole, s, a)
s = state(env)
Expand Down
4 changes: 3 additions & 1 deletion src/episodes/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ function Base.start(ep::Episode)
end

function Base.done(ep::Episode, i)
finished(ep.env, state(ep.env))
n = maxsteps(ep.env)
(n != 0 && ep.niter >= n) && return true
finished(ep.env, state(ep.env))
end

# take one step in the enviroment after querying the policy for an action
Expand Down

0 comments on commit fc6bb5e

Please sign in to comment.