Skip to content

Commit

Permalink
Merge 3c424c2 into b9a5943
Browse files Browse the repository at this point in the history
  • Loading branch information
findmyway committed Dec 31, 2018
2 parents b9a5943 + 3c424c2 commit 9887353
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 48 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://ju-jl.github.io/Ju.jl/dev/)
[![BuildStatus](https://travis-ci.org/Ju-jl/Ju.jl.svg?branch=master)](https://travis-ci.org/Ju-jl/Ju.jl)
[![](https://img.shields.io/docker/pulls/tianjun2018/ju.svg)](https://cloud.docker.com/repository/docker/tianjun2018/ju)
[![](https://img.shields.io/docker/pulls/tianjun2018/ju.svg)](https://hub.docker.com/r/tianjun2018/ju)
[![Coverage Status](https://coveralls.io/repos/github/Ju-jl/Ju.jl/badge.svg)](https://coveralls.io/github/Ju-jl/Ju.jl)
[![](https://codecov.io/gh/Ju-jl/Ju.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/Ju-jl/Ju.jl)
23 changes: 16 additions & 7 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,28 @@ function train!(env::AbstractSyncEnvironment{Tss, Tas, 1} where {Tss, Tas},
agent::AbstractAgent,
::Type{<:SARDBuffer};
callbacks)
s, a = observe(env).observation |> agent
obs, d = observe(env)
if isempty(buffer(agent))
s, a = agent(obs) # TODO: check buffer first
push!(buffer(agent), s, a)
else
s, a = buffer(agent).state[end], buffer(agent).action[end]
end

isstop = false
while !isstop
obs, r, d = env(a)
if d
reset!(env)
ns, na = agent(observe(env).observation)
else
ns, na = agent(obs)
empty!(buffer(agent))
s, a = agent(observe(env).observation)
push!(buffer(agent), s, a)
end
push!(buffer(agent), s, a, r, d, ns, na)

obs, r, d = env(a) # TODO: split into two steps: 1. env(a), 2. observe(env)
s, a = agent(obs)
push!(buffer(agent), r, d, s, a)
update!(agent)
s, a = ns, na

for cb in callbacks
res = cb(env, agent)
if res isa Bool && res
Expand Down
21 changes: 0 additions & 21 deletions src/learners/Q_learner.jl

This file was deleted.

43 changes: 24 additions & 19 deletions src/learners/temporal_difference_learner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,31 +156,34 @@ end
See more details at Section (7.3) on Page 148 of the book *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.*
"""
struct OffPolicyTDLearner{Tapp <: AbstractApproximator, Tpb <: AbstractPolicy, Tpt <: AbstractPolicy, method} <: AbstractModelFreeLearner
struct OffPolicyTDLearner{Tapp <: AbstractApproximator, Tpb <: PolicyOrSelector, Tpt <: PolicyOrSelector, Tα<:Union{Function, Float64}, method} <: AbstractModelFreeLearner
approximator::Tapp
π_behavior::Tpb
π_target::Tpt
γ::Float64
α::Float64
α::Tα
n::Int
function OffPolicyTDLearner(approximator::Tapp, π_behavior::Tpb, π_target::Tpt, γ::Float64, α::Float64, n::Int=0, method::Symbol=:SARSA_ImportanceSampling) where {Tapp<:AbstractApproximator, Tpb<:AbstractPolicy, Tpt<:AbstractPolicy}
new{Tapp, Tpb, Tpt, method}(approximator, π_behavior, π_target, γ, α, n)
function OffPolicyTDLearner(approximator::Tapp, π_behavior::Tpb, π_target::Tpt, γ::Float64, α::Tα, n::Int=0, method::Symbol=:SARSA_ImportanceSampling) where {Tapp<:AbstractApproximator, Tpb<:PolicyOrSelector, Tpt<:PolicyOrSelector, Tα<:Union{Function, Float64}}
new{Tapp, Tpb, Tpt, Tα, method}(approximator, π_behavior, π_target, γ, α, n)
end
end

const QLearner = OffPolicyTDLearner{Tapp, Tp, Tp, :QLearning} where {Tapp<:AbstractQApproximator, Tp<:AbstractPolicy}
QLearner(Q::AbstractQApproximator, π::AbstractPolicy, γ::Float64, α::Float64) = OffPolicyTDLearner(Q, π, π, γ, α, 0, :QLearning)
const QLearner = OffPolicyTDLearner{Tapp, Tp, Tp, Tα, :QLearning} where {Tapp<:AbstractQApproximator, Tp<:PolicyOrSelector, Tα<:Union{Function, Float64}}
QLearner(Q::AbstractQApproximator, π::PolicyOrSelector, γ::Float64, α::Union{Function, Float64}) = OffPolicyTDLearner(Q, π, π, γ, α, 0, :QLearning)

(learner::OffPolicyTDLearner)(s) = learner.π_behavior(s)
(learner::OffPolicyTDLearner)(s, ::Val{:dist}) = learner.π_behavior(s, Val(:dist))
(learner::OffPolicyTDLearner{<:AbstractApproximator, <:AbstractActionSelector})(s) = learner.approximator(s) |> learner.π_behavior
(learner::OffPolicyTDLearner{<:AbstractApproximator, <:AbstractPolicy})(s) = learner.π_behavior(s)
(learner::OffPolicyTDLearner{<:AbstractApproximator, <:AbstractPolicy})(s, ::Val{:dist}) = learner.π_behavior(s, Val(:dist))

function priority(learner::QLearner, s, a, r, d, s′)
α, γ, Q = learner.α, learner.γ, learner.Q
priority(learner::QLearner{<:AbstractQApproximator, <:AbstractPolicy, <:Float64}, s, a, r, d, s′) = priority(learner, learner.α, s, a, r, d, s′)
priority(learner::QLearner{<:AbstractQApproximator, <:AbstractPolicy, <:Function}, s, a, r, d, s′) = priority(learner, learner.α((s,a)), s, a, r, d, s′)
function priority(learner::QLearner, α, s, a, r, d, s′)
γ, Q = learner.γ, learner.approximator
error = d ? α * (r - Q(s, a)) : α * (r + γ * Q(s′, Val(:max)) - Q(s, a))
abs(error)
end

function update!(learner::OffPolicyTDLearner{<:AbstractApproximator, <:AbstractPolicy, <:AbstractPolicy, :SARSA_ImportanceSampling}, buffer::EpisodeSARDBuffer)
function update!(learner::OffPolicyTDLearner{<:AbstractApproximator, <:PolicyOrSelector, <:PolicyOrSelector, <:Union{Function, Float64}, :SARSA_ImportanceSampling}, buffer::EpisodeSARDBuffer)
n = learner.n
update!(learner,
@view(buffer.state[max(1, end - n - 1) : end - 1]),
Expand All @@ -191,8 +194,10 @@ function update!(learner::OffPolicyTDLearner{<:AbstractApproximator, <:AbstractP
Val(buffer.isdone[end]))
end

function update!(learner::QLearner, s, a, r, d, s′)
Q, γ, α, π = learner.approximator, learner.γ, learner.α, learner.π_target
update!(learner::QLearner{<:AbstractQApproximator, <:PolicyOrSelector, <:Float64}, s, a, r, d, s′) = update!(learner, learner.α, s, a, r, d, s′)
update!(learner::QLearner{<:AbstractQApproximator, <:PolicyOrSelector, <:Function}, s, a, r, d, s′) = update!(learner, learner.α((s, a)), s, a, r, d, s′)
function update!(learner::QLearner, α, s, a, r, d, s′)
Q, γ, π = learner.approximator, learner.γ, learner.π_target
error = d ? α * (r - Q(s, a)) : α * (r + γ * Q(s′, Val(:max)) - Q(s, a))
update!(Q, s, a, error)
update!(π, s, Q(s, Val(:argmax)))
Expand All @@ -207,7 +212,7 @@ function update!(learner::QLearner, buffer::EpisodeSARDSBuffer)
update!(learner, buffer.state[end], buffer.action[end], buffer.reward[end], buffer.isdone[end], buffer.nextstate[end])
end

function update!(learner::OffPolicyTDLearner{<:AbstractQApproximator, <:AbstractPolicy, <:AbstractPolicy, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{true})
function update!(learner::OffPolicyTDLearner{<:AbstractQApproximator, <:AbstractPolicy, <:AbstractPolicy, <:Union{Function, Float64}, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{true})
n, γ, Q, α, π, b = learner.n, learner.γ, learner.approximator, learner.α, learner.π_target, learner.π_behavior
# Warning!!! The order of calculation is reversed here for speed. The impact is uncertain!!!
for (G, ρ, s, a) in zip(reverse_discounted_rewards(rewards, γ),
Expand All @@ -219,7 +224,7 @@ function update!(learner::OffPolicyTDLearner{<:AbstractQApproximator, <:Abstract
end
end

function update!(learner::OffPolicyTDLearner{<:AbstractVApproximator, <:AbstractPolicy, <:AbstractPolicy, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{true})
function update!(learner::OffPolicyTDLearner{<:AbstractVApproximator, <:AbstractPolicy, <:AbstractPolicy, <:Union{Function, Float64}, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{true})
n, γ, V, α, π, b = learner.n, learner.γ, learner.approximator, learner.α, learner.π_target, learner.π_behavior
# Warning!!! The order of calculation is reversed here for speed. The impact is uncertain!!!
for (G, ρ, s, a) in zip(reverse_discounted_rewards(rewards, γ),
Expand All @@ -229,7 +234,7 @@ function update!(learner::OffPolicyTDLearner{<:AbstractVApproximator, <:Abstract
end
end

function update!(learner::OffPolicyTDLearner{<:AbstractQApproximator, <:AbstractPolicy, <:AbstractPolicy, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{false})
function update!(learner::OffPolicyTDLearner{<:AbstractQApproximator, <:AbstractPolicy, <:AbstractPolicy, <:Union{Function, Float64}, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{false})
n, γ, Q, α, π, b = learner.n, learner.γ, learner.approximator, learner.α, learner.π_target, learner.π_behavior
if length(states) n
G = discounted_reward(rewards, γ) + γ^n * Q(nextstates[end], nextactions[end])
Expand All @@ -240,7 +245,7 @@ function update!(learner::OffPolicyTDLearner{<:AbstractQApproximator, <:Abstract
end
end

function update!(learner::OffPolicyTDLearner{<:AbstractVApproximator, <:AbstractPolicy, <:AbstractPolicy, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{false})
function update!(learner::OffPolicyTDLearner{<:AbstractVApproximator, <:AbstractPolicy, <:AbstractPolicy, <:Union{Function, Float64}, :SARSA_ImportanceSampling}, states, actions, rewards, nextstates, nextactions, ::Val{false})
n, γ, V, α, π, b = learner.n, learner.γ, learner.approximator, learner.α, learner.π_target, learner.π_behavior
if length(states) n
G = discounted_reward(rewards, γ) + γ^n * V(nextstates[end])
Expand Down Expand Up @@ -268,7 +273,7 @@ end
(learner::DoubleLearner)(s, ::Val{:dist}) = learner.Learner1(s, Val(:dist)) .+ learner.Learner2(s, Val(:dist))
(learner::DoubleLearner)(s) = learner.selector(learner(s, Val(:dist)))

function update!(learner::DoubleLearner{<:OffPolicyTDLearner{<:AbstractQApproximator, <:AbstractPolicy, <:AbstractPolicy, :QLearning}}, buffer::EpisodeSARDBuffer)
function update!(learner::DoubleLearner{<:OffPolicyTDLearner{<:AbstractQApproximator, <:AbstractPolicy, <:AbstractPolicy, <:Union{Function, Float64}, :QLearning}}, buffer::EpisodeSARDBuffer)
s, a, r, s′ = buffer.state[end-1], buffer.action[end-1], buffer.reward[end], buffer.state[end]

if rand() < 0.5
Expand All @@ -287,7 +292,7 @@ function update!(learner::DoubleLearner{<:OffPolicyTDLearner{<:AbstractQApproxim
end

"""
DifferentialTDLearner(approximator::Tapp, π::Tp, α::Float64, β::Float64, R̄::Float64=0., n::Int=0, method::Symbol=:SARSA) where {Tapp<:AbstractApproximator, Tp<:PolicyOrSelector}= new{Tapp, Tp, method}(approximator, π, α, β, R̄, n)
DifferentialTDLearner(approximator::Tapp, π::Tp, α::Float64, β::Float64, R̄::Float64=0., n::Int=1, method::Symbol=:SARSA) where {Tapp<:AbstractApproximator, Tp<:PolicyOrSelector}= new{Tapp, Tp, method}(approximator, π, α, β, R̄, n)
See more details at Section (10.3) on Page 251 of the book *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.*
"""
Expand Down

0 comments on commit 9887353

Please sign in to comment.