-
Notifications
You must be signed in to change notification settings - Fork 8
/
your_rllib_config.py
65 lines (53 loc) · 2.01 KB
/
your_rllib_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Dict
from ray.rllib.policy.policy import PolicySpec
from your_rllib_environment import YourEnvironment
from your_openai_spaces import high_level_obs_space, high_level_action_space, \
low_level_obs_space,low_level_action_space
def policy_map_fn(agent_id: str, _episode=None, _worker=None, **_kwargs) -> str:
"""
Maps agent_id to policy_id
"""
if 'high' in agent_id:
return 'high_level_policy'
elif 'low' in agent_id:
return 'low_level_policy'
else:
raise RuntimeError(f'Invalid agent_id: {agent_id}')
def get_multiagent_policies() -> Dict[str,PolicySpec]:
policies: Dict[str,PolicySpec] = {} # policy_id to policy_spec
policies['high_level_policy'] = PolicySpec(
policy_class=None, # use default in trainer, or could be YourHighLevelPolicy
observation_space=high_level_obs_space,
action_space=high_level_action_space,
config={}
)
policies['low_level_policy'] = PolicySpec(
policy_class=None, # use default in trainer, or could be YourLowLevelPolicy
observation_space=low_level_obs_space,
action_space=low_level_action_space,
config={}
)
return policies
policies = get_multiagent_policies()
# see https://github.com/ray-project/ray/blob/releases/1.10.0/rllib/agents/trainer.py
cust_config = {
#"env": "logan_env",
"simple_optimizer": True,
"ignore_worker_failures": True,
"batch_mode": "complete_episodes",
"env": YourEnvironment,
"env_config": {
"is_use_visualization": False,
},
"framework": "torch",
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_map_fn,
"policies_to_train": list(policies.keys()),
"count_steps_by": "env_steps",
"observation_fn": None,
"replay_mode": "independent",
"policy_map_cache": None,
"policy_map_capacity": 100,
},
}