-
Notifications
You must be signed in to change notification settings - Fork 13
/
solver.jl
executable file
·335 lines (297 loc) · 11.6 KB
/
solver.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
@with_kw mutable struct DeepQLearningSolver{E<:ExplorationPolicy} <: Solver
qnetwork::Any = nothing # intended to be a flux model
learning_rate::Float32 = 1f-4
max_steps::Int64 = 1000
batch_size::Int64 = 32
train_freq::Int64 = 4
eval_freq::Int64 = 500
target_update_freq::Int64 = 500
num_ep_eval::Int64 = 100
double_q::Bool = true
dueling::Bool = true
recurrence::Bool = false
evaluation_policy::Any = basic_evaluation
exploration_policy::E
trace_length::Int64 = 40
prioritized_replay::Bool = true
prioritized_replay_alpha::Float32 = 0.6f0
prioritized_replay_epsilon::Float32 = 1f-6
prioritized_replay_beta::Float32 = 0.4f0
buffer_size::Int64 = 1000
max_episode_length::Int64 = 100
train_start::Int64 = 200
rng::AbstractRNG = MersenneTwister(0)
logdir::Union{Nothing, String} = "log/"
save_freq::Int64 = 3000
log_freq::Int64 = 100
verbose::Bool = true
end
function POMDPs.solve(solver::DeepQLearningSolver, problem::MDP)
env = MDPCommonRLEnv{AbstractArray{Float32}}(problem) # ignores solver.rng because CommonRLEnv doesn't have rng support yet
return solve(solver, env)
end
function POMDPs.solve(solver::DeepQLearningSolver, problem::POMDP)
env = POMDPCommonRLEnv{AbstractArray{Float32}}(problem) # ignores solver.rng because CommonRLEnv doesn't have rng support yet
return solve(solver, env)
end
function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnv)
action_map = collect(actions(env))
action_indices = Dict(a=>i for (i, a) in enumerate(action_map))
# check reccurence
if isrecurrent(solver.qnetwork) && !solver.recurrence
throw("DeepQLearningError: you passed in a recurrent model but recurrence is set to false")
end
replay = initialize_replay_buffer(solver, env, action_indices)
if solver.dueling
active_q = create_dueling_network(solver.qnetwork)
else
active_q = solver.qnetwork
end
policy = NNPolicy(env, active_q, action_map, length(obs_dimensions(env)))
return dqn_train!(solver, env, policy, replay)
end
function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::AbstractNNPolicy, replay)
if solver.logdir !== nothing
logger = TBLogger(solver.logdir)
solver.logdir = logger.logdir
end
active_q = getnetwork(policy) # shallow copy
target_q = deepcopy(active_q)
optimizer = Adam(solver.learning_rate)
# start training
resetstate!(policy)
reset!(env)
obs = observe(env)
done = false
step = 0
rtot = 0
episode_rewards = Float64[0.0]
episode_steps = Int64[]
saved_mean_reward = -Inf
scores_eval = -Inf
model_saved = false
eval_next = false
save_next = false
action_indices = Dict(a=>i for (i, a) in enumerate(actionmap(policy)))
for t=1:solver.max_steps
act = action(solver.exploration_policy, policy, t, obs)
ai = action_indices[act]
rew = act!(env, act)
op = observe(env)
done = terminated(env)
exp = DQExperience(obs, ai, Float32(rew), op, done)
if solver.recurrence
add_exp!(replay, exp)
elseif solver.prioritized_replay
add_exp!(replay, exp, abs(exp.r))
else
add_exp!(replay, exp, 0f0)
end
obs = op
step += 1
episode_rewards[end] += rew
if done || step >= solver.max_episode_length
if eval_next # wait for episode to end before evaluating
scores_eval, steps_eval, info_eval = evaluation(solver.evaluation_policy,
policy, env,
solver.num_ep_eval,
solver.max_episode_length,
solver.verbose)
eval_next = false
# only save after evaluation
if save_next
model_saved, saved_mean_reward = save_model(solver, active_q, scores_eval, saved_mean_reward, model_saved)
save_next = false
end
if solver.logdir !== nothing
log_value(logger, "eval_reward", scores_eval, step = t)
log_value(logger, "eval_steps", steps_eval, step = t)
for (k, v) in info_eval
log_value(logger, k, v, step = t)
end
end
end
reset!(env)
obs = observe(env)
resetstate!(policy)
push!(episode_steps, step)
push!(episode_rewards, 0.0)
done = false
step = 0
rtot = 0
end
num_episodes = length(episode_rewards)
avg100_reward = mean(episode_rewards[max(1, length(episode_rewards)-101):end])
avg100_steps = mean(episode_steps[max(1, length(episode_steps)-101):end])
if t%solver.train_freq == 0
hs = hiddenstates(active_q)
loss_val, grad_val = batch_train!(solver, env, policy, optimizer, target_q, replay)
sethiddenstates!(active_q, hs)
end
if t%solver.target_update_freq == 0
weights = Flux.params(active_q)
Flux.loadparams!(target_q, weights)
end
if t % solver.eval_freq == 0
eval_next = true
end
if t % solver.save_freq == 0
save_next = true
end
if t % solver.log_freq == 0 && solver.logdir !== nothing
nt = loginfo(solver.exploration_policy, t)
for (k, v) in pairs(nt)
log_value(logger, String(k), v, step=t)
end
if solver.verbose
@printf("%5d / %5d eps %0.3f | avgR %1.3f | Loss %2.3e | Grad %2.3e | EvalR %1.3f \n",
t, solver.max_steps, nt[1], avg100_reward, loss_val, grad_val, scores_eval)
end
log_value(logger, "avg_reward", avg100_reward, step = t)
log_value(logger, "loss", loss_val, step = t)
log_value(logger, "grad_val", grad_val, step = t)
end
end # end training
if model_saved
if solver.verbose
@printf("Restore model with eval reward %1.3f \n", saved_mean_reward)
saved_model = BSON.load(joinpath(solver.logdir, "qnetwork.bson"))[:qnetwork]
Flux.loadparams!(getnetwork(policy), saved_model)
end
end
return policy
end
function initialize_replay_buffer(solver::DeepQLearningSolver, env::AbstractEnv, action_indices)
# init and populate replay buffer
if solver.recurrence
replay = EpisodeReplayBuffer(env, solver.buffer_size, solver.batch_size, solver.trace_length)
else
replay = PrioritizedReplayBuffer(env, solver.buffer_size, solver.batch_size)
end
populate_replay_buffer!(replay, env, action_indices, max_pop=solver.train_start)
return replay #XXX type unstable
end
function batch_train!(solver::DeepQLearningSolver,
env::AbstractEnv,
policy::AbstractNNPolicy,
optimizer,
target_q,
replay::PrioritizedReplayBuffer;
discount=default_discount(env)
)
s_batch, a_batch, r_batch, sp_batch, done_batch, indices, importance_weights = sample(replay)
active_q = getnetwork(policy)
p = Flux.params(active_q)
loss_val = nothing
td_vals = nothing
γ = convert(Float32, discount)
if solver.double_q
qp_values = active_q(sp_batch)
target_q_values = target_q(sp_batch)
best_a = [CartesianIndex(argmax(qp_values[:, i]), i) for i=1:solver.batch_size]
q_sp_max = target_q_values[best_a]
else
q_sp_max = dropdims(maximum(target_q(sp_batch), dims=1), dims=1)
end
q_targets = r_batch .+ (1f0 .- done_batch) .* γ .* q_sp_max
gs = Flux.gradient(p) do
q_values = active_q(s_batch)
q_sa = q_values[a_batch]
td_vals = q_sa .- q_targets
loss_val = sum(huber_loss, importance_weights.*td_vals)
loss_val /= solver.batch_size
end
grad_norm = globalnorm(p, gs)
Flux.Optimise.update!(optimizer, p, gs)
if solver.prioritized_replay
update_priorities!(replay, indices, td_vals)
end
return loss_val, grad_norm
end
# for RNNs
function batch_train!(solver::DeepQLearningSolver,
env::AbstractEnv,
policy::AbstractNNPolicy,
optimizer,
target_q,
replay::EpisodeReplayBuffer;
discount=default_discount(env)
)
active_q = getnetwork(policy)
s_batch, a_batch, r_batch, sp_batch, done_batch, trace_mask_batch = DeepQLearning.sample(replay)
Flux.reset!(active_q)
Flux.reset!(target_q)
p = Flux.params(active_q)
loss_val = nothing
td_vals = nothing
γ = convert(Float32, discount)
q_targets = [zeros(Float32, solver.batch_size) for i=1:solver.trace_length]
for i=1:solver.trace_length
if solver.double_q
qp_values = active_q(sp_batch[i])
best_a = [CartesianIndex(argmax(qp_values[:, i]), i) for i=1:solver.batch_size]
target_q_values = target_q(sp_batch[i])
q_sp_max = target_q_values[best_a]
else
q_sp_max = dropdims(maximum(target_q(sp_batch[i]), dims=1), dims=1)
end
q_targets[i] .= r_batch[i] .+ (1f0 .- done_batch[i]) .* γ .* q_sp_max
end
Flux.reset!(active_q)
gs = Flux.gradient(p) do
loss_val = 0f0
for i=1:solver.trace_length
q_values = active_q(s_batch[i])
q_sa = q_values[a_batch[i]]
td_vals = q_sa .- q_targets[i]
loss_val += sum(huber_loss, trace_mask_batch[i].*td_vals)/solver.batch_size
end
loss_val /= solver.trace_length
end
grad_norm = globalnorm(p, gs)
Flux.Optimise.update!(optimizer, p, gs)
return loss_val, grad_norm
end
function save_model(solver::DeepQLearningSolver, active_q, scores_eval::Float64, saved_mean_reward::Float64, model_saved::Bool)
if scores_eval >= saved_mean_reward
bson(joinpath(solver.logdir, "qnetwork.bson"), qnetwork=[w for w in Flux.params(active_q)])
if solver.verbose
@printf("Saving new model with eval reward %1.3f \n", scores_eval)
end
model_saved = true
saved_mean_reward = scores_eval
end
return model_saved, saved_mean_reward
end
function restore_best_model(solver::DeepQLearningSolver, problem::MDP)
env = convert(AbstractEnv, problem) # ignores solver.rng because CommonRLEnv doesn't have rng support yet
restore_best_model(solver, env)
end
function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnv)
if solver.dueling
active_q = create_dueling_network(solver.qnetwork)
else
active_q = solver.qnetwork
end
policy = NNPolicy(env, active_q, collect(actions(env)), length(obs_dimensions(env)))
weights = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork]
Flux.loadparams!(getnetwork(policy), weights)
Flux.testmode!(getnetwork(policy))
return policy
end
POMDPLinter.@POMDP_require POMDPs.solve(solver::DeepQLearningSolver, mdp::Union{MDP, POMDP}) begin
P = typeof(mdp)
S = POMDPs.statetype(P)
A = POMDPs.actiontype(P)
@req POMDPs.discount(::P)
@req POMDPs.actions(::P)
as = POMDPs.actions(mdp)
@req length(::typeof(as))
if isa(mdp, POMDP)
O = obstype(mdp)
@req POMDPs.convert_o(::Type{AbstractArray}, ::O, ::P)
else
@req POMDPs.convert_s(::Type{AbstractArray}, ::S, ::P)
end
@req POMDPs.reward(::P,::S,::A,::S)
end