-
-
Notifications
You must be signed in to change notification settings - Fork 110
/
TinyHanabiEnv.jl
109 lines (92 loc) · 3.07 KB
/
TinyHanabiEnv.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
export TinyHanabiEnv
const TINY_HANABI_REWARD_TABLE = begin
t = Array{Int,4}(undef, 3, 3, 2, 2)
t[:, :, 1, 1] = [
10 0 0
4 8 4
10 0 0
]
t[:, :, 1, 2] = [
0 0 10
4 8 4
0 0 10
]
t[:, :, 2, 1] = [
0 0 10
4 8 4
0 0 0
]
t[:, :, 2, 2] = [
10 0 0
4 8 4
10 0 0
]
t
end
struct TinyHanabiEnv <: AbstractEnv
reward_table::Array{Int,4}
cards::Vector{Int}
actions::Vector{Int}
end
"""
TinyHanabiEnv()
See https://arxiv.org/abs/1902.00506.
"""
TinyHanabiEnv() = TinyHanabiEnv(TINY_HANABI_REWARD_TABLE, Int[], Int[])
function RLBase.reset!(env::TinyHanabiEnv)
empty!(env.cards)
empty!(env.actions)
end
RLBase.players(env::TinyHanabiEnv) = 1:2
RLBase.current_player(env::TinyHanabiEnv) =
if length(env.cards) < 2
CHANCE_PLAYER
elseif length(env.actions) == 0
1
else
2
end
(env::TinyHanabiEnv)(action, ::ChancePlayer) = push!(env.cards, action)
(env::TinyHanabiEnv)(action, ::Int) = push!(env.actions, action)
RLBase.action_space(env::TinyHanabiEnv, ::Int) = Base.OneTo(3)
RLBase.action_space(env::TinyHanabiEnv, ::ChancePlayer) = Base.OneTo(2)
RLBase.legal_action_space(env::TinyHanabiEnv, ::ChancePlayer) = findall(!in(env.cards), 1:2)
RLBase.legal_action_space_mask(env::TinyHanabiEnv, ::ChancePlayer) =
[x ∉ env.cards for x in 1:2]
function RLBase.prob(env::TinyHanabiEnv, ::ChancePlayer)
if isempty(env.cards)
[0.5, 0.5]
elseif length(env.cards) == 1
p = ones(2)
p[env.cards[]] = 0.0
p
else
@error "shouldn't reach here."
end
end
RLBase.state_space(env::TinyHanabiEnv, ::InformationSet, ::ChancePlayer) =
((0,), (0, 1), (0, 2), (0, 1, 2), (0, 2, 1)) # (chance_player_id(0), chance_player's actions...)
RLBase.state(env::TinyHanabiEnv, ::InformationSet, ::ChancePlayer) = (0, env.cards...)
function RLBase.state_space(env::TinyHanabiEnv, ::InformationSet, p::Int)
Tuple(
(p, c..., a...) for p in 1:2 for c in ((), 1, 2) for
a in ((), 1:3..., ((i, j) for i in 1:3 for j in 1:3)...)
)
end
function RLBase.state(env::TinyHanabiEnv, ::InformationSet, p::Int)
card = length(env.cards) >= p ? env.cards[p] : ()
(p, card..., env.actions...)
end
RLBase.is_terminated(env::TinyHanabiEnv) = length(env.actions) == 2
RLBase.reward(env::TinyHanabiEnv, player) =
is_terminated(env) ? env.reward_table[env.actions..., env.cards...] : 0
(env::TinyHanabiEnv)(action::Int, ::ChancePlayer) = push!(env.cards, action)
(env::TinyHanabiEnv)(action::Int, ::Int) = push!(env.actions, action)
RLBase.NumAgentStyle(::TinyHanabiEnv) = MultiAgent(2)
RLBase.DynamicStyle(::TinyHanabiEnv) = SEQUENTIAL
RLBase.ActionStyle(::TinyHanabiEnv) = MINIMAL_ACTION_SET
RLBase.InformationStyle(::TinyHanabiEnv) = IMPERFECT_INFORMATION
RLBase.StateStyle(::TinyHanabiEnv) = InformationSet{Tuple{Vararg{Int}}}()
RLBase.RewardStyle(::TinyHanabiEnv) = TERMINAL_REWARD
RLBase.UtilityStyle(::TinyHanabiEnv) = IDENTICAL_UTILITY
RLBase.ChanceStyle(::TinyHanabiEnv) = EXPLICIT_STOCHASTIC