Skip to content

Commit 924a272

Browse files
committed
subfolder name change
1 parent 7616ef2 commit 924a272

File tree

8 files changed

+473
-561
lines changed

8 files changed

+473
-561
lines changed

banditpylib/learners/collaborative_learner/__init__.py renamed to banditpylib/learners/mab_collaborative_ftbai_learner/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .utils import *
22
from .collaborative_learning import *
33

4-
__all__ = ['CollaborativeBAILearner', 'CollaborativeBAIAgent',
4+
__all__ = [
5+
'CollaborativeBAILearner', 'CollaborativeBAIAgent',
56
'CollaborativeBAIMaster', 'LilUCBHeuristicCollaborativeBAIAgent',
6-
'LilUCBHeuristicCollaborativeBAIMaster']
7+
'LilUCBHeuristicCollaborativeBAIMaster'
8+
]

banditpylib/learners/collaborative_learner/collaborative_learning.py renamed to banditpylib/learners/mab_collaborative_ftbai_learner/collaborative_learning.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .utils import CollaborativeBAIAgent, CollaborativeBAIMaster
1616

17+
1718
class LilUCBHeuristicCollaborative(MABFixedConfidenceBAILearner):
1819
"""LilUCB heuristic policy :cite:`jamieson2014lil`
1920
Modified implementation to supplement CollaborativeAgent
@@ -25,18 +26,20 @@ class LilUCBHeuristicCollaborative(MABFixedConfidenceBAILearner):
2526
:param np.ndarray assigned_arms: arm indices the learner has to work with
2627
:param str name: alias name
2728
"""
28-
def __init__(self, arm_num: int, confidence: float,
29-
assigned_arms: np.ndarray = None, name: str = None):
30-
assert np.max(assigned_arms)<arm_num and len(assigned_arms)<=arm_num, (
31-
"assigned arms should be a subset of [arm_num]\nReceived: "
32-
+ str(assigned_arms))
29+
def __init__(self,
30+
arm_num: int,
31+
confidence: float,
32+
assigned_arms: np.ndarray = None,
33+
name: str = None):
34+
assert np.max(assigned_arms) < arm_num and len(assigned_arms) <= arm_num, (
35+
"assigned arms should be a subset of [arm_num]\nReceived: " +
36+
str(assigned_arms))
3337
super().__init__(arm_num=arm_num, confidence=confidence, name=name)
3438
if assigned_arms is not None:
3539
self.__assigned_arms = assigned_arms
3640
else:
3741
self.__assigned_arms = np.arange(arm_num)
3842

39-
4043
def _name(self) -> str:
4144
return 'lilUCB_heur_collaborative'
4245

@@ -80,7 +83,7 @@ def __ucb(self) -> np.ndarray:
8083
def actions(self, context=None) -> Actions:
8184
del context
8285
if self.__stage == 'initialization':
83-
actions = Actions() # default state is normal
86+
actions = Actions() # default state is normal
8487

8588
# 1 pull each for every assigned arm
8689
for arm_id in self.__assigned_arms:
@@ -109,7 +112,7 @@ def update(self, feedback: Feedback):
109112
for arm_feedback in feedback.arm_feedbacks:
110113
# reverse map from bandit index to local index
111114
pseudo_arm_index = np.where(
112-
self.__assigned_arms==arm_feedback.arm.id)[0][0]
115+
self.__assigned_arms == arm_feedback.arm.id)[0][0]
113116
self.__pseudo_arms[pseudo_arm_index].update(
114117
np.array(arm_feedback.rewards))
115118
self.__total_pulls += len(arm_feedback.rewards)
@@ -120,12 +123,10 @@ def update(self, feedback: Feedback):
120123
@property
121124
def best_arm(self) -> int:
122125
# map best arm local index to actual bandit index
123-
return self.__assigned_arms[
124-
argmax_or_min_tuple([
126+
return self.__assigned_arms[argmax_or_min_tuple([
125127
(pseudo_arm.total_pulls, arm_id)
126128
for (arm_id, pseudo_arm) in enumerate(self.__pseudo_arms)
127-
])
128-
]
129+
])]
129130

