Skip to content

Commit

Permalink
Merge 28ad6e1 into cf60925
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeBouton committed Jan 12, 2019
2 parents cf60925 + 28ad6e1 commit 457716a
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/DeepQLearning.jl
Expand Up @@ -15,6 +15,7 @@ using LinearAlgebra
export DeepQLearningSolver,
AbstractNNPolicy,
NNPolicy,
getnetwork,
DQExperience,
restore_best_model,
ReplayBuffer,
Expand Down
18 changes: 18 additions & 0 deletions src/policy.jl
@@ -1,12 +1,30 @@
abstract type AbstractNNPolicy <: Policy end

## NN Policy interface

"""
getnetwork(policy)
return the value network of the policy
"""
function getnetwork end

"""
reset!(policy)
reset the hidden states of a policy
"""
function reset! end

struct NNPolicy{P <: Union{MDP, POMDP}, Q, A} <: AbstractNNPolicy
problem::P
qnetwork::Q
action_map::Vector{A}
n_input_dims::Int64
end

function getnetwork(policy::NNPolicy)
return policy.qnetwork
end

function reset!(policy::NNPolicy)
Flux.reset!(policy.qnetwork)
end
Expand Down
56 changes: 38 additions & 18 deletions src/solver.jl
Expand Up @@ -51,7 +51,12 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment)
active_q = solver.qnetwork
end
policy = NNPolicy(env.problem, active_q, ordered_actions(env.problem), length(obs_dimensions(env)))
target_q = deepcopy(solver.qnetwork)
return dqn_train!(solver, env, policy, replay)
end

function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnvironment, policy::AbstractNNPolicy, replay)
active_q = getnetwork(policy) # shallow copy
target_q = deepcopy(active_q)
optimizer = ADAM(Flux.params(active_q), solver.learning_rate)
# start training
reset!(policy)
Expand All @@ -64,6 +69,7 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment)
saved_mean_reward = -Inf
scores_eval = -Inf
model_saved = false
eval_next = false
for t=1:solver.max_steps
act, eps = exploration(solver.exploration_policy, policy, env, obs, t, solver.rng)
ai = actionindex(env.problem, act)
Expand All @@ -74,6 +80,15 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment)
step += 1
episode_rewards[end] += rew
if done || step >= solver.max_episode_length
if eval_next # wait for episode to end before evaluating
scores_eval = evaluation(solver.evaluation_policy,
policy, env,
solver.num_ep_eval,
solver.max_episode_length,
solver.verbose)
eval_next = false
end

obs = reset(env)
reset!(policy)
push!(episode_steps, step)
Expand All @@ -87,7 +102,7 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment)
avg100_steps = mean(episode_steps[max(1, length(episode_steps)-101):end])
if t%solver.train_freq == 0
hs = hiddenstates(active_q)
loss_val, td_errors, grad_val = batch_train!(solver, env, optimizer, active_q, target_q, replay)
loss_val, td_errors, grad_val = batch_train!(solver, env, policy, optimizer, target_q, replay)
sethiddenstates!(active_q, hs)
end

Expand All @@ -97,13 +112,16 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment)
end

if t%solver.eval_freq == 0
saved_state = env.state
eval_next = true
end

if eval_next && (done || step >= solver.max_episode_length) # wait for episode to end before evaluating
scores_eval = evaluation(solver.evaluation_policy,
policy, env,
solver.num_ep_eval,
solver.max_episode_length,
solver.verbose)
env.state = saved_state
policy, env,
solver.num_ep_eval,
solver.max_episode_length,
solver.verbose)
eval_next = false
end

if t%solver.log_freq == 0
Expand All @@ -122,7 +140,7 @@ function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnvironment)
if solver.verbose
@printf("Restore model with eval reward %1.3f \n", saved_mean_reward)
saved_model = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork]
Flux.loadparams!(policy.qnetwork, saved_model)
Flux.loadparams!(getnetwork(policy), saved_model)
end
end
return policy
Expand All @@ -147,8 +165,8 @@ function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnvironmen
end
policy = NNPolicy(env.problem, active_q, ordered_actions(env.problem), length(obs_dimensions(env)))
weights = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork]
Flux.loadparams!(policy.qnetwork, weights)
Flux.testmode!(policy.qnetwork)
Flux.loadparams!(getnetwork(policy), weights)
Flux.testmode!(getnetwork(policy))
return policy
end

