Skip to content

Commit

Permalink
Merge a83af24 into b24a90b
Browse files Browse the repository at this point in the history
  • Loading branch information
findmyway committed Jan 20, 2019
2 parents b24a90b + a83af24 commit 1547658
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Expand Up @@ -18,4 +18,4 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test"]
4 changes: 3 additions & 1 deletion src/base/helper_functions.jl
Expand Up @@ -110,4 +110,6 @@ The returned object is of type [`Reductions`](@ref)
reverse_importance_weights(π, b, states, actions) = Reductions(
(ρ, (s, a)) -> ρ == 0. ? 0. : ρ * π(s, a) / b(s, a),
Iterators.reverse(zip(states, actions)),
(init=1.,))
(init=1.,))

const is_using_gpu = false
4 changes: 3 additions & 1 deletion src/dynamic_loading.jl
@@ -1,4 +1,6 @@
import Flux:gpu
using CuArrays

gpu(x::SubArray) = CuArray{Float32}(x)
gpu(x::SubArray) = CuArray{Float32}(x)

const is_using_gpu = true
14 changes: 10 additions & 4 deletions src/learners/dqn.jl
Expand Up @@ -2,28 +2,34 @@ using Flux

const PolicyOrSelector = Union{AbstractPolicy, AbstractActionSelector}

struct DQN{Tn<:NeuralNetworkQ, Tp<:PolicyOrSelector} <: AbstractModelFreeLearner
mutable struct DQN{Tn<:NeuralNetworkQ, Tp<:PolicyOrSelector, Tf, Tl<:Union{Float32, Float64}} <: AbstractModelFreeLearner
Q::Tn
π::Tp
γ::Float64
batch_size::Int
DQN(Q::TQ, π::Tp; γ=0.99, batch_size=32) where {TQ, Tp} = new{TQ, Tp}(Q, π, γ, batch_size)
loss_fun::Tf
loss::Tl
function DQN(Q::TQ, π::Tp; γ=0.99, batch_size=32, loss_fun=Flux.mse) where {TQ, Tp}
init_loss = is_using_gpu ? Float32(0.) : Float64(0.)
new{TQ, Tp, typeof(loss_fun), typeof(init_loss)}(Q, π, γ, batch_size, loss_fun, init_loss)
end
end

(learner::DQN{<:NeuralNetworkQ, <:AbstractActionSelector})(s) = learner.Q(gpu(s)) |> learner.π
(learner::DQN{<:NeuralNetworkQ, <:AbstractPolicy})(s) = learner.π(s)
(learner::DQN{<:NeuralNetworkQ, <:AbstractPolicy})(s, ::Val{:dist}) = learner.π(s, Val(:dist))

function update!(learner::DQN{<:NeuralNetworkQ, <:AbstractActionSelector}, buffer::CircularSARDBuffer)
Q, π, γ, batch_size = learner.Q, learner.π, learner.γ, learner.batch_size
Q, π, γ, batch_size, loss_fun = learner.Q, learner.π, learner.γ, learner.batch_size, learner.loss_fun

if length(buffer) > batch_size
(s, a, r, d, s′), _ = sample(buffer, batch_size)
s, r, d, s′ = gpu(s), gpu(r), gpu(d), gpu(s′)

q, q′ = Q(s, a), Q(s′, Val(:max))
G = @. r + γ * q′ * (1 - d)
loss = Flux.mse(G, q)
loss = loss_fun(G, q)
learner.loss = loss.data
update!(Q, loss)
# update!(π, s, Q(s, Val(:argmax))) # π isa AbstractPolicy
end
Expand Down

0 comments on commit 1547658

Please sign in to comment.