130131
def get_total_pulls(self) -> int:
131132
return self.__total_pulls
@@ -142,9 +143,11 @@ class LilUCBHeuristicCollaborativeBAIAgent(CollaborativeBAIAgent):
142143
(over all rounds combined)
143144
:param Optional[str] name: alias name
144145
"""
145-
146-
def __init__(self, arm_num: int, rounds: int,
147-
horizon: int, name: Optional[str] = None):
146+
def __init__(self,
147+
arm_num: int,
148+
rounds: int,
149+
horizon: int,
150+
name: Optional[str] = None):
148151
super().__init__(name)
149152
if arm_num <= 1:
150153
raise ValueError('Number of arms is expected at least 2. Got %d.' %
@@ -183,8 +186,8 @@ def set_input_arms(self, arms: List[int]):
183186

184187
self.__assigned_arms = np.array(arms)
185188
# confidence of 0.01 suggested in the paper
186-
self.__central_algo = LilUCBHeuristicCollaborative(self.__arm_num,
187-
0.99, self.__assigned_arms)
189+
self.__central_algo = LilUCBHeuristicCollaborative(self.__arm_num, 0.99,
190+
self.__assigned_arms)
188191
self.__central_algo.reset()
189192
if self.__stage == "unassigned":
190193
self.__stage = "preparation"
@@ -207,7 +210,7 @@ def actions(self, context=None) -> Actions:
207210
self.__stage = "learning"
208211
self.__learning_arm = self.__assigned_arms[0]
209212
return self.actions()
210-
if self.__central_algo.get_total_pulls() >= self.__horizon//2:
213+
if self.__central_algo.get_total_pulls() >= self.__horizon // 2:
211214
self.__stage = "learning"
212215
# use whatever best_arm the central algo outputs
213216
self.__learning_arm = self.__central_algo.best_arm
@@ -233,7 +236,7 @@ def actions(self, context=None) -> Actions:
233236
return actions
234237
else:
235238
arm_pull = actions.arm_pulls.add()
236-
arm_pull.arm.id = self.__learning_arm # pylint: disable=protobuf-type-error
239+
arm_pull.arm.id = self.__learning_arm # pylint: disable=protobuf-type-error
237240
arm_pull.times = self.__num_pulls_learning
238241
return actions
239242

@@ -249,16 +252,16 @@ def actions(self, context=None) -> Actions:
249252

250253
else:
251254
raise Exception(self.name + ": " + self.__stage +
252-
" does not allow actions to be played")
255+
" does not allow actions to be played")
253256

254257
def update(self, feedback: Feedback):
255-
self.__learning_mean = None # default in case learning_arm is None
258+
self.__learning_mean = None # default in case learning_arm is None
256259
num_pulls = 0
257260
for arm_feedback in feedback.arm_feedbacks:
258261
num_pulls += len(arm_feedback.rewards)
259262
if self.__central_algo_action_taken:
260263
self.__central_algo.update(feedback)
261-
elif num_pulls>0:
264+
elif num_pulls > 0:
262265
# non-zero pulls not by central_algo => learning step was done
263266
for arm_feedback in feedback.arm_feedbacks:
264267
if arm_feedback.arm.id == self.__learning_arm:
@@ -282,10 +285,11 @@ def broadcast(self) -> Dict[int, Tuple[float, int]]:
282285
return_dict = {}
283286
if self.__learning_arm:
284287
return_dict[self.__learning_arm] = (self.__learning_mean,
285-
self.__pulls_used)
288+
self.__pulls_used)
286289
self.__complete_round()
287290
return return_dict
288291

292+
289293
class LilUCBHeuristicCollaborativeBAIMaster(CollaborativeBAIMaster):
290294
r"""Implementation of master in Collaborative Learning Algorithm
291295
@@ -297,9 +301,12 @@ class LilUCBHeuristicCollaborativeBAIMaster(CollaborativeBAIMaster):
297301
:param int num_agents: number of agents
298302
:param Optional[str] name: alias name
299303
"""
300-
301-
def __init__(self, arm_num:int, rounds: int,
302-
horizon: int, num_agents: int, name: Optional[str] = None):
304+
def __init__(self,
305+
arm_num: int,
306+
rounds: int,
307+
horizon: int,
308+
num_agents: int,
309+
name: Optional[str] = None):
303310
super().__init__(name)
304311
if arm_num <= 1:
305312
raise ValueError('Number of arms is expected at least 2. Got %d.' %
@@ -371,8 +378,8 @@ def random_round(x: float) -> int:
371378
if i >= len(active_arms_copy):
372379
break
373380
num_arms = random_round(num_arms_per_agent)
374-
agent_arm_assignment[agent_id] += active_arms_copy[i: i + num_arms]
375-
i+= num_arms
381+
agent_arm_assignment[agent_id] += active_arms_copy[i:i + num_arms]
382+
i += num_arms
376383
if i < len(active_arms_copy):
377384
agent_arm_assignment[agent_ids[-1]] += active_arms_copy[i:]
378385

@@ -381,8 +388,10 @@ def random_round(x: float) -> int:
381388
def initial_arm_assignment(self) -> Dict[int, List[int]]:
382389
return self.__assign_arms(list(range(self.__num_agents)))
383390

384-
def elimination(self, agent_ids: List[int],
385-
messages: Dict[int, Dict[int, Tuple[float, int]]]) -> Dict[int, List[int]]:
391+
def elimination(
392+
self, agent_ids: List[int],
393+
messages: Dict[int, Dict[int, Tuple[float,
394+
int]]]) -> Dict[int, List[int]]:
386395

387396
aggregate_messages: Dict[int, Tuple[float, int]] = {}
388397
for agent_id in messages.keys():
@@ -398,17 +407,17 @@ def elimination(self, agent_ids: List[int],
398407

399408
accumulated_arm_ids = np.array(list(aggregate_messages.keys()))
400409
accumulated_em_mean_rewards = np.array(
401-
list(map(lambda x: aggregate_messages[x][0], aggregate_messages.keys())))
410+
list(map(lambda x: aggregate_messages[x][0],
411+
aggregate_messages.keys())))
402412

403413
# elimination
404414
confidence_radius = np.sqrt(
405-
self.__comm_rounds * np.log(200 * self.__num_agents * self.__comm_rounds)
406-
/ (self.__T * max(1, self.__num_agents / len(self.__active_arms)))
407-
)
415+
self.__comm_rounds *
416+
np.log(200 * self.__num_agents * self.__comm_rounds) /
417+
(self.__T * max(1, self.__num_agents / len(self.__active_arms))))
408418
highest_em_reward = np.max(accumulated_em_mean_rewards)
409419
self.__active_arms = list(
410-
accumulated_arm_ids[accumulated_em_mean_rewards >=
411-
highest_em_reward - 2 * confidence_radius]
412-
)
420+
accumulated_arm_ids[accumulated_em_mean_rewards >= highest_em_reward -
421+
2 * confidence_radius])
413422

414423
return self.__assign_arms(agent_ids)

banditpylib/learners/collaborative_learner/lilucb_heur_collaborative_test.py renamed to banditpylib/learners/mab_collaborative_ftbai_learner/lilucb_heur_collaborative_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def test_simple_run(self):
1010
arm_num = 3
1111
confidence = 0.95
1212
learner = LilUCBHeuristicCollaborative(arm_num=arm_num,
13-
confidence=confidence, assigned_arms=np.arange(arm_num))
13+
confidence=confidence,
14+
assigned_arms=np.arange(arm_num))
1415
learner.reset()
1516

1617
while True:
@@ -23,9 +24,8 @@ def test_simple_run(self):
2324
arm_feedback = feedback.arm_feedbacks.add()
2425
arm_feedback.arm.id = arm_pull.arm.id
2526
arm_feedback.rewards.extend(
26-
list(
27-
np.random.normal(arm_pull.arm.id / arm_num, 1,
28-
arm_pull.times)))
27+
list(np.random.normal(arm_pull.arm.id / arm_num, 1,
28+
arm_pull.times)))
2929
learner.update(feedback)
3030

3131
assert learner.best_arm in list(range(arm_num))

banditpylib/learners/collaborative_learner/utils.py renamed to banditpylib/learners/mab_collaborative_ftbai_learner/utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from banditpylib.data_pb2 import Arm, Feedback, Actions, Context
77
from banditpylib.learners import Goal, IdentifyBestArm, Learner
88

9+
910
class CollaborativeBAIAgent(ABC):
1011
r"""One individual agent
1112
1213
This agent aims to identify the best arm with other agents.
1314
1415
:param Optional[str] name: alias name
1516
"""
16-
1717
def __init__(self, name: Optional[str]):
1818
self.__name = self._name() if name is None else name
1919

@@ -86,7 +86,6 @@ class CollaborativeBAIMaster(ABC):
8686
8787
:param Optional[str] name: alias name
8888
"""
89-
9089
def __init__(self, name: Optional[str]):
9190
self.__name = self._name() if name is None else name
9291

