Skip to content

Commit

Permalink
schedules are functors, policies take functions as arg
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeBouton committed Mar 11, 2020
1 parent 71bfd25 commit 8d7e28f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ language: julia

julia:
- 1.0
- 1.3
- 1

os:
- linux
Expand Down
1 change: 1 addition & 0 deletions src/POMDPPolicies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export ExplorationSchedule,
EpsGreedyPolicy,
SoftmaxPolicy,
ExplorationPolicy,
exploration_parameter,
LinearDecaySchedule,
ConstantSchedule

Expand Down
68 changes: 35 additions & 33 deletions src/exploration_policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Abstract type for exploration schedule.
It is useful to define the schedule of a parameter of an exploration policy.
The effect of a schedule is defined by the `update_value` function.
"""
abstract type ExplorationSchedule end
abstract type ExplorationSchedule <: Function end

"""
update_value(::ExplorationSchedule, value)
Expand All @@ -23,28 +23,20 @@ if the value is greater or equal to `end_val`, it stays constant.
# Constructor
`LinearDecaySchedule(;start_val, end_val, steps)`
`LinearDecaySchedule(;start, stop, steps)`
"""
@with_kw struct LinearDecaySchedule{R<:Real} <: ExplorationSchedule
start_val::R
end_val::R
start::R
stop::R
steps::Int
end

function update_value(schedule::LinearDecaySchedule, value)
rate = (schedule.start_val - schedule.end_val) / schedule.steps
new_value = max(value - rate, schedule.end_val)
function (schedule::LinearDecaySchedule)(k)
rate = (schedule.start - schedule.stop) / schedule.steps
val = schedule.start - k*rate
val = max(schedule.stop, val)
end

"""
ConstantSchedule
A schedule that keeps the value constant
"""
struct ConstantSchedule <: ExplorationSchedule
end

update_value(::ConstantSchedule, value) = value


"""
ExplorationPolicy <: Policy
Expand All @@ -53,7 +45,11 @@ Sampling from an exploration policy is done using `action(exploration_policy, on
"""
abstract type ExplorationPolicy <: Policy end


# """
# exploration_parameter(::ExplorationPolicy)
# returns the exploration parameter of an exploration policy, e.g. epsilon for e-greedy or temperature for softmax
# """
# function exploration_parameter end

"""
EpsGreedyPolicy <: ExplorationPolicy
Expand All @@ -65,29 +61,32 @@ constructor:
`EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Float64; rng=Random.GLOBAL_RNG, schedule=ConstantSchedule)`
"""
mutable struct EpsGreedyPolicy{T<:Real, S<:ExplorationSchedule, R<:AbstractRNG, A} <: ExplorationPolicy
struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy
eps::T
schedule::S
rng::R
actions::A
end

function EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Function;
rng::AbstractRNG=Random.GLOBAL_RNG)
return EpsGreedyPolicy(eps, rng, actions(problem))
end
function EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Real;
rng::AbstractRNG=Random.GLOBAL_RNG,
schedule::ExplorationSchedule=ConstantSchedule())
return EpsGreedyPolicy(eps, schedule, rng, actions(problem))
rng::AbstractRNG=Random.GLOBAL_RNG)
return EpsGreedyPolicy(x->eps, rng, actions(problem))
end


function POMDPs.action(p::EpsGreedyPolicy{T}, on_policy::Policy, s) where T<:Real
p.eps = update_value(p.schedule, p.eps)
if rand(p.rng) < p.eps
function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s)
if rand(p.rng) < p.eps(k)
return rand(p.rng, p.actions)
else
return action(on_policy, s)
end
end

# exploration_parameter(p::EpsGreedyPolicy, k) = p.eps(k)

# softmax
"""
SoftmaxPolicy <: ExplorationPolicy
Expand All @@ -96,25 +95,28 @@ represents a softmax policy, sampling a random action according to a softmax fun
The softmax function converts the action values of the on policy into probabilities that are used for sampling.
A temperature parameter can be used to make the resulting distribution more or less wide.
"""
mutable struct SoftmaxPolicy{T<:Real, S<:ExplorationSchedule, R<:AbstractRNG, A} <: ExplorationPolicy
struct SoftmaxPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy
temperature::T
schedule::S
rng::R
actions::A
end

function SoftmaxPolicy(problem, temperature::Function;
rng::AbstractRNG=Random.GLOBAL_RNG)
return SoftmaxPolicy(temperature, rng, actions(problem))
end
function SoftmaxPolicy(problem, temperature::Real;
rng::AbstractRNG=Random.GLOBAL_RNG,
schedule::ExplorationSchedule=ConstantSchedule())
return SoftmaxPolicy(temperature, schedule, rng, actions(problem))
rng::AbstractRNG=Random.GLOBAL_RNG)
return SoftmaxPolicy(x->temperature, rng, actions(problem))
end

function POMDPs.action(p::SoftmaxPolicy, on_policy::Policy, s)
p.temperature = update_value(p.schedule, p.temperature)
function POMDPs.action(p::SoftmaxPolicy, on_policy::Policy, k, s)
vals = actionvalues(on_policy, s)
vals ./= p.temperature
vals ./= p.temperature(k)
maxval = maximum(vals)
exp_vals = exp.(vals .- maxval)
exp_vals /= sum(exp_vals)
return p.actions[sample(p.rng, Weights(exp_vals))]
end

# exploration_parameter(p::SoftmaxPolicy, k) = p.temperature(k)
16 changes: 8 additions & 8 deletions test/test_exploration_policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ problem = SimpleGridWorld()
# e greedy
policy = EpsGreedyPolicy(problem, 0.5)
a = first(actions(problem))
@inferred action(policy, FunctionPolicy(s->a::Symbol), GWPos(1,1))
policy.eps = 0.0
@test action(policy, FunctionPolicy(s->a), GWPos(1,1)) == a
@inferred action(policy, FunctionPolicy(s->a::Symbol), 1, GWPos(1,1))
policy = EpsGreedyPolicy(problem, 0.0)
@test action(policy, FunctionPolicy(s->a), 1, GWPos(1,1)) == a

# softmax
policy = SoftmaxPolicy(problem, 0.5)
on_policy = ValuePolicy(problem)
@inferred action(policy, on_policy, GWPos(1,1))
@inferred action(policy, on_policy, 1, GWPos(1,1))

# test linear schedule
policy = EpsGreedyPolicy(problem, 1.0, schedule=LinearDecaySchedule(start_val=1.0, end_val=0.0, steps=10))
policy = EpsGreedyPolicy(problem, LinearDecaySchedule(start=1.0, stop=0.0, steps=10))
for i=1:11
action(policy, FunctionPolicy(s->a), GWPos(1,1))
@test policy.eps < 1.0
action(policy, FunctionPolicy(s->a), i, GWPos(1,1))
@test policy.eps(i) < 1.0
end
@test policy.eps 0.0
@test policy.eps(11) 0.0

0 comments on commit 8d7e28f

Please sign in to comment.