-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
207 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[deps] | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Exploration Policies | ||
|
||
Exploration policies are often useful for Reinforcement Learning algorithm to choose an action that is different than the action given by the policy being learned (`on_policy`). | ||
|
||
Exploration policies are subtype of the abstract `ExplorationPolicy` type and they follow the following interface: | ||
`action(exploration_policy::ExplorationPolicy, on_policy::Policy, k, s)`. `k` is used to compute the value of the exploration parameter (see [Schedule](@ref)), and `s` is the current state or observation in which the agent is taking an action. | ||
|
||
The `action` method is exported by [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl). | ||
To use exploration policies in a solver, you must use the four argument version of `action` where `on_policy` is the policy being learned (e.g. tabular policy or neural network policy). | ||
|
||
This package provides two exploration policies: `EpsGreedyPolicy` and `SoftmaxPolicy` | ||
|
||
```@docs | ||
EpsGreedyPolicy | ||
SoftmaxPolicy | ||
``` | ||
|
||
## Schedule | ||
|
||
Exploration policies often rely on a key parameter: $\epsilon$ in $\epsilon$-greedy and the temperature in softmax for example. | ||
Reinforcement learning algorithms often require a decay schedule for these parameters. | ||
Schedule can be passed to an exploration policy as functions. For example one can define an epsilon greedy policy with an exponential decay schedule as follow: | ||
```julia | ||
m # your mdp or pomdp model | ||
exploration_policy = EpsGreedyPolicy(m, k->0.05*0.9^(k/10)) | ||
``` | ||
|
||
`POMDPPolicies.jl` exports a linear decay schedule object that can be used as well. | ||
|
||
```@docs | ||
LinearDecaySchedule | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
""" | ||
LinearDecaySchedule | ||
A schedule that linearly decreases a value from `start` to `stop` in `steps` steps. | ||
if the value is greater or equal to `stop`, it stays constant. | ||
# Constructor | ||
`LinearDecaySchedule(;start, stop, steps)` | ||
""" | ||
@with_kw struct LinearDecaySchedule{R<:Real} <: Function | ||
start::R | ||
stop::R | ||
steps::Int | ||
end | ||
|
||
function (schedule::LinearDecaySchedule)(k) | ||
rate = (schedule.start - schedule.stop) / schedule.steps | ||
val = schedule.start - k*rate | ||
val = max(schedule.stop, val) | ||
end | ||
|
||
|
||
""" | ||
ExplorationPolicy <: Policy | ||
An abstract type for exploration policies. | ||
Sampling from an exploration policy is done using `action(exploration_policy, on_policy, k, state)`. | ||
`k` is a value that is used to determine the exploration parameter. It is usually a training step in a TD-learning algorithm. | ||
""" | ||
abstract type ExplorationPolicy <: Policy end | ||
|
||
""" | ||
loginfo(::ExplorationPolicy, k) | ||
returns information about an exploration policy, e.g. epsilon for e-greedy or temperature for softmax. | ||
It is expected to return a namedtuple (e.g. (temperature=0.5)). `k` is the current training step that is used to compute the exploration parameter. | ||
""" | ||
function loginfo end | ||
|
||
""" | ||
EpsGreedyPolicy <: ExplorationPolicy | ||
represents an epsilon greedy policy, sampling a random action with a probability `eps` or returning an action from a given policy otherwise. | ||
The evolution of epsilon can be controlled using a schedule. This feature is useful for using those policies in reinforcement learning algorithms. | ||
# Constructor: | ||
`EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Union{Function, Float64}; rng=Random.GLOBAL_RNG, schedule=ConstantSchedule)` | ||
If a function is passed for `eps`, `eps(k)` is called to compute the value of epsilon when calling `action(exploration_policy, on_policy, k, s)`. | ||
# Fields | ||
- `eps::Function` | ||
- `rng::AbstractRNG` | ||
- `actions::A` an indexable list of action | ||
""" | ||
struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy | ||
eps::T | ||
rng::R | ||
actions::A | ||
end | ||
|
||
function EpsGreedyPolicy(problem, eps::Function; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return EpsGreedyPolicy(eps, rng, actions(problem)) | ||
end | ||
function EpsGreedyPolicy(problem, eps::Real; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return EpsGreedyPolicy(x->eps, rng, actions(problem)) | ||
end | ||
|
||
|
||
function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s) | ||
if rand(p.rng) < p.eps(k) | ||
return rand(p.rng, p.actions) | ||
else | ||
return action(on_policy, s) | ||
end | ||
end | ||
|
||
loginfo(p::EpsGreedyPolicy, k) = (eps=p.eps(k),) | ||
|
||
# softmax | ||
""" | ||
SoftmaxPolicy <: ExplorationPolicy | ||
represents a softmax policy, sampling a random action according to a softmax function. | ||
The softmax function converts the action values of the on policy into probabilities that are used for sampling. | ||
A temperature parameter or function can be used to make the resulting distribution more or less wide. | ||
# Constructor | ||
`SoftmaxPolicy(problem, temperature::Union{Function, Float64}; rng=Random.GLOBAL_RNG)` | ||
If a function is passed for `temperature`, `temperature(k)` is called to compute the value of the temperature when calling `action(exploration_policy, on_policy, k, s)` | ||
# Fields | ||
- `temperature::Function` | ||
- `rng::AbstractRNG` | ||
- `actions::A` an indexable list of action | ||
""" | ||
struct SoftmaxPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy | ||
temperature::T | ||
rng::R | ||
actions::A | ||
end | ||
|
||
function SoftmaxPolicy(problem, temperature::Function; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return SoftmaxPolicy(temperature, rng, actions(problem)) | ||
end | ||
function SoftmaxPolicy(problem, temperature::Real; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return SoftmaxPolicy(x->temperature, rng, actions(problem)) | ||
end | ||
|
||
function POMDPs.action(p::SoftmaxPolicy, on_policy::Policy, k, s) | ||
vals = actionvalues(on_policy, s) | ||
vals ./= p.temperature(k) | ||
maxval = maximum(vals) | ||
exp_vals = exp.(vals .- maxval) | ||
exp_vals /= sum(exp_vals) | ||
return p.actions[sample(p.rng, Weights(exp_vals))] | ||
end | ||
|
||
loginfo(p::SoftmaxPolicy, k) = (temperature=p.temperature(k),) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
using POMDPModels | ||
|
||
problem = SimpleGridWorld() | ||
# e greedy | ||
policy = EpsGreedyPolicy(problem, 0.5) | ||
a = first(actions(problem)) | ||
@inferred action(policy, FunctionPolicy(s->a::Symbol), 1, GWPos(1,1)) | ||
policy = EpsGreedyPolicy(problem, 0.0) | ||
@test action(policy, FunctionPolicy(s->a), 1, GWPos(1,1)) == a | ||
|
||
# softmax | ||
policy = SoftmaxPolicy(problem, 0.5) | ||
@test loginfo(policy, 1).temperature == 0.5 | ||
on_policy = ValuePolicy(problem) | ||
@inferred action(policy, on_policy, 1, GWPos(1,1)) | ||
|
||
# test linear schedule | ||
policy = EpsGreedyPolicy(problem, LinearDecaySchedule(start=1.0, stop=0.0, steps=10)) | ||
for i=1:11 | ||
action(policy, FunctionPolicy(s->a), i, GWPos(1,1)) | ||
@test policy.eps(i) < 1.0 | ||
@test loginfo(policy, i).eps == policy.eps(i) | ||
end | ||
@test policy.eps(11) ≈ 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters