From 8376cd3349a3039dea39a5fb59c1591bd9ea9acf Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Thu, 10 Jan 2019 17:15:29 -0800 Subject: [PATCH 1/4] progress --- src/DeepQLearning.jl | 3 ++- src/solver.jl | 27 ++++++++++++++++++--------- test/profile.jl | 19 ++++++++++++++++++- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/DeepQLearning.jl b/src/DeepQLearning.jl index f590584..92ac638 100755 --- a/src/DeepQLearning.jl +++ b/src/DeepQLearning.jl @@ -12,7 +12,8 @@ using POMDPPolicies using RLInterface using LinearAlgebra -export DeepQLearningSolver, +export AbstractDQNSolver, + DeepQLearningSolver, AbstractNNPolicy, NNPolicy, DQExperience, diff --git a/src/solver.jl b/src/solver.jl index 5e45ce5..85d30eb 100755 --- a/src/solver.jl +++ b/src/solver.jl @@ -51,7 +51,12 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment) active_q = solver.qnetwork end policy = NNPolicy(env.problem, active_q, ordered_actions(env.problem), length(obs_dimensions(env))) - target_q = deepcopy(solver.qnetwork) + return dqn_train!(solver, env, policy, replay) +end + +function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, policy::AbstractNNPolicy, replay) + active_q = solver.qnetwork # shallow copy + target_q = deepcopy(active_q) optimizer = ADAM(Flux.params(active_q), solver.learning_rate) # start training reset!(policy) @@ -87,7 +92,7 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment) avg100_steps = mean(episode_steps[max(1, length(episode_steps)-101):end]) if t%solver.train_freq == 0 hs = hiddenstates(active_q) - loss_val, td_errors, grad_val = batch_train!(solver, env, optimizer, active_q, target_q, replay) + loss_val, td_errors, grad_val = batch_train!(solver, env, policy, optimizer, target_q, replay) sethiddenstates!(active_q, hs) end @@ -165,14 +170,17 @@ end function batch_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, - optimizer, - active_q, + policy::NNPolicy, + optimizer, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, importance_weights) + active_q = policy.qnetwork loss_tracked, td_tracked = q_learning_loss(solver, env, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, importance_weights) loss_val = loss_tracked.data td_vals = Flux.data.(td_tracked) Flux.back!(loss_tracked) + @show active_q + @show params(active_q) grad_norm = globalnorm(params(active_q)) optimizer() return loss_val, td_vals, grad_norm @@ -196,22 +204,22 @@ end function batch_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, + policy::AbstractNNPolicy, optimizer, - active_q, target_q, replay::ReplayBuffer) s_batch, a_batch, r_batch, sp_batch, done_batch = sample(replay) - return batch_train!(solver, env, optimizer, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, ones(solver.batch_size)) + return batch_train!(solver, env, policy, optimizer, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, ones(solver.batch_size)) end function batch_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, + policy::AbstractNNPolicy, optimizer, - active_q, target_q, replay::PrioritizedReplayBuffer) s_batch, a_batch, r_batch, sp_batch, done_batch, indices, weights = sample(replay) - loss_val, td_vals, grad_norm = batch_train!(solver, env, optimizer, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, weights) + loss_val, td_vals, grad_norm = batch_train!(solver, env, policy, optimizer, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, weights) update_priorities!(replay, indices, td_vals) return loss_val, td_vals, grad_norm end @@ -219,10 +227,11 @@ end # for RNNs function batch_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, + policy::NNPolicy, optimizer, - active_q, target_q, replay::EpisodeReplayBuffer) + active_q = policy.qnetwork s_batch, a_batch, r_batch, sp_batch, done_batch, trace_mask_batch = DeepQLearning.sample(replay) Flux.reset!(active_q) Flux.reset!(target_q) diff --git a/test/profile.jl b/test/profile.jl index 83c859f..149ab27 100644 --- a/test/profile.jl +++ b/test/profile.jl @@ -8,7 +8,24 @@ using Flux using Profile using BenchmarkTools + +using Revise +using Random +using POMDPs +using DeepQLearning +using Flux rng = MersenneTwister(1) +include("test/test_env.jl") +mdp = TestMDP((5,5), 4, 6) +model = Chain(x->flattenbatch(x), Dense(100, 8, tanh), Dense(8, n_actions(mdp))) +solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, learning_rate=0.005, + eval_freq=2000,num_ep_eval=100, + log_freq = 500, + double_q = false, dueling=true, prioritized_replay=false, + rng=rng) + +policy = solve(solver, mdp) + mdp = SimpleGridWorld() @@ -16,7 +33,7 @@ model = Chain(Dense(2, 32, relu), LSTM(32,32), Dense(32, 32, relu), Dense(32, n_ solver = DeepQLearningSolver(qnetwork = model, prioritized_replay=false, max_steps=1000, learning_rate=0.001,log_freq=500, recurrence=true,trace_length=10, double_q=false, dueling=false, rng=rng, verbose=false) -@btime policy = solve(solver, mdp) +policy = solve(solver, mdp) @profile 1+1 From 87d462fb506119eb63603825a1da408325cdffc3 Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Fri, 11 Jan 2019 16:00:11 -0800 Subject: [PATCH 2/4] fix #8, made the implementation more modular --- src/DeepQLearning.jl | 1 + src/policy.jl | 18 +++++++++++++ src/solver.jl | 45 +++++++++++++++++++------------ test/{profile.jl => prototype.jl} | 13 +++++++++ test/runtests.jl | 1 + 5 files changed, 61 insertions(+), 17 deletions(-) rename test/{profile.jl => prototype.jl} (85%) diff --git a/src/DeepQLearning.jl b/src/DeepQLearning.jl index 92ac638..74e9cf4 100755 --- a/src/DeepQLearning.jl +++ b/src/DeepQLearning.jl @@ -16,6 +16,7 @@ export AbstractDQNSolver, DeepQLearningSolver, AbstractNNPolicy, NNPolicy, + qnetwork, DQExperience, restore_best_model, ReplayBuffer, diff --git a/src/policy.jl b/src/policy.jl index dfc613d..cf1d9ed 100755 --- a/src/policy.jl +++ b/src/policy.jl @@ -1,5 +1,19 @@ abstract type AbstractNNPolicy <: Policy end +## NN Policy interface + +""" + getnetwork(policy) + return the value network of the policy +""" +function getnetwork end + +""" + reset!(policy) +reset the hidden states of a policy +""" +function reset! end + struct NNPolicy{P <: Union{MDP, POMDP}, Q, A} <: AbstractNNPolicy problem::P qnetwork::Q @@ -7,6 +21,10 @@ struct NNPolicy{P <: Union{MDP, POMDP}, Q, A} <: AbstractNNPolicy n_input_dims::Int64 end +function getnetwork(policy::NNPolicy) + return policy.qnetwork +end + function reset!(policy::NNPolicy) Flux.reset!(policy.qnetwork) end diff --git a/src/solver.jl b/src/solver.jl index af8e160..22897cc 100755 --- a/src/solver.jl +++ b/src/solver.jl @@ -55,7 +55,7 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment) end function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, policy::AbstractNNPolicy, replay) - active_q = solver.qnetwork # shallow copy + active_q = getnetwork(policy) # shallow copy target_q = deepcopy(active_q) optimizer = ADAM(Flux.params(active_q), solver.learning_rate) # start training @@ -69,6 +69,7 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, polic saved_mean_reward = -Inf scores_eval = -Inf model_saved = false + eval_next = false for t=1:solver.max_steps act, eps = exploration(solver.exploration_policy, policy, env, obs, t, solver.rng) ai = actionindex(env.problem, act) @@ -79,6 +80,15 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, polic step += 1 episode_rewards[end] += rew if done || step >= solver.max_episode_length + if eval_next # wait for episode to end before evaluating + scores_eval = evaluation(solver.evaluation_policy, + policy, env, + solver.num_ep_eval, + solver.max_episode_length, + solver.verbose) + eval_next = false + end + obs = reset(env) reset!(policy) push!(episode_steps, step) @@ -102,13 +112,16 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, polic end if t%solver.eval_freq == 0 - saved_state = env.state + eval_next = true + end + + if eval_next && (done || step >= solver.max_episode_length) # wait for episode to end before evaluating scores_eval = evaluation(solver.evaluation_policy, - policy, env, - solver.num_ep_eval, - solver.max_episode_length, - solver.verbose) - env.state = saved_state + policy, env, + solver.num_ep_eval, + solver.max_episode_length, + solver.verbose) + eval_next = false end if t%solver.log_freq == 0 @@ -127,7 +140,7 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, polic if solver.verbose @printf("Restore model with eval reward %1.3f \n", saved_mean_reward) saved_model = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork] - Flux.loadparams!(policy.qnetwork, saved_model) + Flux.loadparams!(getnetwork(policy), saved_model) end end return policy @@ -152,8 +165,8 @@ function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnvironmen end policy = NNPolicy(env.problem, active_q, ordered_actions(env.problem), length(obs_dimensions(env))) weights = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork] - Flux.loadparams!(policy.qnetwork, weights) - Flux.testmode!(policy.qnetwork) + Flux.loadparams!(getnetwork(policy), weights) + Flux.testmode!(getnetwork(policy)) return policy end @@ -172,17 +185,15 @@ end function batch_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, - policy::NNPolicy, + policy::AbstractNNPolicy, optimizer, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, importance_weights) - active_q = policy.qnetwork + active_q = getnetwork(policy) loss_tracked, td_tracked = q_learning_loss(solver, env, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, importance_weights) loss_val = loss_tracked.data td_vals = Flux.data.(td_tracked) Flux.back!(loss_tracked) - @show active_q - @show params(active_q) grad_norm = globalnorm(params(active_q)) optimizer() return loss_val, td_vals, grad_norm @@ -221,7 +232,7 @@ function batch_train!(solver::DeepQLearningSolver, target_q, replay::PrioritizedReplayBuffer) s_batch, a_batch, r_batch, sp_batch, done_batch, indices, weights = sample(replay) - loss_val, td_vals, grad_norm = batch_train!(solver, env, policy, optimizer, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, weights) + loss_val, td_vals, grad_norm = batch_train!(solver, env, policy, optimizer, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, weights) update_priorities!(replay, indices, td_vals) return loss_val, td_vals, grad_norm end @@ -229,11 +240,11 @@ end # for RNNs function batch_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, - policy::NNPolicy, + policy::AbstractNNPolicy, optimizer, target_q, replay::EpisodeReplayBuffer) - active_q = policy.qnetwork + active_q = getnetwork(policy) s_batch, a_batch, r_batch, sp_batch, done_batch, trace_mask_batch = DeepQLearning.sample(replay) Flux.reset!(active_q) Flux.reset!(target_q) diff --git a/test/profile.jl b/test/prototype.jl similarity index 85% rename from test/profile.jl rename to test/prototype.jl index 149ab27..261c13b 100644 --- a/test/profile.jl +++ b/test/prototype.jl @@ -43,6 +43,19 @@ Profile.clear() ProfileView.view() +### Try on SubHunt + +using Revise +using POMDPs +using SubHunt +using RLInterface +using DeepQLearning +using Flux + +solver = DeepQLearningSolver(qnetwork= Chain(Dense(8, 32, relu), Dense(32,32,relu), Dense(32, 6)), + max_steps=100_000) +solve(solver, SubHuntPOMDP()) + ### get q_sa diff --git a/test/runtests.jl b/test/runtests.jl index 9d3968f..354197f 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Flux using Random using RLInterface using Test +srand(1) # for test consistency include("test_env.jl") From c82bb222199830e41038b8fae920630130790a8a Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Fri, 11 Jan 2019 16:01:45 -0800 Subject: [PATCH 3/4] qnetwork->getnetwork, no AbstractDQNSolver --- src/DeepQLearning.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/DeepQLearning.jl b/src/DeepQLearning.jl index 74e9cf4..56a2cc7 100755 --- a/src/DeepQLearning.jl +++ b/src/DeepQLearning.jl @@ -12,11 +12,10 @@ using POMDPPolicies using RLInterface using LinearAlgebra -export AbstractDQNSolver, - DeepQLearningSolver, +export DeepQLearningSolver, AbstractNNPolicy, NNPolicy, - qnetwork, + getnetwork, DQExperience, restore_best_model, ReplayBuffer, From 28ad6e13f1c1591866dcaafdf3a68023887b8755 Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Fri, 11 Jan 2019 16:02:55 -0800 Subject: [PATCH 4/4] Random.srand --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 354197f..dd22b3e 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using Flux using Random using RLInterface using Test -srand(1) # for test consistency +Random.srand(1) # for test consistency include("test_env.jl")