Expand All @@ -167,10 +185,11 @@ end

function batch_train!(solver::DeepQLearningSolver,
env::AbstractEnvironment,
optimizer,
active_q,
policy::AbstractNNPolicy,
optimizer,
target_q,
s_batch, a_batch, r_batch, sp_batch, done_batch, importance_weights)
active_q = getnetwork(policy)
loss_tracked, td_tracked = q_learning_loss(solver, env, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, importance_weights)
loss_val = loss_tracked.data
td_vals = Flux.data.(td_tracked)
Expand Down Expand Up @@ -198,33 +217,34 @@ end

function batch_train!(solver::DeepQLearningSolver,
env::AbstractEnvironment,
policy::AbstractNNPolicy,
optimizer,
active_q,
target_q,
replay::ReplayBuffer)
s_batch, a_batch, r_batch, sp_batch, done_batch = sample(replay)
return batch_train!(solver, env, optimizer, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, ones(solver.batch_size))
return batch_train!(solver, env, policy, optimizer, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, ones(solver.batch_size))
end

function batch_train!(solver::DeepQLearningSolver,
env::AbstractEnvironment,
policy::AbstractNNPolicy,
optimizer,
active_q,
target_q,
replay::PrioritizedReplayBuffer)
s_batch, a_batch, r_batch, sp_batch, done_batch, indices, weights = sample(replay)
loss_val, td_vals, grad_norm = batch_train!(solver, env, optimizer, active_q, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, weights)
loss_val, td_vals, grad_norm = batch_train!(solver, env, policy, optimizer, target_q, s_batch, a_batch, r_batch, sp_batch, done_batch, weights)
update_priorities!(replay, indices, td_vals)
return loss_val, td_vals, grad_norm
end

# for RNNs
function batch_train!(solver::DeepQLearningSolver,
env::AbstractEnvironment,
policy::AbstractNNPolicy,
optimizer,
active_q,
target_q,
replay::EpisodeReplayBuffer)
active_q = getnetwork(policy)
s_batch, a_batch, r_batch, sp_batch, done_batch, trace_mask_batch = DeepQLearning.sample(replay)
Flux.reset!(active_q)
Flux.reset!(target_q)
Expand Down
32 changes: 31 additions & 1 deletion test/profile.jl → test/prototype.jl
Expand Up @@ -8,15 +8,32 @@ using Flux
using Profile
using BenchmarkTools


using Revise
using Random
using POMDPs
using DeepQLearning
using Flux
rng = MersenneTwister(1)
include("test/test_env.jl")
mdp = TestMDP((5,5), 4, 6)
model = Chain(x->flattenbatch(x), Dense(100, 8, tanh), Dense(8, n_actions(mdp)))
solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, learning_rate=0.005,
eval_freq=2000,num_ep_eval=100,
log_freq = 500,
double_q = false, dueling=true, prioritized_replay=false,
rng=rng)

policy = solve(solver, mdp)


mdp = SimpleGridWorld()

model = Chain(Dense(2, 32, relu), LSTM(32,32), Dense(32, 32, relu), Dense(32, n_actions(mdp)))

solver = DeepQLearningSolver(qnetwork = model, prioritized_replay=false, max_steps=1000, learning_rate=0.001,log_freq=500,
recurrence=true,trace_length=10, double_q=false, dueling=false, rng=rng, verbose=false)
@btime policy = solve(solver, mdp)
policy = solve(solver, mdp)


@profile 1+1
Expand All @@ -26,6 +43,19 @@ Profile.clear()

ProfileView.view()

### Try on SubHunt

using Revise
using POMDPs
using SubHunt
using RLInterface
using DeepQLearning
using Flux

solver = DeepQLearningSolver(qnetwork= Chain(Dense(8, 32, relu), Dense(32,32,relu), Dense(32, 6)),
max_steps=100_000)
solve(solver, SubHuntPOMDP())



### get q_sa
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Expand Up @@ -5,6 +5,7 @@ using Flux
using Random
using RLInterface
using Test
Random.srand(1) # for test consistency

include("test_env.jl")

Expand Down

0 comments on commit 457716a

Please sign in to comment.