-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exploration Policies #20
Changes from 6 commits
feef019
14d7d10
b812660
939509e
71bfd25
8d7e28f
29cdf39
0e6e11d
7d8b5ca
faeaf75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ language: julia | |
|
||
julia: | ||
- 1.0 | ||
- 1.2 | ||
- 1 | ||
|
||
os: | ||
- linux | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Exploration Policies | ||
|
||
Exploration policies are often useful for Reinforcement Learning algorithm to choose an action that is different than the action given by the policy being learned. | ||
|
||
This package provides two exploration policies: `EpsGreedyPolicy` and `SoftmaxPolicy` | ||
|
||
```@docs | ||
EpsGreedyPolicy | ||
SoftmaxPolicy | ||
``` | ||
|
||
## Interface | ||
|
||
Exploration policies are subtype of the abstract `ExplorationPolicy` type and they follow the following interface: | ||
`action(exploration_policy::ExplorationPolicy, on_policy::Policy, s)`. | ||
|
||
The `action` method is exported by [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl). | ||
To use exploration policies in a solver, you must use the three argument version of `action` where `on_policy` is the policy being learned (e.g. tabular policy or neural network policy). | ||
|
||
## Schedules | ||
|
||
Exploration policies often relies on a key parameter: $\epsilon$ in $\epsilon$-greedy and the temperature in softmax for example. | ||
Reinforcement learning algorithms often require a decay schedule for these parameters. | ||
`POMDPPolicies.jl` exports an interface for implementing decay schedules as well as a few convenient schedule. | ||
|
||
```@docs | ||
LinearDecaySchedule | ||
ConstantSchedule | ||
``` | ||
|
||
To implement your own schedule, you must define a schedule type that is a subtype of `ExplorationSchedule`, as well as the function `update_value` that returns the new parameter value updated according to your schedule. | ||
|
||
```@docs | ||
ExplorationSchedule | ||
update_value | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
|
||
|
||
# exploration schedule | ||
""" | ||
ExplorationSchedule | ||
Abstract type for exploration schedule. | ||
It is useful to define the schedule of a parameter of an exploration policy. | ||
The effect of a schedule is defined by the `update_value` function. | ||
""" | ||
abstract type ExplorationSchedule <: Function end | ||
|
||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't we getting rid of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have not updated the docstrings yet ;) |
||
update_value(::ExplorationSchedule, value) | ||
Returns an updated value according to the schedule. | ||
""" | ||
function update_value(::ExplorationSchedule, value) end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be function update_value end so the standard method error will be thrown. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given my other comments, this function may cease to exist entirely though. |
||
|
||
|
||
""" | ||
LinearDecaySchedule | ||
A schedule that linearly decreases a value from `start_val` to `end_val` in `steps` steps. | ||
if the value is greater or equal to `end_val`, it stays constant. | ||
|
||
# Constructor | ||
|
||
`LinearDecaySchedule(;start, stop, steps)` | ||
""" | ||
@with_kw struct LinearDecaySchedule{R<:Real} <: ExplorationSchedule | ||
start::R | ||
stop::R | ||
steps::Int | ||
end | ||
|
||
function (schedule::LinearDecaySchedule)(k) | ||
rate = (schedule.start - schedule.stop) / schedule.steps | ||
val = schedule.start - k*rate | ||
val = max(schedule.stop, val) | ||
end | ||
|
||
|
||
""" | ||
ExplorationPolicy <: Policy | ||
An abstract type for exploration policies. | ||
Sampling from an exploration policy is done using `action(exploration_policy, on_policy, state)` | ||
""" | ||
abstract type ExplorationPolicy <: Policy end | ||
|
||
# """ | ||
# exploration_parameter(::ExplorationPolicy) | ||
# returns the exploration parameter of an exploration policy, e.g. epsilon for e-greedy or temperature for softmax | ||
# """ | ||
# function exploration_parameter end | ||
|
||
""" | ||
EpsGreedyPolicy <: ExplorationPolicy | ||
|
||
represents an epsilon greedy policy, sampling a random action with a probability `eps` or returning an action from a given policy otherwise. | ||
The evolution of epsilon can be controlled using a schedule. This feature is useful for using those policies in reinforcement learning algorithms. | ||
|
||
constructor: | ||
|
||
`EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Float64; rng=Random.GLOBAL_RNG, schedule=ConstantSchedule)` | ||
""" | ||
struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy | ||
eps::T | ||
rng::R | ||
actions::A | ||
end | ||
|
||
function EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Function; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return EpsGreedyPolicy(eps, rng, actions(problem)) | ||
end | ||
function EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Real; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return EpsGreedyPolicy(x->eps, rng, actions(problem)) | ||
end | ||
|
||
|
||
function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s) | ||
if rand(p.rng) < p.eps(k) | ||
return rand(p.rng, p.actions) | ||
else | ||
return action(on_policy, s) | ||
end | ||
end | ||
|
||
# exploration_parameter(p::EpsGreedyPolicy, k) = p.eps(k) | ||
|
||
# softmax | ||
""" | ||
SoftmaxPolicy <: ExplorationPolicy | ||
|
||
represents a softmax policy, sampling a random action according to a softmax function. | ||
The softmax function converts the action values of the on policy into probabilities that are used for sampling. | ||
A temperature parameter can be used to make the resulting distribution more or less wide. | ||
""" | ||
struct SoftmaxPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy | ||
temperature::T | ||
rng::R | ||
actions::A | ||
end | ||
|
||
function SoftmaxPolicy(problem, temperature::Function; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return SoftmaxPolicy(temperature, rng, actions(problem)) | ||
end | ||
function SoftmaxPolicy(problem, temperature::Real; | ||
rng::AbstractRNG=Random.GLOBAL_RNG) | ||
return SoftmaxPolicy(x->temperature, rng, actions(problem)) | ||
end | ||
|
||
function POMDPs.action(p::SoftmaxPolicy, on_policy::Policy, k, s) | ||
vals = actionvalues(on_policy, s) | ||
vals ./= p.temperature(k) | ||
maxval = maximum(vals) | ||
exp_vals = exp.(vals .- maxval) | ||
exp_vals /= sum(exp_vals) | ||
return p.actions[sample(p.rng, Weights(exp_vals))] | ||
end | ||
|
||
# exploration_parameter(p::SoftmaxPolicy, k) = p.temperature(k) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
using POMDPModels | ||
|
||
problem = SimpleGridWorld() | ||
# e greedy | ||
policy = EpsGreedyPolicy(problem, 0.5) | ||
a = first(actions(problem)) | ||
@inferred action(policy, FunctionPolicy(s->a::Symbol), 1, GWPos(1,1)) | ||
policy = EpsGreedyPolicy(problem, 0.0) | ||
@test action(policy, FunctionPolicy(s->a), 1, GWPos(1,1)) == a | ||
|
||
# softmax | ||
policy = SoftmaxPolicy(problem, 0.5) | ||
on_policy = ValuePolicy(problem) | ||
@inferred action(policy, on_policy, 1, GWPos(1,1)) | ||
|
||
# test linear schedule | ||
policy = EpsGreedyPolicy(problem, LinearDecaySchedule(start=1.0, stop=0.0, steps=10)) | ||
for i=1:11 | ||
action(policy, FunctionPolicy(s->a), i, GWPos(1,1)) | ||
@test policy.eps(i) < 1.0 | ||
end | ||
@test policy.eps(11) ≈ 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this abstract type? I don't see what purpose it serves and I am afraid someone will see it and think they need to use it. I think the schedule should just be a function, so people can write
eps = k->max(0, 0.1*(10000-k)/10000)
for instance.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, not really needed here since we don't have an interface for schedules anymore.