@@ -119,8 +118,10 @@ def initial_arm_assignment(self) -> Dict[int, List[int]]:
119118
"""
120119

121120
@abstractmethod
122-
def elimination(self, agent_ids: List[int],
123-
messages: Dict[int, Dict[int, Tuple[float, int]]]) -> Dict[int, List[int]]:
121+
def elimination(
122+
self, agent_ids: List[int],
123+
messages: Dict[int, Dict[int, Tuple[float,
124+
int]]]) -> Dict[int, List[int]]:
124125
"""Update the set of active arms based on some criteria
125126
and return arm assignment
126127
@@ -141,9 +142,11 @@ class CollaborativeBAILearner(Learner):
141142
:param int num_agents: total number of agents involved
142143
:param Optional[str] name: alias name
143144
"""
144-
def __init__(self, agent: CollaborativeBAIAgent,
145-
master: CollaborativeBAIMaster, num_agents: int,
146-
name: Optional[str] = None):
145+
def __init__(self,
146+
agent: CollaborativeBAIAgent,
147+
master: CollaborativeBAIMaster,
148+
num_agents: int,
149+
name: Optional[str] = None):
147150
super().__init__(name)
148151
self.__agents = []
149152
for _ in range(num_agents):
@@ -176,7 +179,7 @@ def goal(self) -> Goal:
176179
best_arm = self.__agents[0].best_arm
177180
for agent in self.__agents[1:]:
178181
if best_arm != agent.best_arm:
179-
best_arm = -1 # implies regret of 1
182+
best_arm = -1 # implies regret of 1
180183
break
181184
arm.id = best_arm
182185
return IdentifyBestArm(best_arm=arm)

banditpylib/protocols/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,4 @@
22
from .single_player import *
33
from .collaborative_protocol import *
44

5-
__all__ = [
6-
'Protocol',
7-
'SinglePlayerProtocol',
8-
'CollaborativeLearningProtocol'
9-
]
5+
__all__ = ['Protocol', 'SinglePlayerProtocol', 'CollaborativeLearningProtocol']

banditpylib/protocols/collaborative_protocol.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from banditpylib.bandits import Bandit
88
from banditpylib.data_pb2 import Trial, Actions
99
from banditpylib.learners import Learner
10-
from banditpylib.learners.collaborative_learner import CollaborativeBAILearner
10+
from banditpylib.learners.mab_collaborative_ftbai_learner \
11+
import CollaborativeBAILearner
1112
from .utils import Protocol
1213

1314

@@ -38,8 +39,7 @@ class CollaborativeLearningProtocol(Protocol):
3839
so-called batched learner. In this case, eah action counts as a timestep
3940
used.
4041
"""
41-
def __init__(self,
42-
bandit: Bandit, learners: List[CollaborativeBAILearner]):
42+
def __init__(self, bandit: Bandit, learners: List[CollaborativeBAILearner]):
4343
super().__init__(bandit=bandit, learners=cast(List[Learner], learners))
4444

