Skip to content

Commit

Permalink
v4.1.4 perf(pettingzoo): added multi-agent environment support pettin…
Browse files Browse the repository at this point in the history
…gzoo. (#41, #34)

1. fixed a little bug in maddpg
  • Loading branch information
StepNeverStop committed Jul 30, 2021
1 parent 964f78b commit cca69b2
Show file tree
Hide file tree
Showing 36 changed files with 488 additions and 156 deletions.
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ This project supports:
- Box -> Discrete
- Box -> Box
- Box/Discrete -> Tuple(Discrete, Discrete, Discrete)
- [PettingZoo](https://www.pettingzoo.ml/#)
- MultiAgent training.
- MultiImage input. Images will resized to same shape before store into replay buffer, like `[84, 84, 3]`.
- Four types of Replay Buffer, Default is ER:
Expand Down Expand Up @@ -93,6 +94,10 @@ $ pip install gym[atari]
$ pip install gym[box2d]
```

```bash
$ pip install pettingzoo[all]
```

You can download the builded docker image from [here](https://hub.docker.com/r/keavnn/rls):
```bash
$ docker pull keavnn/rls:latest
Expand Down Expand Up @@ -178,20 +183,19 @@ For now, these algorithms are available:

```python
"""
usage: run.py [-h] [-c COPYS] [--seed SEED] [-r] [-p {gym,unity}]
usage: run.py [-h] [-c COPYS] [--seed SEED] [-r] [-p {gym,unity,pettingzoo}]
[-a {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn}]
[-i] [-l LOAD_PATH] [-m MODELS] [-n NAME] [-s SAVE_FREQUENCY]
[--apex {learner,worker,buffer,evaluator}] [--config-file CONFIG_FILE] [--store-dir STORE_DIR]
[--episode-length EPISODE_LENGTH] [--prefill-steps PREFILL_STEPS] [--prefill-choose] [--hostname]
[--info INFO] [-e ENV_NAME] [-f FILE_NAME] [--no-save] [-d DEVICE]
[-i] [-l LOAD_PATH] [-m MODELS] [-n NAME] [-s SAVE_FREQUENCY] [--apex {learner,worker,buffer,evaluator}] [--config-file CONFIG_FILE]
[--store-dir STORE_DIR] [--episode-length EPISODE_LENGTH] [--prefill-steps PREFILL_STEPS] [--prefill-choose] [--hostname] [--info INFO]
[-e ENV_NAME] [-f FILE_NAME] [--no-save] [-d DEVICE]
optional arguments:
-h, --help show this help message and exit
-c COPYS, --copys COPYS
nums of environment copys that collect data in parallel
--seed SEED specify the random seed of module random, numpy and pytorch
-r, --render whether render game interface
-p {gym,unity}, --platform {gym,unity}
-p {gym,unity,pettingzoo}, --platform {gym,unity,pettingzoo}
specify the platform of training environment
-a {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn}, --algorithm {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn}
specify the training algorithm
Expand All @@ -211,8 +215,7 @@ optional arguments:
--episode-length EPISODE_LENGTH
specify the maximum step per episode
--prefill-steps PREFILL_STEPS
specify the number of experiences that should be collected before start training, use for
off-policy algorithms
specify the number of experiences that should be collected before start training, use for off-policy algorithms
--prefill-choose whether choose action using model or choose randomly
--hostname whether concatenate hostname with the training name
--info INFO write another information that describe this training task
Expand All @@ -223,6 +226,7 @@ optional arguments:
--no-save specify whether save models/logs/summaries while training or not
-d DEVICE, --device DEVICE
specify the device that operate Torch.Tensor
"""
```

```python
Expand Down
2 changes: 1 addition & 1 deletion rls/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '4'
_MINOR_VERSION = '1'
_PATCH_VERSION = '3'
_PATCH_VERSION = '4'

# Example: '0.4.2'
__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
3 changes: 2 additions & 1 deletion rls/algorithms/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class Base:
def __init__(self,
no_save=False,
base_dir='',
device='cpu'):
device='cpu',
**kwargs): # TODO: remove this
'''
inputs:
a_dim: action spaces
Expand Down
2 changes: 1 addition & 1 deletion rls/algorithms/base/ma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def write_summaries(self,
'''
write summaries showing in tensorboard.
'''
if 'model' in summaries.keys():
if 'model' in summaries.keys():
super().write_summaries(global_step, summaries=summaries.pop('model'), writer=self.writer)
for i, summary in summaries.items():
super().write_summaries(global_step, summaries=summary, writer=self.writers[i])
54 changes: 30 additions & 24 deletions rls/algorithms/multi/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,14 @@ def __init__(self,
critic_target.eval()
actor_target = deepcopy(actor)
actor_target.eval()
self.rep_nets.append(ep_net)
self.rep_nets.append(rep_net)
self.critics.append(critic)
self.target_rep_nets.append(target_rep_net)
self.critic_targets.append(critic_target)
self.actors.append(actor)
self.actor_targets.append(actor_target)

actor_oplr = OPLR(actor,
actor_lr)
actor_oplr = OPLR(actor, actor_lr)
critic_oplr = OPLR([rep_net, critic], critic_lr)
self.actor_oplrs.append(actor_oplr)
self.critic_oplrs.append(critic_oplr)
Expand Down Expand Up @@ -188,8 +187,25 @@ def _train(self, BATCHs):
j = 0 if self.share_params else i
q_targets.append(self.critic_targets[j](t.cat(feats_, -1), target_actions))

q_loss = []
for i in range(self.n_agents_percopy):
j = 0 if self.share_params else i
q = self.critics[j](
t.cat(feats, -1),
t.cat([BATCH.action for BATCH in BATCHs], -1)
)
dc_r = (BATCHs[i].reward + self.gamma * q_targets[i] * (1 - BATCHs[i].done)).detach()

td_error = dc_r - q
q_loss.append(0.5 * td_error.square().mean())
if self.share_params:
self.critic_oplrs[0].step(sum(q_loss))

actor_loss = []
feats = [feat.detach() for feat in feats]
for i in range(self.n_agents_percopy):
j = 0 if self.share_params else i

if self.envspecs[i].is_continuous:
mu = self.actors[j](feats[i])
else:
Expand All @@ -205,26 +221,16 @@ def _train(self, BATCHs):
t.cat(feats, -1),
t.cat([BATCH.action for BATCH in BATCHs[:i]]+[mu]+[BATCH.action for BATCH in BATCHs[i+1:]], -1)
)
actor_loss = -q_actor.mean()

q = self.critics[j](
t.cat(feats, -1),
t.cat([BATCH.action for BATCH in BATCHs], -1)
)
dc_r = (BATCHs[i].reward + self.gamma * q_targets[i] * (1 - BATCHs[i].done)).detach()

td_error = dc_r - q
q_loss = 0.5 * td_error.square().mean()

self.critic_oplrs[j].step(q_loss)
self.actor_oplrs[j].step(actor_loss)

summaries[i] = dict([
['LOSS/actor_loss', actor_loss],
['LOSS/critic_loss', q_loss],
['Statistics/q_min', q.min()],
['Statistics/q_mean', q.mean()],
['Statistics/q_max', q.max()]
])
actor_loss.append(-q_actor.mean())
if self.share_params:
self.actor_oplrs[0].step(sum(actor_loss))

# summaries[i] = dict([
# ['LOSS/actor_loss', actor_loss],
# ['LOSS/critic_loss', q_loss],
# ['Statistics/q_min', q.min()],
# ['Statistics/q_mean', q.mean()],
# ['Statistics/q_max', q.max()]
# ])
self.global_step.add_(1)
return summaries
2 changes: 1 addition & 1 deletion rls/algorithms/multi/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self,
target_rep_net.eval()
q_target_net = deepcopy(q_net)
q_target_net.eval()
self.rep_nets.append(ep_net)
self.rep_nets.append(rep_net)
self.q_nets.append(q_net)
self.target_rep_nets.append(target_rep_net)
self.q_target_nets.append(q_target_net)
Expand Down
2 changes: 1 addition & 1 deletion rls/algorithms/wrapper/IndependentMA.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def save(self, **kwargs):

def resume(self, base_dir: Optional[str] = None) -> Dict:
for i, model in enumerate(self.models):
if self._n_agents > 1:
if self._n_agents > 1 and base_dir is not None:
base_dir += f'/i{model.__class__.__name__}-{i}'
training_info = model.resume(base_dir)
else:
Expand Down
5 changes: 1 addition & 4 deletions rls/common/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ def wrapper(*args, **kwargs):
args = [to_tensor(x, dtype=dtype, device=device) for x in args]
kwargs = {k: to_tensor(v, dtype=dtype, device=device) for k, v in kwargs.items()}
output = func(*args, **kwargs)
if isinstance(output, (tuple, list)):
output = [to_numpy(x) for x in output]
else:
output = to_numpy(output)
output = to_numpy(output)
return output

return wrapper
7 changes: 5 additions & 2 deletions rls/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,24 @@ def __init__(self,

# ENV
self.env = make_env(self.env_args)
# logger.info(self.env.GroupsSpec)

# ALGORITHM CONFIG
self.agent_class, self.policy_mode, self.is_multi = get_model_info(self.train_args.algorithm)

if self.policy_mode == 'on-policy': # TODO:
self.train_args.pre_fill_steps = 0 # if on-policy, prefill experience replay is no longer needed.
self.train_args.prefill_steps = 0 # if on-policy, prefill experience replay is no longer needed.

self.initialize()
self.start_time = time.time()

def initialize(self):
logger.info('Initialize Agent Begin.')
if not self.is_multi:
self.model = IndependentMA(self.agent_class, self.env.GroupsSpec, self.algo_args)
else:
self.model = self.agent_class(envspecs=self.env.GroupsSpec, **self.algo_args)
logger.info('Initialize Agent Successfully.')
_train_info = self.model.resume(self.train_args.load_path)
self.begin_train_step = _train_info['train_step']
self.begin_frame_step = _train_info['frame_step']
Expand All @@ -78,7 +81,7 @@ def __call__(self) -> NoReturn:
model=self.model,
reset_config=self.train_args.reset_config,
step_config=self.train_args.step_config,
pre_fill_steps=self.train_args.pre_fill_steps,
prefill_steps=self.train_args.prefill_steps,
prefill_choose=self.train_args.prefill_choose)
train(env=self.env,
model=self.model,
Expand Down
15 changes: 8 additions & 7 deletions rls/configs/algorithms.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# algorithm config = general UNION on_policy/off_policy UNION specific algorithm config
# if same key happened in multiple configurations, then specific algorithm config override on_policy/off_policy override general
# algorithm config = policy UNION on_policy/off_policy UNION specific algorithm config
# if same key happened in multiple configurations, then specific algorithm config override on_policy/off_policy override policy
# priority 1: specific algorithm config
# priority 2: on_policy/off_policy
# priority 3: general
# priority 3: policy

general: &general
policy: &policy
decay_lr: false

use_curiosity: false # whether to use ICM or not
Expand All @@ -13,6 +13,7 @@ general: &general
curiosity_beta: 0.2 # weight that scale the forward loss and inverse loss of ICM

normalize_vector_obs: false

# ----- could be overrided in specific algorithms, i.e. dqn, so as to using different type of visual net, memory net.
representation_net_params: &representation_net_params
use_encoder: false
Expand All @@ -30,11 +31,11 @@ general: &general
# -----

on_policy: &on_policy
<<: *general
<<: *policy
rnn_time_steps: 8

off_policy: &off_policy
<<: *general
<<: *policy
train_times_per_step: 1 # train multiple times per agent step
# PER
use_isw: false
Expand Down Expand Up @@ -662,7 +663,7 @@ maddpg:
batch_size: 4
buffer_size: 100000
n_step: 4
share_params: true
share_params: false
network_settings:
actor_continuous: [64, 64]
actor_discrete: [64, 64]
Expand Down
4 changes: 2 additions & 2 deletions rls/configs/examples/gym_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ environment:
obs_scale: false
obs_stack: false
platform: gym
render_mode: random_1
render_all: false
resize:
- 84
- 84
Expand Down Expand Up @@ -91,7 +91,7 @@ train:
off_policy_step_eval_episodes: 100
off_policy_train_interval: 1
platform: gym
pre_fill_steps: 10000
prefill_steps: 10000
prefill_choose: false
prefill_steps: 10000
render: false
Expand Down
4 changes: 2 additions & 2 deletions rls/configs/examples/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ environment:
obs_scale: false
obs_stack: false
platform: gym
render_mode: random_1
render_all: false
resize:
- 84
- 84
Expand Down Expand Up @@ -91,7 +91,7 @@ train:
off_policy_step_eval_episodes: 100
off_policy_train_interval: 1
platform: gym
pre_fill_steps: 1000
prefill_steps: 1000
prefill_choose: true
prefill_steps: 10000
render: false
Expand Down
3 changes: 2 additions & 1 deletion rls/configs/gym/env.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
vector_env_type: multiprocessing # todo
render_mode: random_1 # first last [list] random_[num] or all.
env_name: CartPole-v0
render_all: false
action_skip: false
skip: 4
obs_stack: false
Expand Down
4 changes: 4 additions & 0 deletions rls/configs/pettingzoo/env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
vector_env_type: vector # todo
env_name: mpe.simple_adversary_v2
env_config:
continuous_actions: True
Empty file.
2 changes: 1 addition & 1 deletion rls/configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ config_file: ~
store_dir: ./data
# episode_length, max_train_step, max_frame_step, max_train_episode are set to sys.maxsize if default value is set to zero.
episode_length: 1000 # episode_length per episode, if gym.env.max_episode_steps > episode_length, then PEB(Partial Episode Bootstraping), else TimeLimit.
pre_fill_steps: 10000 # pre_fill_steps should be set to an integer multiple of '--copy' to get an accurate pre-fill number
prefill_steps: 10000 # prefill_steps should be set to an integer multiple of '--copy' to get an accurate pre-fill number
prefill_choose: false
hostname: false
no_save: true
Expand Down
2 changes: 1 addition & 1 deletion rls/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

platform_list = ['gym', 'unity']
platform_list = ['gym', 'unity', 'pettingzoo']
Loading

0 comments on commit cca69b2

Please sign in to comment.