Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions src/envs/pendulum.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,59 @@
module PendulumEnv
# Ported from: https://github.com/openai/gym/blob/996e5115621bf57b34b9b79941e629a36a709ea1/gym/envs/classic_control/pendulum.py
# https://github.com/openai/gym/wiki/Pendulum-v0

using Reinforce: AbstractEnvironment
using LearnBase: IntervalSet
using RecipesBase
using Distributions
using Random: seed!

import Reinforce: reset!, actions, finished, step!
import Reinforce: reset!, actions, finished, step!, state

export
Pendulum,
reset!,
step!,
actions,
finished
finished,
state

const max_speed = 8.0
const max_torque = 2.0

angle_normalize(x) = ((x+π) % (2π)) - π

mutable struct PendulumState
θ::Float64
θvel::Float64
mutable struct PendulumState{T<:AbstractFloat} <: AbstractVector{T}
θ::T
θvel::T
end

mutable struct Pendulum <: AbstractEnvironment
state::PendulumState
PendulumState() = PendulumState(0., 0.)

Base.size(::PendulumState) = (2,)

function Base.getindex(s::PendulumState, i::Int)
(i > length(s)) && throw(BoundsError(s, i))
ifelse(i == 1, s.θ, s.θvel)
end

function Base.setindex!(s::PendulumState, x, i::Int)
(i > length(s)) && throw(BoundsError(s, i))
setproperty!(s, ifelse(i == 1, :θ, :θvel), x)
end

mutable struct Pendulum{V<:AbstractVector} <: AbstractEnvironment
state::V
reward::Float64
a::Float64 # last action for rendering
steps::Int
maxsteps::Int
end
Pendulum(maxsteps=500) = Pendulum(PendulumState(0.,0.),0.,0.,0,maxsteps)

Pendulum(maxsteps = 500) = Pendulum(PendulumState(),0., 0., 0, maxsteps)

function reset!(env::Pendulum)
env.state.θ = rand(Uniform(-π, π))
env.state.θvel = rand(Uniform(-1., 1.))
env.state = PendulumState(rand(Uniform(-π, π)), rand(Uniform(-1., 1.)))
env.reward = 0.0
env.a = 0.0
env.steps = 0
Expand All @@ -47,8 +63,7 @@ end
actions(env::Pendulum, s) = IntervalSet(-max_torque, max_torque)

function step!(env::Pendulum, s, a)
θ = env.state.θ
θvel = env.state.θvel
θ, θvel = env.state
g = 10.0
m = 1.0
l = 1.0
Expand All @@ -62,17 +77,15 @@ function step!(env::Pendulum, s, a)
newθvel = θvel + (-1.5g/l * sin(θ+π) + 3/(m*l^2)*a) * dt
newθ = θ + newθvel * dt
newθvel = clamp(newθvel, -max_speed, max_speed)
env.state.θ = newθ
env.state.θvel = newθvel
env.state = PendulumState(newθ, newθvel)

env.steps += 1
env.reward, state(env)
env.reward, env.state
end


function state(env::Pendulum)
θ = env.state.θ
θvel = env.state.θvel
θ, θvel = env.state
Float64[cos(θ), sin(θ), θvel]
end

Expand Down Expand Up @@ -103,7 +116,7 @@ finished(env::Pendulum, s′) = env.steps >= env.maxsteps
seriestype := :scatter
markersize := 10
markercolor := :black
annotations := [(0, -0.2, "a: $(env.a)", :top)]
annotations := [(0, -0.2, "a: $(round(env.a, digits = 4))", :top)]
[0],[0]
end
end
Expand Down