Skip to content
This repository has been archived by the owner on May 21, 2022. It is now read-only.

Commit

Permalink
Merge pull request #17 from JuliaML/ib/cartpole
Browse files Browse the repository at this point in the history
interface: introduce maxsteps for CartPole v0 and v1

close #16
  • Loading branch information
iblislin committed Feb 18, 2018
2 parents ad8eb0d + 44bee14 commit 3968119
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 125 deletions.
13 changes: 13 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# v0.1.0

- New interface for controlling episode termination `maxsteps(env)::Int` ([#17]).
The condition of termination is `finished(...) || maxsteps(...)` now.

- New field for `CartPole` environment: `maxsteps`.
An keyword of constructor is added: `CartPole(; maxsteps = 42)` ([#17]).

Also, there are helper functions of CartPole v0 and v1:
- `CartPoleV0()`: this is equal to `CartPole(maxsteps = 200)`
- `CartPoleV1()`: this is equal to `CartPole(maxsteps = 500)`

[#17]: https://github.com/JuliaML/Reinforce.jl/pull/17
28 changes: 21 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,40 @@ Packages which build on Reinforce:

---

New environments are created by subtyping `AbstractEnvironment` and implementing a few methods:
New environments are created by subtyping `AbstractEnvironment` and implementing
a few methods:

- `reset!(env)`
- `actions(env, s) --> A`
- `step!(env, s, a) --> r, s′`
- `actions(env, s) -> A`
- `step!(env, s, a) -> (r, s′)`
- `finished(env, s′)`

and optional overrides:

- `state(env) --> s`
- `reward(env) --> r`
- `state(env) -> s`
- `reward(env) -> r`

which map to `env.state` and `env.reward` respectively when unset.

- `ismdp(env) --> bool`
- `ismdp(env) -> Bool`

An environment may be fully observable (MDP) or partially observable (POMDP). In the case of a partially observable environment, the state `s` is really an observation `o`. To maintain consistency, we call everything a state, and assume that an environment is free to maintain additional (unobserved) internal state. The `ismdp` query returns true when the environment is MDP, and false otherwise.
An environment may be fully observable (MDP) or partially observable (POMDP).
In the case of a partially observable environment, the state `s` is really
an observation `o`. To maintain consistency, we call everything a state,
and assume that an environment is free to maintain additional (unobserved)
internal state. The `ismdp` query returns true when the environment is MDP,
and false otherwise.

- `maxsteps(env) -> Int`

The terminating condition of an episode is control by
`maxsteps() || finished()`.
It's default value is `0`, indicates unlimited.

---

An minimal example for testing purpose is `test/foo.jl`.

TODO: more details and examples

---
Expand Down
103 changes: 56 additions & 47 deletions src/Reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,92 +15,100 @@ using LearningStrategies
import LearningStrategies: pre_hook, iter_hook, finished, post_hook

export
AbstractEnvironment,
reset!,
step!,
reward,
state,
finished,
actions,
ismdp,

AbstractPolicy,
RandomPolicy,
OnlineGAE,
OnlineActorCritic,
EpisodicActorCritic,
action,

AbstractState,
StateVector,
History,
state!,

Episode,
Episodes,
run_episode
AbstractEnvironment,
reset!,
step!,
reward,
state,
finished,
actions,
ismdp,
maxsteps,

AbstractPolicy,
RandomPolicy,
OnlineGAE,
OnlineActorCritic,
EpisodicActorCritic,
action,

AbstractState,
StateVector,
History,
state!,

Episode,
Episodes,
run_episode


# ----------------------------------------------------------------
# Implement this interface for a new environment

abstract type AbstractEnvironment end


"""
`reset!(env)`
reset!(env)
Reset an environment.
"""
function reset! end


"""
r, s′ = step!(env, s, a)
r, s′ = step!(env, s, a)
Move the simulation forward, collecting a reward and getting the next state.
"""
function step! end


# note for developers: you should also implement Base.done(env) for episodic environments
finished(env::AbstractEnvironment, s′) = false


"""
`A′ = actions(env, s′)`
A′ = actions(env, s′)
Return a list/set/description of valid actions from state `s′`.
"""
# actions(env::AbstractEnvironment) = actions(env, state(env))
function actions end


# note for developers: you don't need to implement these if you have state/reward fields

"""
`s = state(env)`
s = state(env)
Return the current state of the environment.
"""
state(env::AbstractEnvironment) = env.state

"""
`r = reward(env)`
r = reward(env)
Return the current reward of the environment.
"""
reward(env::AbstractEnvironment) = env.reward

"""
`ismdp(env) --> bool`
ismdp(env)::Bool
An environment may be fully observable (MDP) or partially observable (POMDP). In the case of a partially observable environment, the state `s` is really an observation `o`. To maintain consistency, we call everything a state, and assume that an environment is free to maintain additional (unobserved) internal state.
An environment may be fully observable (MDP) or partially observable (POMDP).
In the case of a partially observable environment,
the state `s` is really an observation `o`.
To maintain consistency, we call everything a state, and assume that an
environment is free to maintain additional (unobserved) internal state.
The `ismdp` query returns true when the environment is MDP, and false otherwise.
"""
ismdp(env::AbstractEnvironment) = false

"""
maxsteps(env)::Int
Return the max steps in single episode.
Default is `0` (unlimited).
"""
maxsteps(env::AbstractEnvironment) = 0

# ----------------------------------------------------------------
# Implement this interface for a new policy
Expand Down Expand Up @@ -135,11 +143,11 @@ include("envs/mountain_car.jl")
# a keyboard action space

struct KeyboardAction
key
key
end

mutable struct KeyboardActionSet{T} <: AbstractSet{T}
keys::Vector
keys::Vector
end

LearnBase.randtype(s::KeyboardActionSet) = KeyboardAction
Expand All @@ -151,23 +159,24 @@ Base.length(s::KeyboardActionSet) = 1
# a mouse/pointer action space

struct MouseAction
x::Int
y::Int
button::Int
x::Int
y::Int
button::Int
end

mutable struct MouseActionSet{T} <: AbstractSet{T}
screen_width::Int
screen_height::Int
button::DiscreteSet{Vector{Int}}
screen_width::Int
screen_height::Int
button::DiscreteSet{Vector{Int}}
end

LearnBase.randtype(s::MouseActionSet) = MouseAction
Base.rand(s::MouseActionSet) = MouseAction(rand(1:s.screen_width), rand(1:s.screen_height), rand(s.button))
Base.in(a::MouseAction, s::MouseActionSet) = a.x in 1:s.screen_width && a.y in 1:s.screen_height && a.button in s.button
Base.rand(s::MouseActionSet) =
MouseAction(rand(1:s.screen_width), rand(1:s.screen_height), rand(s.button))
Base.in(a::MouseAction, s::MouseActionSet) =
a.x in 1:s.screen_width && a.y in 1:s.screen_height && a.button in s.button
Base.length(s::MouseActionSet) = 1


# ----------------------------------------------------------------

end # module
end # module Reinforce
96 changes: 51 additions & 45 deletions src/envs/cartpole.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,63 +20,69 @@ const x_threshold = 2.4
mutable struct CartPole <: AbstractEnvironment
state::Vector{Float64}
reward::Float64
maxsteps::Int # max step in each episode
end
CartPole() = CartPole(0.1rand(4)-0.05, 0.0)
CartPole(; maxsteps = 0) = CartPole(0.1rand(4)-0.05, 0.0, maxsteps)

reset!(env::CartPole) = (env.state = 0.1rand(4)-0.05; env.reward = 0.0; return)
# see https://github.com/FluxML/model-zoo/pull/23#issuecomment-366030179
CartPoleV0() = CartPole(maxsteps = 200)
CartPoleV1() = CartPole(maxsteps = 500)

reset!(env::CartPole) = (env.state = 0.1rand(4)-0.05; env.reward = 0.0; return)
actions(env::CartPole, s) = DiscreteSet(1:2)
maxsteps(env::CartPole) = env.maxsteps

function step!(env::CartPole, s, a)
s = state(env)
x, xvel, θ, θvel = s

force = (a == 1 ? -1 : 1) * force_mag
tmp = (force + mass_pole_length * sin(θ) * (θvel^2)) / total_mass
θacc = (gravity * sin(θ) - tmp * cos(θ)) /
(pole_length * (4/3 - mass_pole * (cos(θ)^2) / total_mass))
xacc = tmp - mass_pole_length * θacc * cos(θ) / total_mass

# update state
s[1] = x += τ * xvel
s[2] = xvel += τ * xacc
s[3] = θ += τ * θvel
s[4] = θvel += τ * θacc

env.reward = finished(env, s) ? 0.0 : 1.0
env.reward, s
s = state(env)
x, xvel, θ, θvel = s

force = (a == 1 ? -1 : 1) * force_mag
tmp = (force + mass_pole_length * sin(θ) * (θvel^2)) / total_mass
θacc = (gravity * sin(θ) - tmp * cos(θ)) /
(pole_length * (4/3 - mass_pole * (cos(θ)^2) / total_mass))
xacc = tmp - mass_pole_length * θacc * cos(θ) / total_mass

# update state
s[1] = x += τ * xvel
s[2] = xvel += τ * xacc
s[3] = θ += τ * θvel
s[4] = θvel += τ * θacc

env.reward = finished(env, s) ? 0.0 : 1.0
env.reward, s
end

function finished(env::CartPole, s′)
x, xvel, θ, θvel = s′
!(-x_threshold <= x <= x_threshold &&
-θ_threshold <= θ <= θ_threshold)
x, xvel, θ, θvel = s′
!(-x_threshold <= x <= x_threshold &&
-θ_threshold <= θ <= θ_threshold)
end


# ------------------------------------------------------------------------

@recipe function f(env::CartPole)
x, xvel, θ, θvel = state(env)
legend := false
xlims := (-x_threshold, x_threshold)
ylims := (-Inf, 2pole_length)
grid := false
ticks := nothing

# pole
@series begin
linecolor := :red
linewidth := 10
[x, x + 2pole_length * sin(θ)], [0.0, 2pole_length * cos(θ)]
end

# cart
@series begin
seriescolor := :black
seriestype := :shape
hw = 0.5
l, r = x-hw, x+hw
t, b = 0.0, -0.1
[l, r, r, l], [t, t, b, b]
end
x, xvel, θ, θvel = state(env)
legend := false
xlims := (-x_threshold, x_threshold)
ylims := (-Inf, 2pole_length)
grid := false
ticks := nothing

# pole
@series begin
linecolor := :red
linewidth := 10
[x, x + 2pole_length * sin(θ)], [0.0, 2pole_length * cos(θ)]
end

# cart
@series begin
seriescolor := :black
seriestype := :shape
hw = 0.5
l, r = x-hw, x+hw
t, b = 0.0, -0.1
[l, r, r, l], [t, t, b, b]
end
end
Loading

0 comments on commit 3968119

Please sign in to comment.