Skip to content

Commit

Permalink
Load only methods with @require (#68)
Browse files Browse the repository at this point in the history
* fix render

* parametric action

* import OrdinaryDiffEq

* fixes #27

* Array -> AbstractArray

Co-authored-by: Jun Tian <find_my_way@foxmail.com>

* Array -> AbstractArray

Co-authored-by: Jun Tian <find_my_way@foxmail.com>
  • Loading branch information
jbrea and findmyway committed Jun 10, 2020
1 parent 81f682b commit e4108dd
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 124 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module ReinforcementLearningEnvironments

using ReinforcementLearningBase
using Random
using GR
using Requires

export RLEnvs
const RLEnvs = ReinforcementLearningEnvironments
export RLEnvs

using Requires

# built-in environments
include("environments/non_interactive/non_interactive.jl")
include("environments/classic_control/classic_control.jl")
include("environments/structs.jl")

# dynamic loading environments
function __init__()
Expand Down
40 changes: 9 additions & 31 deletions src/ReinforcementLearningEnvironments/src/environments/atari.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,5 @@
module AtariWrapper

using ArcadeLearningEnvironment, GR, Random
using ReinforcementLearningBase

export AtariEnv

mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv
ale::Ptr{Nothing}
screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling
actions::Vector{Int64}
action_space::DiscreteSpace{UnitRange{Int}}
observation_space::MultiDiscreteSpace{Array{UInt8,N}}
noopmax::Int
frame_skip::Int
reward::Float32
lives::Int
seed::S
end
using .ArcadeLearningEnvironment


"""
AtariEnv(;kwargs...)
Expand All @@ -27,7 +10,7 @@ TODO: support seed! in single/multi thread
# Keywords
- `name::String="pong"`: name of the Atari environments. Use `getROMList` to show all supported environments.
- `name::String="pong"`: name of the Atari environments. Use `ReinforcementLearningEnvironments.list_atari_rom_names()` to show all supported environments.
- `grayscale_obs::Bool=true`:if `true`, then gray scale observation is returned, otherwise, RGB observation is returned.
- `noop_max::Int=30`: max number of no-ops.
- `frame_skip::Int=4`: the frequency at which the agent experiences the game.
Expand All @@ -39,6 +22,7 @@ TODO: support seed! in single/multi thread
See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
"""
AtariEnv(name; kwargs...) = AtariEnv(; name = name, kwargs...)
function AtariEnv(;
name = "pong",
grayscale_obs = true,
Expand All @@ -53,7 +37,7 @@ function AtariEnv(;
)
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
name in getROMList() ||
throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))
throw(ArgumentError("unknown ROM name.\n\nRun `ReinforcementLearningEnvironments.list_atari_rom_names()` to see all the game names."))

if isnothing(seed)
seed = (MersenneTwister(), 0)
Expand Down Expand Up @@ -148,11 +132,10 @@ function RLBase.reset!(env::AtariEnv)
end


imshowgrey(x::Array{UInt8,2}) = imshowgrey(x[:], size(x))
imshowgrey(x::Array{UInt8,1}, dims) = imshow(reshape(x, dims), colormap = 2)
imshowcolor(x::Array{UInt8,3}) = imshowcolor(x[:], size(x))

function imshowcolor(x::Array{UInt8,1}, dims)
imshowgrey(x::AbstractArray{UInt8,2}) = imshowgrey(reshape(x, :), size(x))
imshowgrey(x::AbstractArray{UInt8,1}, dims) = imshow(reshape(x, dims), colormap = 2)
imshowcolor(x::AbstractArray{UInt8,3}) = imshowcolor(reshape(x, :), size(x))
function imshowcolor(x::AbstractArray{UInt8,1}, dims)
clearws()
setviewport(0, dims[1] / dims[2], 0, 1)
setwindow(0, 1, 0, 1)
Expand All @@ -171,8 +154,3 @@ function RLBase.render(env::AtariEnv)
end

list_atari_rom_names() = getROMList()

end

using .AtariWrapper
export AtariEnv
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Random
using OrdinaryDiffEq
import OrdinaryDiffEq

export AcrobotEnv

Expand Down Expand Up @@ -146,8 +145,8 @@ function (env::AcrobotEnv{T})(a) where {T<:Number}
# augmented state for derivative function
s_augmented = [env.state..., torque]

ode = ODEProblem(dsdt, s_augmented, (0.0, env.params.dt), env)
ns = solve(ode, RK4())
ode = OrdinaryDiffEq.ODEProblem(dsdt, s_augmented, (0.0, env.params.dt), env)
ns = OrdinaryDiffEq.solve(ode, OrdinaryDiffEq.RK4())
# only care about final timestep of integration returned by integrator
ns = ns.u[end]
ns = ns[1:4] # omit action
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Random

export CartPoleEnv

