/
multi_thread_env.jl
184 lines (159 loc) · 4.78 KB
/
multi_thread_env.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
export MultiThreadEnv
using Base.Threads: @spawn
"""
MultiThreadEnv(envs::Vector{<:AbstractEnv})
Wrap multiple instances of the same environment type into one environment.
Each environment will run in parallel by leveraging `Threads.@spawn`.
So remember to set the environment variable `JULIA_NUM_THREADS`!
"""
struct MultiThreadEnv{E,S,R,AS,SS,L} <: AbstractEnv
envs::Vector{E}
states::S
rewards::R
terminals::BitArray{1}
action_space::AS
state_space::SS
legal_action_space_mask::L
end
function Base.show(io::IO, t::MIME"text/markdown", env::MultiThreadEnv)
print(io, "MultiThreadEnv($(length(env)) x $(nameof(env[1])))")
end
"""
MultiThreadEnv(f, n::Int)
`f` is a lambda function which creates an `AbstractEnv` by calling `f()`.
"""
MultiThreadEnv(f, n::Int) = MultiThreadEnv([f() for _ in 1:n])
function MultiThreadEnv(envs::Vector{<:AbstractEnv})
n = length(envs)
S = state_space(envs[1])
s = state(envs[1])
if S isa Space
S_batch = similar(S, size(S)..., n)
s_batch = similar(s, size(s)..., n)
for j in 1:n
Sₙ = state_space(envs[j])
sₙ = state(envs[j])
for i in CartesianIndices(size(S))
S_batch[i, j] = Sₙ[i]
s_batch[i, j] = sₙ[i]
end
end
else
S_batch = Space(state_space.(envs))
s_batch = state.(envs)
end
A = action_space(envs[1])
if A isa Space
A_batch = similar(A, size(A)..., n)
for j in 1:n
Aⱼ = action_space(envs[j])
for i in CartesianIndices(size(A))
A_batch[i, j] = Aⱼ[i]
end
end
else
A_batch = Space(action_space.(envs))
end
r_batch = reward.(envs)
t_batch = is_terminated.(envs)
if ActionStyle(envs[1]) === FULL_ACTION_SET
m_batch = BitArray(undef, size(A)..., n)
for j in 1:n
L = legal_action_space_mask(envs[j])
for i in CartesianIndices(size(A))
m_batch[i, j] = L[i]
end
end
else
m_batch = nothing
end
MultiThreadEnv(envs, s_batch, r_batch, t_batch, A_batch, S_batch, m_batch)
end
MacroTools.@forward MultiThreadEnv.envs Base.getindex, Base.length, Base.iterate
function (env::MultiThreadEnv)(actions)
N = ndims(actions)
@sync for i in 1:length(env)
@spawn begin
if N == 1
env[i](actions[i])
else
env[i](selectdim(actions, N, i))
end
end
end
end
function RLBase.reset!(env::MultiThreadEnv; is_force = false)
if is_force
for i in 1:length(env)
reset!(env[i])
end
else
@sync for i in 1:length(env)
if is_terminated(env[i])
@spawn begin
reset!(env[i])
end
end
end
end
end
const MULTI_THREAD_ENV_CACHE = IdDict{AbstractEnv,Dict{Symbol,Array}}()
function RLBase.state(env::MultiThreadEnv)
N = ndims(env.states)
@sync for i in 1:length(env)
@spawn selectdim(env.states, N, i) .= state(env[i])
end
env.states
end
function RLBase.reward(env::MultiThreadEnv)
env.rewards .= reward.(env.envs)
env.rewards
end
function RLBase.is_terminated(env::MultiThreadEnv)
env.terminals .= is_terminated.(env.envs)
env.terminals
end
function RLBase.legal_action_space_mask(env::MultiThreadEnv)
N = ndims(env.states)
@sync for i in 1:length(env)
@spawn selectdim(env.legal_action_space_mask, N, i) .=
legal_action_space_mask(env[i])
end
env.legal_action_space_mask
end
RLBase.action_space(env::MultiThreadEnv) = env.action_space
RLBase.state_space(env::MultiThreadEnv) = env.state_space
RLBase.legal_action_space(env::MultiThreadEnv) = Space(legal_action_space.(env.envs))
# RLBase.current_player(env::MultiThreadEnv) = current_player.(env.envs)
for f in RLBase.ENV_API
if endswith(String(f), "Style")
@eval RLBase.$f(x::MultiThreadEnv) = $f(x[1])
end
end
#####
# Patches
#####
(env::MultiThreadEnv)(action::EnrichedAction) = env(action.action)
function (π::QBasedPolicy)(env::MultiThreadEnv, ::MinimalActionSet, A)
[A[i][a] for (i, a) in enumerate(π.explorer(π.learner(env)))]
end
function (π::QBasedPolicy)(env::MultiThreadEnv, ::FullActionSet, A)
[
A[i][a] for
(i, a) in enumerate(π.explorer(π.learner(env), legal_action_space_mask(env)))
]
end
function (π::QBasedPolicy)(
env::MultiThreadEnv,
::MinimalActionSet,
::Space{<:Vector{<:Base.OneTo{<:Integer}}},
)
π.explorer(π.learner(env))
end
function (π::QBasedPolicy)(
env::MultiThreadEnv,
::FullActionSet,
::Space{<:Vector{<:Base.OneTo{<:Integer}}},
)
π.explorer(π.learner(env), legal_action_space_mask(env))
end