4545
@property
@@ -90,7 +90,7 @@ def _one_trial(self, random_seed: int, debug: bool) -> bytes:
9090
elif actions.state == Actions.WAIT:
9191
agent_in_wait_ids.append(agent_id)
9292
break
93-
else: # actions.state == Actions.STOP
93+
else: # actions.state == Actions.STOP
9494
break
9595
max_pulls = max(max_pulls, pulls)
9696
total_pulls += max_pulls
@@ -113,8 +113,8 @@ def _one_trial(self, random_seed: int, debug: bool) -> bytes:
113113
# Send info to master for elimination to get arm assignment for next round
114114
# agent_arm_assignment: key is agent_id, value is a list storing arm ids
115115
# assigned to this agent
116-
agent_arm_assignment = master.elimination(
117-
agent_in_wait_ids, accumulated_messages)
116+
agent_arm_assignment = master.elimination(agent_in_wait_ids,
117+
accumulated_messages)
118118
for agent_id in agent_arm_assignment:
119119
agents[agent_id].set_input_arms(agent_arm_assignment[agent_id])
120120
communication_rounds += 1
@@ -123,7 +123,6 @@ def _one_trial(self, random_seed: int, debug: bool) -> bytes:
123123
result = trial.results.add()
124124
result.rounds = communication_rounds
125125
result.total_actions = total_pulls
126-
result.regret = self.bandit.regret(
127-
current_learner.goal)
126+
result.regret = self.bandit.regret(current_learner.goal)
128127

129128
return trial.SerializeToString()

0 commit comments

Comments
 (0)