struct CartPoleEnvParams{T}
Expand Down Expand Up @@ -128,3 +126,30 @@ function (env::CartPoleEnv)(a)
end

Random.seed!(env::CartPoleEnv, seed) = Random.seed!(env.rng, seed)

function plotendofepisode(x, y, d)
if d
setmarkercolorind(7)
setmarkertype(-1)
setmarkersize(6)
polymarker([x], [y])
end
return nothing
end
function RLBase.render(env::CartPoleEnv)
s, a, d = env.state, env.action, env.done
x, xdot, theta, thetadot = s
l = 2 * env.params.halflength
clearws()
setviewport(0, 1, 0, 1)
xthreshold = env.params.xthreshold
setwindow(-xthreshold, xthreshold, -.1, l + .1)
fillarea([x-.5, x-.5, x+.5, x+.5], [-.05, 0, 0, -.05])
setlinecolorind(4)
setlinewidth(3)
polyline([x, x + l * sin(theta)], [0, l * cos(theta)])
setlinecolorind(2)
drawarrow(x + (a == 1) - .5, -.025, x + 1.4 * (a==1) - .7, -.025)
plotendofepisode(xthreshold - .2, l, d)
updatews()
end
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Random

export MountainCarEnv, ContinuousMountainCarEnv

struct MountainCarEnvParams{T}
Expand Down Expand Up @@ -36,12 +34,12 @@ function MountainCarEnvParams(;
)
end

mutable struct MountainCarEnv{A,T,R<:AbstractRNG} <: AbstractEnv
mutable struct MountainCarEnv{A,T,ACT,R<:AbstractRNG} <: AbstractEnv
params::MountainCarEnvParams{T}
action_space::A
observation_space::MultiContinuousSpace{Vector{T}}
state::Vector{T}
action::Union{Int,AbstractFloat}
action::ACT
done::Bool
t::Int
rng::R
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Random

export PendulumEnv

struct PendulumEnvParams{T}
Expand All @@ -12,7 +10,7 @@ struct PendulumEnvParams{T}
max_steps::Int
end

mutable struct PendulumEnv{A,T,R<:AbstractRNG} <: AbstractEnv
mutable struct PendulumEnv{A,T,ACT,R<:AbstractRNG} <: AbstractEnv
params::PendulumEnvParams{T}
action_space::A
observation_space::MultiContinuousSpace{Vector{T}}
Expand All @@ -22,7 +20,7 @@ mutable struct PendulumEnv{A,T,R<:AbstractRNG} <: AbstractEnv
rng::R
reward::T
n_actions::Int
action::Union{Int,AbstractFloat}
action::ACT
end

