/
CommonRLInterface.jl
111 lines (82 loc) · 3.35 KB
/
CommonRLInterface.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import CommonRLInterface
const CRL = CommonRLInterface
#####
# CommonRLEnv
#####
struct CommonRLEnv{T<:AbstractEnv} <: CRL.AbstractEnv
env::T
end
function Base.convert(::Type{CRL.AbstractEnv}, env::AbstractEnv)
CommonRLEnv(env)
end
CRL.reset!(env::CommonRLEnv) = reset!(env.env)
CRL.actions(env::CommonRLEnv) = action_space(env.env)
CRL.terminated(env::CommonRLEnv) = is_terminated(env.env)
function CRL.act!(env::CommonRLEnv, a)
act!(env.env, a)
reward(env.env)
end
function find_state_style(env::AbstractEnv, s)
find_state_style(StateStyle(env), s)
end
find_state_style(::Tuple{}, s) = nothing
function find_state_style(ss::Tuple, s)
x = first(ss)
if x isa s
x
else
find_state_style(Base.tail(ss), s)
end
end
# !!! may need to be extended by user
CRL.observe(env::CommonRLEnv) = state(env.env)
CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = !isnothing(find_state_style(env.env, InternalState))
CRL.state(env::CommonRLEnv) = state(env.env, find_state_style(env.env, InternalState))
CRL.clone(env::CommonRLEnv) = CommonRLEnv(copy(env.env))
CRL.render(env::CommonRLEnv) = @error "unsupported yet..."
CRL.player(env::CommonRLEnv) = current_player(env.env)
CRL.valid_actions(x::CommonRLEnv) = legal_action_space(x.env)
CRL.provided(::typeof(CRL.valid_actions), env::CommonRLEnv) =
ActionStyle(env.env) === FullActionSet()
CRL.valid_action_mask(x::CommonRLEnv) = legal_action_space_mask(x.env)
CRL.provided(::typeof(CRL.valid_action_mask), env::CommonRLEnv) =
ActionStyle(env.env) === FullActionSet()
CRL.observations(env::CommonRLEnv) = state_space(env.env)
#####
# RLBaseEnv
#####
mutable struct RLBaseEnv{T<:CRL.AbstractEnv,R} <: AbstractEnv
env::T
r::R
end
Base.convert(::Type{AbstractEnv}, env::CRL.AbstractEnv) = convert(RLBaseEnv, env)
Base.convert(::Type{RLBaseEnv}, env::CRL.AbstractEnv) = RLBaseEnv(env, 0.0f0) # can not determine reward ahead. Assume `Float32`.
RLBase.StateStyle(env::RLBaseEnv) = (
(CRL.provided(CRL.observe, env.env) ? (Observation{Any}(),) : ())...,
(CRL.provided(CRL.state, env.env) ? (InternalState{Any}(),) : ())...,
)
state(env::RLBaseEnv, ::Observation) = CRL.observe(env.env)
state(env::RLBaseEnv, ::InternalState) = CRL.state(env.env)
state_space(env::RLBaseEnv, ::Observation) = CRL.observations(env.env)
action_space(env::RLBaseEnv) = CRL.actions(env.env)
reward(env::RLBaseEnv) = env.r
is_terminated(env::RLBaseEnv) = CRL.terminated(env.env)
legal_action_space(env::RLBaseEnv) = CRL.valid_actions(env.env)
legal_action_space_mask(env::RLBaseEnv) = CRL.valid_action_mask(env.env)
reset!(env::RLBaseEnv) = CRL.reset!(env.env)
act!(env::RLBaseEnv, a) = env.r = CommonRLInterface.act!(env.env, a)
Base.copy(env::CommonRLEnv) = RLBaseEnv(CRL.clone(env.env), env.r)
ActionStyle(env::RLBaseEnv) =
CRL.provided(CRL.valid_actions, env.env) ? FullActionSet() : MinimalActionSet()
current_player(env::RLBaseEnv) = CRL.player(env.env)
"""
players(env::RLBaseEnv)
Players in the game. This is a no-op for single-player games. `MultiAgent` games should implement this method.
"""
players(env::RLBaseEnv) = CRL.players(env.env)
#
"""
next_player!(env::E) where {E<:AbstractEnv}
Advance to the next player. This is a no-op for single-player and simultaneous games. `Sequential` `MultiAgent` games should implement this method.
"""
next_player!(env::E) where {E<:AbstractEnv} = nothing