/
gym.jl
169 lines (155 loc) · 5.33 KB
/
gym.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
using .PyCall
function GymEnv(name::String; seed::Union{Int, Nothing}=nothing)
if !PyCall.pyexists("gym")
error(
"Cannot import module 'gym'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n",
)
end
gym = pyimport_conda("gym", "gym")
if PyCall.pyexists("d4rl") pyimport("d4rl") end
pyenv = try
gym.make(name)
catch e
error(
"Gym environment $name not found.\n\nRun `ReinforcementLearningEnvironments.list_gym_env_names()` to find supported environments.\n",
)
end
if seed !== nothing pyenv.seed(seed) end
obs_space = space_transform(pyenv.observation_space)
act_space = space_transform(pyenv.action_space)
obs_type = if obs_space isa Space{<:Union{Array{<:Interval},Array{<:ZeroTo}}}
PyArray
elseif obs_space isa Interval
Float64
elseif obs_space isa ZeroTo
Int
elseif obs_space isa Space{<:Tuple}
PyVector
elseif obs_space isa Space{<:Dict}
PyDict
else
error("don't know how to get the observation type from observation space of $obs_space")
end
env = GymEnv{obs_type,typeof(act_space),typeof(obs_space),typeof(pyenv)}(
pyenv,
obs_space,
act_space,
PyNULL(),
)
reset!(env) # reset immediately to init env.state
env
end
Base.nameof(env::GymEnv) = env.pyenv.__class__.__name__
function Base.copy(env::GymEnv)
@warn "clone method is not exposed in GymEnv"
env
end
function (env::GymEnv{T})(action) where {T}
if env.action_space isa Tuple
action = Tuple(action)
end
pycall!(env.state, env.pyenv.step, PyObject, action)
nothing
end
function RLBase.reset!(env::GymEnv)
pycall!(env.state, env.pyenv.reset, PyObject)
nothing
end
RLBase.action_space(env::GymEnv) = env.action_space
RLBase.state_space(env::GymEnv) = env.observation_space
function RLBase.reward(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
reward
else
0.0
end
end
function RLBase.is_terminated(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
isdone
else
false
end
end
function RLBase.state(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
obs
else
convert(T, env.state)
end
end
Random.seed!(env::GymEnv, s) = env.pyenv.seed(s)
# Base.display(env::GymEnv) = env.pyenv.render()
###
### utils
###
function space_transform(s::PyObject)
spacetype = s.__class__.__name__
if spacetype == "Box"
Space(ClosedInterval.(s.low, s.high))
elseif spacetype == "Discrete" # for GymEnv("CliffWalking-v0"), `s.n` is of type PyObject (numpy.int64)
ZeroTo(py"int($s.n)" - 1)
elseif spacetype == "MultiBinary"
Space(ZeroTo.(ones(Int8, s.n)))
elseif spacetype == "MultiDiscrete"
Space(ZeroTo.(s.nvec .- one(eltype(s.nvec))))
elseif spacetype == "Tuple"
Space(Tuple(space_transform(x) for x in s.spaces))
elseif spacetype == "Dict"
Space(Dict((k => space_transform(v) for (k, v) in s.spaces)...))
else
error("Don't know how to convert Gym Space of class [$(spacetype)]")
end
end
function list_gym_env_names(;
modules = [
"gym.envs.algorithmic",
"gym.envs.box2d",
"gym.envs.classic_control",
"gym.envs.mujoco",
"gym.envs.mujoco.ant_v3",
"gym.envs.mujoco.half_cheetah_v3",
"gym.envs.mujoco.hopper_v3",
"gym.envs.mujoco.humanoid_v3",
"gym.envs.mujoco.swimmer_v3",
"gym.envs.mujoco.walker2d_v3",
"gym.envs.robotics",
"gym.envs.toy_text",
"gym.envs.unittest",
"d4rl.pointmaze",
"d4rl.hand_manipulation_suite",
"d4rl.gym_mujoco.gym_envs",
"d4rl.locomotion.ant",
"d4rl.gym_bullet.gym_envs",
"d4rl.pointmaze_bullet.bullet_maze", # yet to include flow and carla
],
)
if PyCall.pyexists("d4rl") pyimport("d4rl") end
gym = pyimport("gym")
[x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules]
end
"""
install_gym(; packages = ["gym", "pybullet"])
"""
function install_gym(; packages = ["gym", "pybullet"])
# Use eventual proxy info
proxy_arg = String[]
if haskey(ENV, "http_proxy")
push!(proxy_arg, "--proxy")
push!(proxy_arg, ENV["http_proxy"])
end
# Import pip
if !PyCall.pyexists("pip")
# If it is not found, install it
println("Pip not found on your system. Downloading it.")
get_pip = joinpath(dirname(@__FILE__), "get-pip.py")
download("https://bootstrap.pypa.io/get-pip.py", get_pip)
run(`$(PyCall.python) $(proxy_arg) $get_pip --user`)
end
println("Installing required python packages using pip")
run(`$(PyCall.python) $(proxy_arg) -m pip install --user --upgrade pip setuptools`)
run(`$(PyCall.python) $(proxy_arg) -m pip install --user $(packages)`)
end