/
random_policy.jl
77 lines (62 loc) · 2.02 KB
/
random_policy.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
export RandomPolicy
using Random
"""
RandomPolicy(action_space=nothing; rng=Random.GLOBAL_RNG)
If `action_space` is `nothing`, then it will use the `legal_action_space` at
runtime to randomly select an action. Otherwise, a random element within
`action_space` is selected.
!!! note
You should always set `action_space=nothing` when dealing with environments
of `FULL_ACTION_SET`.
"""
struct RandomPolicy{S,R<:AbstractRNG} <: AbstractPolicy
action_space::S
rng::R
end
Random.seed!(p::RandomPolicy, seed) = Random.seed!(p.rng, seed)
RandomPolicy(s = nothing; rng = Random.GLOBAL_RNG) = RandomPolicy(s, rng)
(p::RandomPolicy{Nothing})(env) = rand(p.rng, legal_action_space(env))
(p::RandomPolicy)(env) = rand(p.rng, p.action_space)
function RLBase.prob(p::RandomPolicy{<:Union{AbstractVector,Tuple}}, env::AbstractEnv)
prob(p, state(env))
end
function RLBase.prob(p::RandomPolicy{<:Union{AbstractVector,Tuple}}, s)
n = length(p.action_space)
Categorical(fill(1 / n, n); check_args = false)
end
RLBase.prob(p::RandomPolicy{Nothing}, env::AbstractEnv) = prob(p, env, ChanceStyle(env))
RLBase.prob(p::RandomPolicy{Nothing}, x) =
@error "no I really don't know how to calculate the prob from nothing"
function RLBase.prob(
p::RandomPolicy{Nothing},
env::AbstractEnv,
::RLBase.AbstractChanceStyle,
)
mask = legal_action_space_mask(env)
n = sum(mask)
prob = zeros(length(mask))
prob[mask] .= 1 / n
prob
end
function RLBase.prob(
p::RandomPolicy{Nothing},
env::AbstractEnv,
::RLBase.ExplicitStochastic,
)
if current_player(env) == chance_player(env)
prob(env, chance_player(env))
else
prob(p, env, DETERMINISTIC)
end
end
RLBase.update!(p::RandomPolicy, x) = nothing
RLBase.prob(p::RandomPolicy, env_or_state, a) = 1 / length(p.action_space)
function RLBase.prob(p::RandomPolicy{Nothing}, env::AbstractEnv, a)
# we can safely assume s is discrete here.
s = legal_action_space(env)
if a in s
1.0 / length(s)
else
0.0
end
end