"""
Expand Down
49 changes: 30 additions & 19 deletions src/ReinforcementLearningEnvironments/src/environments/gym.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
module GymWrapper

using ReinforcementLearningBase
using PyCall

export GymEnv
using .PyCall

# TODO: support `seed`

struct GymEnv{T,Ta<:AbstractSpace,To<:AbstractSpace} <: AbstractEnv
pyenv::PyObject
observation_space::To
action_space::Ta
state::PyObject
end

function GymEnv(name::String)
if !PyCall.pyexists("gym")
error("Cannot import module 'gym'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n")
end
gym = pyimport("gym")
pyenv = gym.make(name)
pyenv = try gym.make(name)
catch e
error("Gym environment $name not found.\n\nRun `ReinforcementLearningEnvironments.list_gym_env_names()` to find supported environments.\n")
end
obs_space = convert(AbstractSpace, pyenv.observation_space)
act_space = convert(AbstractSpace, pyenv.action_space)
obs_type = if obs_space isa Union{MultiContinuousSpace,MultiDiscreteSpace}
Expand All @@ -32,7 +25,7 @@ function GymEnv(name::String)
else
error("don't know how to get the observation type from observation space of $obs_space")
end
env = GymEnv{obs_type,typeof(act_space),typeof(obs_space)}(
env = GymEnv{obs_type,typeof(act_space),typeof(obs_space),typeof(pyenv)}(
pyenv,
obs_space,
act_space,
Expand Down Expand Up @@ -115,7 +108,25 @@ function list_gym_env_names(;
[x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules]
end

"""
install_gym(; packages = ["gym", "pybullet"])
"""
function install_gym(; packages = ["gym", "pybullet"])
# Use eventual proxy info
proxy_arg=String[]
if haskey(ENV, "http_proxy")
push!(proxy_arg, "--proxy")
push!(proxy_arg, ENV["http_proxy"])
end
# Import pip
if !PyCall.pyexists("pip")
# If it is not found, install it
println("Pip not found on your system. Downloading it.")
get_pip = joinpath(dirname(@__FILE__), "get-pip.py")
download("https://bootstrap.pypa.io/get-pip.py", get_pip)
run(`$(PyCall.python) $(proxy_arg) $get_pip --user`)
end
println("Installing required python packages using pip")
run(`$(PyCall.python) $(proxy_arg) -m pip install --user --upgrade pip setuptools`)
run(`$(PyCall.python) $(proxy_arg) -m pip install --user $(packages)`)
end

using .GymWrapper
export GymEnv
30 changes: 1 addition & 29 deletions src/ReinforcementLearningEnvironments/src/environments/mdp.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
module POMDPWrapper

export POMDPEnv

using ReinforcementLearningBase
using POMDPs
using Random
using .POMDPs

RLBase.get_action_space(m::Union{<:POMDP,<:MDP}) = convert(AbstractSpace, actions(m))

#####
# POMDPEnv
#####

mutable struct POMDPEnv{M<:POMDP,S,O,I,R,RNG<:AbstractRNG} <: AbstractEnv
model::M
state::S
observation::O
info::I
reward::R
rng::RNG
end

Random.seed!(env::POMDPEnv, seed) = Random.seed!(env.rng, seed)

function POMDPEnv(model::POMDP; seed = nothing)
Expand Down Expand Up @@ -77,14 +62,6 @@ RLBase.get_action_space(env::POMDPEnv) = get_action_space(env.model)
# MDPEnv
#####

mutable struct MDPEnv{M<:MDP,S,I,R,RNG<:AbstractRNG} <: AbstractEnv
model::M
state::S
info::I
reward::R
rng::RNG
end

Random.seed!(env::MDPEnv, seed) = seed!(env.rng, seed)

function MDPEnv(model::MDP; seed = nothing)
Expand Down Expand Up @@ -132,8 +109,3 @@ end

RLBase.get_observation_space(env::MDPEnv) = get_observation_space(env.model)
RLBase.get_action_space(env::MDPEnv) = get_action_space(env.model)

end

using .POMDPWrapper
export POMDPEnv
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Random

export PendulumNonInteractiveEnv

struct PendulumNonInteractiveEnvParams{Fl<:AbstractFloat}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
module OpenSpielWrapper

export OpenSpielEnv

using ReinforcementLearningBase
using OpenSpiel
using Random
import .OpenSpiel: load_game, get_type, provides_information_state_tensor,
provides_observation_tensor, dynamics, new_initial_state, chance_mode,
is_chance_node, information_state_tensor, information_state_tensor_size,
num_distinct_actions, num_players, apply_action, current_player, player_reward,
legal_actions, legal_actions_mask, rewards, history, observation_tensor_size,
observation_tensor, chance_outcomes
using StatsBase: sample, weights

abstract type AbstractObservationType end

mutable struct OpenSpielEnv{O,D,S,G,R} <: AbstractEnv
state::S
game::G
rng::R
end

"""
OpenSpielEnv(name; observation_type=nothing, kwargs...)
# Arguments
- `name`::`String`, you can call `resigtered_names()` to see all the supported names. Note that the name can contains parameters, like `"goofspiel(imp_info=True,num_cards=4,points_order=descending)"`. Because the parameters part is parsed by the backend C++ code, the bool variable must be `True` or `False` (instead of `true` or `false`). Another approach is to just specify parameters in `kwargs` in the Julia style.
- `name`::`String`, you can call `ReinforcementLearningEnvironments.OpenSpiel.registered_names()` to see all the supported names. Note that the name can contains parameters, like `"goofspiel(imp_info=True,num_cards=4,points_order=descending)"`. Because the parameters part is parsed by the backend C++ code, the bool variable must be `True` or `False` (instead of `true` or `false`). Another approach is to just specify parameters in `kwargs` in the Julia style.
- `observation_type`::`Union{Symbol,Nothing}`, Supported values are [`:information`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L342-L367), [`:observation`](https://github.com/deepmind/open_spiel/blob/1ad92a54f3b800394b2bc7f178ccdff62d8369e1/open_spiel/spiel.h#L397-L408) or `nothing`. The default value is `nothing`, which means `:information` if the game ` provides_information_state_tensor`. If not, it means `:observation`.
"""
function OpenSpielEnv(name; seed = nothing, observation_type = nothing, kwargs...)
Expand Down Expand Up @@ -114,11 +106,6 @@ end
(env::OpenSpielEnv)(::Simultaneous, player, action) =
@error "Simultaneous environments can not take in the actions from players seperately"

struct OpenSpielObs{O,D,S}
state::S
player::Int32
end

RLBase.observe(env::OpenSpielEnv{O,D,S}, player) where {O,D,S} =
OpenSpielObs{O,D,S}(env.state, player)

Expand Down Expand Up @@ -161,8 +148,3 @@ RLBase.get_state(obs::OpenSpielObs{:observation}) =

RLBase.get_history(obs::OpenSpielObs) = history(obs.state)
RLBase.get_history(env::OpenSpielEnv) = history(env.state)

end

using .OpenSpielWrapper
export OpenSpielEnv
Loading

0 comments on commit e4108dd

Please sign in to comment.