Skip to content

Commit

Permalink
Merge 71bfd25 into ed9220a
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeBouton committed Mar 3, 2020
2 parents ed9220a + 71bfd25 commit 4894b31
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ language: julia

julia:
- 1.0
- 1.2
- 1.3

os:
- linux
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
name = "POMDPPolicies"
uuid = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
version = "0.2.1"
version = "0.3.0"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
POMDPModelTools = "0.2"
BeliefUpdaters = "0.1"
POMDPModelTools = "0.2"
POMDPs = "0.7.3, 0.8"
StatsBase = "0.26,0.27,0.28,0.29,0.30,0.31,0.32"
julia = "1"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ A collection of default policy types for [POMDPs.jl](https://github.com/JuliaPOM
```julia
using Pkg
Pkg.add("POMDPPolicies")
```
```
36 changes: 36 additions & 0 deletions docs/src/exploration_policies.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Exploration Policies

Exploration policies are often useful for Reinforcement Learning algorithm to choose an action that is different than the action given by the policy being learned.

This package provides two exploration policies: `EpsGreedyPolicy` and `SoftmaxPolicy`

```@docs
EpsGreedyPolicy
SoftmaxPolicy
```

## Interface

Exploration policies are subtype of the abstract `ExplorationPolicy` type and they follow the following interface:
`action(exploration_policy::ExplorationPolicy, on_policy::Policy, s)`.

The `action` method is exported by [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl).
To use exploration policies in a solver, you must use the three argument version of `action` where `on_policy` is the policy being learned (e.g. tabular policy or neural network policy).

## Schedules

Exploration policies often relies on a key parameter: $\epsilon$ in $\epsilon$-greedy and the temperature in softmax for example.
Reinforcement learning algorithms often require a decay schedule for these parameters.
`POMDPPolicies.jl` exports an interface for implementing decay schedules as well as a few convenient schedule.

```@docs
LinearDecaySchedule
ConstantSchedule
```

To implement your own schedule, you must define a schedule type that is a subtype of `ExplorationSchedule`, as well as the function `update_value` that returns the new parameter value updated according to your schedule.

```@docs
ExplorationSchedule
update_value
```
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ It currently provides:
- an alpha vector policy type
- a random policy
- a stochastic policy type
- exploration policies
- a vector policy type
- a wrapper to collect statistics and errors about policies

Expand Down
5 changes: 0 additions & 5 deletions docs/src/stochastic.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Types for representing randomized policies:
- `StochasticPolicy` samples actions from an arbitrary distribution.
- `UniformRandomPolicy` samples actions uniformly (see `RandomPolicy` for a similar use)
- `CategoricalTabularPolicy` samples actions from a categorical distribution with weights given by a `ValuePolicy`.
- `EpsGreedyPolicy` uses epsilon-greedy action selection.

```@docs
StochasticPolicy
Expand All @@ -14,7 +13,3 @@ StochasticPolicy
```@docs
CategoricalTabularPolicy
```

```@docs
EpsGreedyPolicy
```
13 changes: 11 additions & 2 deletions src/POMDPPolicies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using LinearAlgebra
using Random
using StatsBase # for Weights
using SparseArrays # for sparse vectors in alpha_vector.jl
using Parameters

using POMDPs
import POMDPs: action, value, solve, updater
Expand Down Expand Up @@ -52,11 +53,19 @@ include("vector.jl")
export
StochasticPolicy,
UniformRandomPolicy,
CategoricalTabularPolicy,
EpsGreedyPolicy
CategoricalTabularPolicy

include("stochastic.jl")

export ExplorationSchedule,
EpsGreedyPolicy,
SoftmaxPolicy,
ExplorationPolicy,
LinearDecaySchedule,
ConstantSchedule

include("exploration_policies.jl")

export
PolicyWrapper,
payload
Expand Down
120 changes: 120 additions & 0 deletions src/exploration_policies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@


# exploration schedule
"""
ExplorationSchedule
Abstract type for exploration schedule.
It is useful to define the schedule of a parameter of an exploration policy.
The effect of a schedule is defined by the `update_value` function.
"""
abstract type ExplorationSchedule end

"""
update_value(::ExplorationSchedule, value)
Returns an updated value according to the schedule.
"""
function update_value(::ExplorationSchedule, value) end


"""
LinearDecaySchedule
A schedule that linearly decreases a value from `start_val` to `end_val` in `steps` steps.
if the value is greater or equal to `end_val`, it stays constant.
# Constructor
`LinearDecaySchedule(;start_val, end_val, steps)`
"""
@with_kw struct LinearDecaySchedule{R<:Real} <: ExplorationSchedule
start_val::R
end_val::R
steps::Int
end

function update_value(schedule::LinearDecaySchedule, value)
rate = (schedule.start_val - schedule.end_val) / schedule.steps
new_value = max(value - rate, schedule.end_val)
end

"""
ConstantSchedule
A schedule that keeps the value constant
"""
struct ConstantSchedule <: ExplorationSchedule
end

update_value(::ConstantSchedule, value) = value


"""
ExplorationPolicy <: Policy
An abstract type for exploration policies.
Sampling from an exploration policy is done using `action(exploration_policy, on_policy, state)`
"""
abstract type ExplorationPolicy <: Policy end



"""
EpsGreedyPolicy <: ExplorationPolicy
represents an epsilon greedy policy, sampling a random action with a probability `eps` or returning an action from a given policy otherwise.
The evolution of epsilon can be controlled using a schedule. This feature is useful for using those policies in reinforcement learning algorithms.
constructor:
`EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Float64; rng=Random.GLOBAL_RNG, schedule=ConstantSchedule)`
"""
mutable struct EpsGreedyPolicy{T<:Real, S<:ExplorationSchedule, R<:AbstractRNG, A} <: ExplorationPolicy
eps::T
schedule::S
rng::R
actions::A
end

function EpsGreedyPolicy(problem::Union{MDP, POMDP}, eps::Real;
rng::AbstractRNG=Random.GLOBAL_RNG,
schedule::ExplorationSchedule=ConstantSchedule())
return EpsGreedyPolicy(eps, schedule, rng, actions(problem))
end


function POMDPs.action(p::EpsGreedyPolicy{T}, on_policy::Policy, s) where T<:Real
p.eps = update_value(p.schedule, p.eps)
if rand(p.rng) < p.eps
return rand(p.rng, p.actions)
else
return action(on_policy, s)
end
end

# softmax
"""
SoftmaxPolicy <: ExplorationPolicy
represents a softmax policy, sampling a random action according to a softmax function.
The softmax function converts the action values of the on policy into probabilities that are used for sampling.
A temperature parameter can be used to make the resulting distribution more or less wide.
"""
mutable struct SoftmaxPolicy{T<:Real, S<:ExplorationSchedule, R<:AbstractRNG, A} <: ExplorationPolicy
temperature::T
schedule::S
rng::R
actions::A
end

function SoftmaxPolicy(problem, temperature::Real;
rng::AbstractRNG=Random.GLOBAL_RNG,
schedule::ExplorationSchedule=ConstantSchedule())
return SoftmaxPolicy(temperature, schedule, rng, actions(problem))
end

function POMDPs.action(p::SoftmaxPolicy, on_policy::Policy, s)
p.temperature = update_value(p.schedule, p.temperature)
vals = actionvalues(on_policy, s)
vals ./= p.temperature
maxval = maximum(vals)
exp_vals = exp.(vals .- maxval)
exp_vals /= sum(exp_vals)
return p.actions[sample(p.rng, Weights(exp_vals))]
end
26 changes: 0 additions & 26 deletions src/stochastic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,3 @@ function action(policy::CategoricalTabularPolicy, s)
policy.stochastic.distribution = Weights(policy.value.value_table[stateindex(policy.value.mdp, s),:])
return policy.value.act[sample(policy.stochastic.rng, policy.stochastic.distribution)]
end

"""
EpsGreedyPolicy
represents an epsilon greedy policy, sampling a random action with a probability `eps` or sampling from a given stochastic policy otherwise.
constructor:
`EpsGreedyPolicy(mdp::Union{MDP,POMDP}, eps::Float64; rng=Random.GLOBAL_RNG)`
"""
mutable struct EpsGreedyPolicy <: Policy
eps::Float64
val::ValuePolicy
uni::StochasticPolicy
end

EpsGreedyPolicy(mdp::Union{MDP,POMDP}, eps::Float64;
rng=Random.GLOBAL_RNG) = EpsGreedyPolicy(eps, ValuePolicy(mdp), UniformRandomPolicy(mdp, rng))

function action(policy::EpsGreedyPolicy, s)
if rand(policy.uni.rng) > policy.eps
return action(policy.val, s)
else
return action(policy.uni, s)
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ end
@testset "pretty_printing" begin
include("test_pretty_printing.jl")
end
@testset "exploration policies" begin
include("test_exploration_policies.jl")
end
22 changes: 22 additions & 0 deletions test/test_exploration_policies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using POMDPModels

problem = SimpleGridWorld()
# e greedy
policy = EpsGreedyPolicy(problem, 0.5)
a = first(actions(problem))
@inferred action(policy, FunctionPolicy(s->a::Symbol), GWPos(1,1))
policy.eps = 0.0
@test action(policy, FunctionPolicy(s->a), GWPos(1,1)) == a

# softmax
policy = SoftmaxPolicy(problem, 0.5)
on_policy = ValuePolicy(problem)
@inferred action(policy, on_policy, GWPos(1,1))

# test linear schedule
policy = EpsGreedyPolicy(problem, 1.0, schedule=LinearDecaySchedule(start_val=1.0, end_val=0.0, steps=10))
for i=1:11
action(policy, FunctionPolicy(s->a), GWPos(1,1))
@test policy.eps < 1.0
end
@test policy.eps 0.0
4 changes: 0 additions & 4 deletions test/test_stochastic_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,4 @@ policy = CategoricalTabularPolicy(problem)
sim = RolloutSimulator(max_steps=10)
simulate(sim, problem, policy)

policy = EpsGreedyPolicy(problem, 0.5)
sim = RolloutSimulator(max_steps=10)
simulate(sim, problem, policy)

end

0 comments on commit 4894b31

Please sign in to comment.