Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exploration Policies #20

Merged
merged 10 commits into from
Mar 19, 2020
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ language: julia

julia:
- 1.0
- 1.2
- 1

os:
- linux
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
name = "POMDPPolicies"
uuid = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
version = "0.2.1"
version = "0.3.0"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
POMDPModelTools = "0.2"
BeliefUpdaters = "0.1"
POMDPModelTools = "0.2"
POMDPs = "0.7.3, 0.8"
StatsBase = "0.26,0.27,0.28,0.29,0.30,0.31,0.32"
julia = "1"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ A collection of default policy types for [POMDPs.jl](https://github.com/JuliaPOM
```julia
using Pkg
Pkg.add("POMDPPolicies")
```
```
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
9 changes: 2 additions & 7 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@ using Documenter, POMDPPolicies

makedocs(
modules = [POMDPPolicies],
format = :html,
format = Documenter.HTML(),
sitename = "POMDPPolicies.jl"
)

deploydocs(
repo = "github.com/JuliaPOMDP/POMDPPolicies.jl.git",
julia = "1.0",
osname = "linux",
target = "build",
deps = nothing,
make = nothing
)
)
32 changes: 32 additions & 0 deletions docs/src/exploration_policies.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Exploration Policies

Exploration policies are often useful for Reinforcement Learning algorithm to choose an action that is different than the action given by the policy being learned (`on_policy`).

Exploration policies are subtype of the abstract `ExplorationPolicy` type and they follow the following interface:
`action(exploration_policy::ExplorationPolicy, on_policy::Policy, k, s)`. `k` is used to compute the value of the exploration parameter (see [Schedule](@ref)), and `s` is the current state or observation in which the agent is taking an action.

The `action` method is exported by [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl).
To use exploration policies in a solver, you must use the four argument version of `action` where `on_policy` is the policy being learned (e.g. tabular policy or neural network policy).

This package provides two exploration policies: `EpsGreedyPolicy` and `SoftmaxPolicy`

```@docs
EpsGreedyPolicy
SoftmaxPolicy
```

## Schedule

Exploration policies often rely on a key parameter: $\epsilon$ in $\epsilon$-greedy and the temperature in softmax for example.
Reinforcement learning algorithms often require a decay schedule for these parameters.
Schedule can be passed to an exploration policy as functions. For example one can define an epsilon greedy policy with an exponential decay schedule as follow:
```julia
m # your mdp or pomdp model
exploration_policy = EpsGreedyPolicy(m, k->0.05*0.9^(k/10))
```

`POMDPPolicies.jl` exports a linear decay schedule object that can be used as well.

```@docs
LinearDecaySchedule
```
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ It currently provides:
- an alpha vector policy type
- a random policy
- a stochastic policy type
- exploration policies
- a vector policy type
- a wrapper to collect statistics and errors about policies

Expand Down
5 changes: 0 additions & 5 deletions docs/src/stochastic.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Types for representing randomized policies:
- `StochasticPolicy` samples actions from an arbitrary distribution.
- `UniformRandomPolicy` samples actions uniformly (see `RandomPolicy` for a similar use)
- `CategoricalTabularPolicy` samples actions from a categorical distribution with weights given by a `ValuePolicy`.
- `EpsGreedyPolicy` uses epsilon-greedy action selection.

```@docs
StochasticPolicy
Expand All @@ -14,7 +13,3 @@ StochasticPolicy
```@docs
CategoricalTabularPolicy
```

```@docs
EpsGreedyPolicy
```
12 changes: 10 additions & 2 deletions src/POMDPPolicies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using LinearAlgebra
using Random
using StatsBase # for Weights
using SparseArrays # for sparse vectors in alpha_vector.jl
using Parameters

using POMDPs
import POMDPs: action, value, solve, updater
Expand Down Expand Up @@ -52,11 +53,18 @@ include("vector.jl")
export
StochasticPolicy,
UniformRandomPolicy,
CategoricalTabularPolicy,
EpsGreedyPolicy
CategoricalTabularPolicy

include("stochastic.jl")

export LinearDecaySchedule,
EpsGreedyPolicy,
SoftmaxPolicy,
ExplorationPolicy,
loginfo

include("exploration_policies.jl")

export
PolicyWrapper,
payload
Expand Down
128 changes: 128 additions & 0 deletions src/exploration_policies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
LinearDecaySchedule
A schedule that linearly decreases a value from `start` to `stop` in `steps` steps.
if the value is greater or equal to `stop`, it stays constant.

# Constructor

`LinearDecaySchedule(;start, stop, steps)`
"""
@with_kw struct LinearDecaySchedule{R<:Real} <: Function
start::R
stop::R
steps::Int
end

function (schedule::LinearDecaySchedule)(k)
rate = (schedule.start - schedule.stop) / schedule.steps
val = schedule.start - k*rate
val = max(schedule.stop, val)
end


"""
ExplorationPolicy <: Policy
An abstract type for exploration policies.
Sampling from an exploration policy is done using `action(exploration_policy, on_policy, k, state)`.
`k` is a value that is used to determine the exploration parameter. It is usually a training step in a TD-learning algorithm.
"""
abstract type ExplorationPolicy <: Policy end

"""
loginfo(::ExplorationPolicy, k)
returns information about an exploration policy, e.g. epsilon for e-greedy or temperature for softmax.
It is expected to return a namedtuple (e.g. (temperature=0.5)). `k` is the current training step that is used to compute the exploration parameter.
"""
function loginfo end

"""
EpsGreedyPolicy <: ExplorationPolicy

represents an epsilon greedy policy, sampling a random action with a probability `eps` or returning an action from a given policy otherwise.
The evolution of epsilon can be controlled using a schedule. This feature is useful for using those policies in reinforcement learning algorithms.

# Constructor:

`EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Union{Function, Float64}; rng=Random.GLOBAL_RNG, schedule=ConstantSchedule)`

If a function is passed for `eps`, `eps(k)` is called to compute the value of epsilon when calling `action(exploration_policy, on_policy, k, s)`.


# Fields

- `eps::Function`
- `rng::AbstractRNG`
- `actions::A` an indexable list of action
"""
struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy
eps::T
rng::R
actions::A
end

function EpsGreedyPolicy(problem, eps::Function;
rng::AbstractRNG=Random.GLOBAL_RNG)
return EpsGreedyPolicy(eps, rng, actions(problem))
end
function EpsGreedyPolicy(problem, eps::Real;
rng::AbstractRNG=Random.GLOBAL_RNG)
return EpsGreedyPolicy(x->eps, rng, actions(problem))
end


function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s)
if rand(p.rng) < p.eps(k)
return rand(p.rng, p.actions)
else
return action(on_policy, s)
end
end

loginfo(p::EpsGreedyPolicy, k) = (eps=p.eps(k),)

# softmax
"""
SoftmaxPolicy <: ExplorationPolicy

represents a softmax policy, sampling a random action according to a softmax function.
The softmax function converts the action values of the on policy into probabilities that are used for sampling.
A temperature parameter or function can be used to make the resulting distribution more or less wide.

# Constructor

`SoftmaxPolicy(problem, temperature::Union{Function, Float64}; rng=Random.GLOBAL_RNG)`

If a function is passed for `temperature`, `temperature(k)` is called to compute the value of the temperature when calling `action(exploration_policy, on_policy, k, s)`

# Fields

- `temperature::Function`
- `rng::AbstractRNG`
- `actions::A` an indexable list of action

"""
struct SoftmaxPolicy{T<:Function, R<:AbstractRNG, A} <: ExplorationPolicy
temperature::T
rng::R
actions::A
end

function SoftmaxPolicy(problem, temperature::Function;
rng::AbstractRNG=Random.GLOBAL_RNG)
return SoftmaxPolicy(temperature, rng, actions(problem))
end
function SoftmaxPolicy(problem, temperature::Real;
rng::AbstractRNG=Random.GLOBAL_RNG)
return SoftmaxPolicy(x->temperature, rng, actions(problem))
end

function POMDPs.action(p::SoftmaxPolicy, on_policy::Policy, k, s)
vals = actionvalues(on_policy, s)
vals ./= p.temperature(k)
maxval = maximum(vals)
exp_vals = exp.(vals .- maxval)
exp_vals /= sum(exp_vals)
return p.actions[sample(p.rng, Weights(exp_vals))]
end

loginfo(p::SoftmaxPolicy, k) = (temperature=p.temperature(k),)
26 changes: 0 additions & 26 deletions src/stochastic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,3 @@ function action(policy::CategoricalTabularPolicy, s)
policy.stochastic.distribution = Weights(policy.value.value_table[stateindex(policy.value.mdp, s),:])
return policy.value.act[sample(policy.stochastic.rng, policy.stochastic.distribution)]
end

"""
EpsGreedyPolicy

represents an epsilon greedy policy, sampling a random action with a probability `eps` or sampling from a given stochastic policy otherwise.

constructor:

`EpsGreedyPolicy(mdp::Union{MDP,POMDP}, eps::Float64; rng=Random.GLOBAL_RNG)`
"""
mutable struct EpsGreedyPolicy <: Policy
eps::Float64
val::ValuePolicy
uni::StochasticPolicy
end

EpsGreedyPolicy(mdp::Union{MDP,POMDP}, eps::Float64;
rng=Random.GLOBAL_RNG) = EpsGreedyPolicy(eps, ValuePolicy(mdp), UniformRandomPolicy(mdp, rng))

function action(policy::EpsGreedyPolicy, s)
if rand(policy.uni.rng) > policy.eps
return action(policy.val, s)
else
return action(policy.uni, s)
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ end
@testset "pretty_printing" begin
include("test_pretty_printing.jl")
end
@testset "exploration policies" begin
include("test_exploration_policies.jl")
end
24 changes: 24 additions & 0 deletions test/test_exploration_policies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using POMDPModels

problem = SimpleGridWorld()
# e greedy
policy = EpsGreedyPolicy(problem, 0.5)
a = first(actions(problem))
@inferred action(policy, FunctionPolicy(s->a::Symbol), 1, GWPos(1,1))
policy = EpsGreedyPolicy(problem, 0.0)
@test action(policy, FunctionPolicy(s->a), 1, GWPos(1,1)) == a

# softmax
policy = SoftmaxPolicy(problem, 0.5)
@test loginfo(policy, 1).temperature == 0.5
on_policy = ValuePolicy(problem)
@inferred action(policy, on_policy, 1, GWPos(1,1))

# test linear schedule
policy = EpsGreedyPolicy(problem, LinearDecaySchedule(start=1.0, stop=0.0, steps=10))
for i=1:11
action(policy, FunctionPolicy(s->a), i, GWPos(1,1))
@test policy.eps(i) < 1.0
@test loginfo(policy, i).eps == policy.eps(i)
end
@test policy.eps(11) ≈ 0.0
4 changes: 0 additions & 4 deletions test/test_stochastic_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,4 @@ policy = CategoricalTabularPolicy(problem)
sim = RolloutSimulator(max_steps=10)
simulate(sim, problem, policy)

policy = EpsGreedyPolicy(problem, 0.5)
sim = RolloutSimulator(max_steps=10)
simulate(sim, problem, policy)

end