Skip to content

Commit e3cbe2e

Browse files
committed
put game running parameters in play method
1 parent 861d803 commit e3cbe2e

11 files changed

+188
-214
lines changed

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,14 @@ learners = [EpsGreedy(arm_num=len(arms)),
132132
```Python
133133
# Horizon of the game
134134
horizon = 2000
135-
# Record intermediate regrets for each trial
136-
intermediate_regrets = list(range(0, horizon+1, 50))
137135
# Set up simulator using single-player protocol
138-
game = SinglePlayerProtocol(bandit=bandit,
139-
learners=learners,
140-
intermediate_regrets=intermediate_regrets,
141-
horizon=horizon)
136+
game = SinglePlayerProtocol(bandit=bandit, learners=learners)
137+
# Record intermediate regrets after these horizons
138+
intermediate_horizons = list(range(0, horizon+1, 50))
142139
# Start playing the game and for each setup we run 200 trials
143-
game.play(trials=200)
140+
game.play(trials=200,
141+
intermediate_horizons=intermediate_horizons,
142+
horizon=horizon)
144143
```
145144

146145
The following figure shows the simulation results.

banditpylib/protocols/collaborative_learning_protocol.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,32 @@
1313
class CollaborativeLearningProtocol(Protocol):
1414
"""Collaborative learning protocol :cite:`tao2019collaborative`
1515
16-
This protocol is used to simulate the multi-agent game
17-
as discussed in the paper. It runs in rounds. During each round,
18-
the protocol runs the following steps in sequence:
16+
This class defines the communication protocol for the collaborative learning
17+
multi-agent game as discussed in the reference paper. The game runs in
18+
rounds. During each round, the protocol runs the following steps in sequence:
1919
2020
- For each agent,
2121
22-
* fetch the state of the corresponding environment and ask the agent for
23-
actions;
24-
* send the actions to the enviroment for execution;
25-
* update the agent with the feedback of the environment;
22+
* fetch the state of the corresponding bandit environment and ask the agent
23+
for actions;
24+
* send the actions to the bandit environment for execution;
25+
* update the agent with the feedback of the bandit environment;
2626
* repeat the above steps until the agent enters the `WAIT` or `STOP` state.
2727
28-
- If there is at least one agent in `WAIT` state, then receive information
28+
- If there is at least one agent in `WAIT` state, then fetch information
2929
broadcasted from every waiting agent and send them to master to decide
30-
arm assignment of next round. Otherwise, stop the simulaiton.
30+
arm assignment of next round. Otherwise, stop the game.
3131
3232
:param Bandit bandit: bandit environment
3333
:param List[CollaborativeLearner] learners: learners that will be compared
34+
with
3435
3536
.. note::
3637
Each agent interacts with an independent bandit environment.
3738
3839
.. note::
3940
Each action counts as a timestep. The time (or sample) complexity equals to
40-
the maximum number of pulls used by the agents.
41+
the maximum number of pulls across different agents.
4142
4243
.. note::
4344
According to the protocol, number of rounds always equals to number of
@@ -50,23 +51,23 @@ def __init__(self, bandit: Bandit, learners: List[CollaborativeLearner]):
5051
def name(self) -> str:
5152
return 'collaborative_learning_protocol'
5253

53-
def _one_trial(self, random_seed: int, debug: bool) -> bytes:
54-
if debug:
54+
def _one_trial(self, random_seed: int) -> bytes:
55+
if self._debug:
5556
logging.set_verbosity(logging.DEBUG)
5657
np.random.seed(random_seed)
5758

5859
# Initialization
59-
current_learner = cast(CollaborativeLearner, self.current_learner)
60+
current_learner = cast(CollaborativeLearner, self._current_learner)
6061
current_learner.reset()
6162
agents = current_learner.agents
6263
bandits = []
6364
master = current_learner.master
6465
for _ in range(len(agents)):
65-
bandits.append(dcopy(self.bandit))
66+
bandits.append(dcopy(self._bandit))
6667
bandits[-1].reset()
6768

6869
trial = Trial()
69-
trial.bandit = self.bandit.name
70+
trial.bandit = self._bandit.name
7071
trial.learner = current_learner.name
7172

7273
communication_rounds, total_pulls = 0, 0
@@ -124,6 +125,6 @@ def _one_trial(self, random_seed: int, debug: bool) -> bytes:
124125
result = trial.results.add()
125126
result.rounds = communication_rounds
126127
result.total_actions = total_pulls
127-
result.regret = self.bandit.regret(current_learner.goal)
128+
result.regret = self._bandit.regret(current_learner.goal)
128129

129130
return trial.SerializeToString()

banditpylib/protocols/collaborative_learning_protocol_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_simple_run(self):
1919
collaborative_learner = CollaborativeLearningProtocol(
2020
bandit=bandit, learners=[lil_ucb_collaborative_learner])
2121
temp_file = tempfile.NamedTemporaryFile()
22-
collaborative_learner.play(trials=3, output_filename=temp_file.name)
22+
collaborative_learner.play(3, temp_file.name)
2323

2424
with open(temp_file.name, 'rb') as f:
2525
# Check number of records is 3

banditpylib/protocols/single_player_protocol.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,45 @@
1313
class SinglePlayerProtocol(Protocol):
1414
"""Single player protocol
1515
16-
This protocol is used to simulate the ordinary single-player game. It runs in
17-
rounds. During each round, the protocol runs the following steps in sequence:
16+
This class defines the communication protocol for the ordinary single-player
17+
game. The game runs in rounds and during each round, the protocol runs the
18+
following steps in sequence:
1819
19-
* fetch the state of the environment and ask the learner for actions;
20-
* send the actions to the enviroment for execution;
21-
* update the learner with the feedback of the environment.
20+
* fetch the state of the bandit environment and ask the learner for actions;
21+
* send the actions to the bandit environment for execution;
22+
* update the learner with the feedback of the bandit environment.
2223
23-
The simulation stopping criteria is one of the following two:
24+
The game runs until one of the following two stopping conditions is satisfied:
2425
2526
* no actions are returned by the learner;
2627
* total number of actions achieve `horizon`.
2728
28-
2929
:param Bandit bandit: bandit environment
3030
:param List[SinglePlayerLearner] learners: learners to be compared with
31-
:param List[int] intermediate_regrets: a list of rounds. If set, the regrets
32-
after these rounds will be recorded
33-
:param int horizon: horizon of the game (i.e., total number of actions a
34-
leaner can make)
3531
3632
.. note::
3733
During a round, a learner may want to perform multiple actions, which is
38-
so-called batched learner. The total number of rounds shows how often the
39-
learner wants to communicate with the bandit environment which is at most
40-
`horizon`.
34+
so-called batched learner.
4135
"""
42-
def __init__(self,
43-
bandit: Bandit,
44-
learners: List[SinglePlayerLearner],
45-
intermediate_regrets: List[int] = None,
46-
horizon: int = np.inf): # type: ignore
36+
def __init__(self, bandit: Bandit, learners: List[SinglePlayerLearner]):
4737
super().__init__(bandit=bandit, learners=cast(List[Learner], learners))
48-
self.__intermediate_regrets = \
49-
intermediate_regrets if intermediate_regrets is not None else []
50-
self.__horizon = horizon
5138

5239
@property
5340
def name(self) -> str:
5441
return 'single_player_protocol'
5542

56-
def _one_trial(self, random_seed: int, debug: bool) -> bytes:
57-
if debug:
43+
def _one_trial(self, random_seed: int) -> bytes:
44+
if self._debug:
5845
logging.set_verbosity(logging.DEBUG)
5946
np.random.seed(random_seed)
6047

6148
# Reset the bandit environment and the learner
62-
self.bandit.reset()
63-
current_learner = cast(SinglePlayerLearner, self.current_learner)
49+
self._bandit.reset()
50+
current_learner = cast(SinglePlayerLearner, self._current_learner)
6451
current_learner.reset()
6552

6653
trial = Trial()
67-
trial.bandit = self.bandit.name
54+
trial.bandit = self._bandit.name
6855
trial.learner = current_learner.name
6956
rounds = 0
7057
# Number of actions the learner has made
@@ -74,20 +61,20 @@ def add_result():
7461
result = trial.results.add()
7562
result.rounds = rounds
7663
result.total_actions = total_actions
77-
result.regret = self.bandit.regret(current_learner.goal)
64+
result.regret = self._bandit.regret(current_learner.goal)
7865

79-
while total_actions < self.__horizon:
80-
actions = current_learner.actions(self.bandit.context)
66+
while total_actions < self._horizon:
67+
actions = current_learner.actions(self._bandit.context)
8168

8269
# Stop the game if no actions are returned by the learner
8370
if not actions.arm_pulls:
8471
break
8572

8673
# Record intermediate regrets
87-
if rounds in self.__intermediate_regrets:
74+
if rounds in self._intermediate_horizons:
8875
add_result()
8976

90-
feedback = self.bandit.feed(actions)
77+
feedback = self._bandit.feed(actions)
9178
current_learner.update(feedback)
9279

9380
for arm_pull in actions.arm_pulls:

banditpylib/protocols/single_player_protocol_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ def test_simple_run(self):
1515
ordinary_bandit = MultiArmedBandit(arms)
1616
eps_greedy_learner = EpsGreedy(arm_num=3)
1717
single_player = SinglePlayerProtocol(bandit=ordinary_bandit,
18-
learners=[eps_greedy_learner],
19-
horizon=10)
18+
learners=[eps_greedy_learner])
2019
temp_file = tempfile.NamedTemporaryFile()
21-
single_player.play(trials=3, output_filename=temp_file.name)
20+
single_player.play(3, temp_file.name, horizon=10)
2221

2322
with open(temp_file.name, 'rb') as f:
2423
# check number of records is 3

banditpylib/protocols/utils.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from absl import logging
88

99
from google.protobuf.internal.encoder import _VarintBytes # type: ignore
10+
import numpy as np
1011

1112
from banditpylib.bandits import Bandit
1213
from banditpylib.learners import Learner
@@ -23,8 +24,8 @@ def time_seed() -> int:
2324

2425

2526
class Protocol(ABC):
26-
"""Abstract class for a protocol which is used to coordinate the interactions
27-
between the learner and the bandit environment.
27+
"""Abstract class for a communication protocol which defines the principles of
28+
the interactions between the learner and the bandit environment.
2829
2930
:param Bandit bandit: bandit environment
3031
:param List[Learner] learners: learners used to run simulations
@@ -48,26 +49,39 @@ def __init__(self, bandit: Bandit, learners: List[Learner]):
4849

4950
self.__bandit = bandit
5051
self.__learners = learners
51-
# The learner simulated currently
52-
self.__current_learner: Learner = None
5352

5453
@property
5554
@abstractmethod
5655
def name(self) -> str:
5756
"""Protocol name"""
5857

5958
@property
60-
def bandit(self) -> Bandit:
61-
"""Bandit environment the simulator is using the learners to play with"""
59+
def _bandit(self) -> Bandit:
60+
"""Bandit environment"""
6261
return self.__bandit
6362

6463
@property
65-
def current_learner(self) -> Learner:
66-
"""The learner used by the simulator currently"""
64+
def _current_learner(self) -> Learner:
65+
"""The learner in simulation currently"""
6766
return self.__current_learner
6867

68+
@property
69+
def _horizon(self) -> int:
70+
"""Horizon of the game"""
71+
return self.__horizon
72+
73+
@property
74+
def _intermediate_horizons(self) -> List[int]:
75+
"""Horizons used to report intermediate regrets"""
76+
return self.__intermediate_horizons
77+
78+
@property
79+
def _debug(self) -> bool:
80+
"""Debug mode"""
81+
return self.__debug
82+
6983
@abstractmethod
70-
def _one_trial(self, random_seed: int, debug: bool) -> bytes:
84+
def _one_trial(self, random_seed: int) -> bytes:
7185
"""One trial of the game
7286
7387
This method defines how to run one trial of the game.
@@ -91,23 +105,38 @@ def __write_to_file(self, data: bytes):
91105
f.write(data)
92106
f.flush()
93107

94-
def play(self, trials: int, output_filename: str, processes=-1, debug=False):
108+
def play(
109+
self,
110+
trials: int,
111+
output_filename: str,
112+
processes: int = -1,
113+
debug: bool = False,
114+
# pylint: disable=dangerous-default-value
115+
intermediate_horizons: List[int] = [],
116+
horizon: int = np.inf): # type: ignore
95117
"""Start playing the game
96118
97119
Args:
98120
trials: number of repetitions
99-
output_filename: name of the file used to dump the results
121+
output_filename: name of the file used to dump the simulation results
100122
processes: maximum number of processes to run. -1 means no limit
101123
debug: debug mode. When it is set to `True`, `trials` will be
102124
automatically set to 1 and debug information of the trial will be
103125
printed out.
126+
intermediate_horizons: report intermediate regrets after these horizons
127+
horizon: horizon of the game. Different protocols may have different
128+
interpretations.
104129
105130
.. warning::
106131
By default, `output_filename` will be opened with mode `a`.
107132
"""
108133
if debug:
109134
trials = 1
110135

136+
self.__debug = debug
137+
self.__horizon = horizon
138+
self.__intermediate_horizons = intermediate_horizons
139+
111140
for learner in self.__learners:
112141
# Set current learner
113142
self.__current_learner = learner
@@ -123,7 +152,7 @@ def play(self, trials: int, output_filename: str, processes=-1, debug=False):
123152
trial_results = []
124153
for _ in range(trials):
125154
result = pool.apply_async(self._one_trial,
126-
args=[time_seed(), debug],
155+
args=[time_seed()],
127156
callback=self.__write_to_file)
128157

129158
trial_results.append(result)

0 commit comments

Comments
 (0)