Skip to content

Commit

Permalink
several updates
Browse files Browse the repository at this point in the history
rename sac to rl in agent and sac_main
remove temp issue pytorch/pytorch#80809
add depth camera to realcar
add dm_control mujoco sample
  • Loading branch information
BlueFisher committed Aug 18, 2022
1 parent ff07397 commit 98a6f84
Show file tree
Hide file tree
Showing 14 changed files with 1,996 additions and 690 deletions.
57 changes: 28 additions & 29 deletions algorithm/agent.py
Expand Up @@ -251,20 +251,20 @@ def set_model_abs_dir(self, model_abs_dir: Path):
model_abs_dir.mkdir(parents=True, exist_ok=True)
self.model_abs_dir = model_abs_dir

def set_sac(self, sac: SAC_Base):
self.sac = sac
self.seq_encoder = sac.seq_encoder
def set_rl(self, rl: SAC_Base):
self.rl = rl
self.seq_encoder = rl.seq_encoder

def pre_run(self, num_agents: int):
self['initial_pre_action'] = self.sac.get_initial_action(num_agents) # [n_agents, action_size]
self['initial_pre_action'] = self.rl.get_initial_action(num_agents) # [n_agents, action_size]
self['pre_action'] = self['initial_pre_action']
if self.seq_encoder is not None:
self['initial_seq_hidden_state'] = self.sac.get_initial_seq_hidden_state(num_agents) # [n_agents, *seq_hidden_state_shape]
self['initial_seq_hidden_state'] = self.rl.get_initial_seq_hidden_state(num_agents) # [n_agents, *seq_hidden_state_shape]
self['seq_hidden_state'] = self['initial_seq_hidden_state']

self.agents: List[Agent] = [
Agent(i, self.obs_shapes, self.action_size,
seq_hidden_state_shape=self.sac.seq_hidden_state_shape
seq_hidden_state_shape=self.rl.seq_hidden_state_shape
if self.seq_encoder is not None else None)
for i in range(num_agents)
]
Expand All @@ -273,11 +273,11 @@ def get_action(self,
disable_sample: bool = False,
force_rnd_if_available: bool = False):
if self.seq_encoder == SEQ_ENCODER.RNN:
action, prob, next_seq_hidden_state = self.sac.choose_rnn_action(self['obs_list'],
self['pre_action'],
self['seq_hidden_state'],
disable_sample=disable_sample,
force_rnd_if_available=force_rnd_if_available)
action, prob, next_seq_hidden_state = self.rl.choose_rnn_action(self['obs_list'],
self['pre_action'],
self['seq_hidden_state'],
disable_sample=disable_sample,
force_rnd_if_available=force_rnd_if_available)

elif self.seq_encoder == SEQ_ENCODER.ATTN:
ep_length = min(512, max([a.episode_length for a in self.agents]))
Expand Down Expand Up @@ -306,18 +306,18 @@ def get_action(self,
for o, t_o in zip(ep_obses_list, self['obs_list'])]
ep_pre_actions = gen_pre_n_actions(ep_actions, True)

action, prob, next_seq_hidden_state = self.sac.choose_attn_action(ep_indexes,
ep_padding_masks,
ep_obses_list,
ep_pre_actions,
ep_attn_states,
disable_sample=disable_sample,
force_rnd_if_available=force_rnd_if_available)
action, prob, next_seq_hidden_state = self.rl.choose_attn_action(ep_indexes,
ep_padding_masks,
ep_obses_list,
ep_pre_actions,
ep_attn_states,
disable_sample=disable_sample,
force_rnd_if_available=force_rnd_if_available)

else:
action, prob = self.sac.choose_action(self['obs_list'],
disable_sample=disable_sample,
force_rnd_if_available=force_rnd_if_available)
action, prob = self.rl.choose_action(self['obs_list'],
disable_sample=disable_sample,
force_rnd_if_available=force_rnd_if_available)
next_seq_hidden_state = None

self['action'] = action
Expand Down Expand Up @@ -355,13 +355,14 @@ def train(self):
# ep_obses_list, ep_actions, ep_rewards, next_obs_list, ep_dones, ep_probs,
# ep_seq_hidden_states
for episode_trans in self['episode_trans_list']:
self.sac.put_episode(*episode_trans)
self.rl.put_episode(*episode_trans)

trained_steps = self.sac.train()
trained_steps = self.rl.train()

return trained_steps

def post_step(self, local_done):
def post_step(self, next_obs_list, local_done):
self['obs_list'] = next_obs_list
self['pre_action'] = self['action']
self['pre_action'][local_done] = self['initial_pre_action'][local_done]
if self.seq_encoder is not None:
Expand Down Expand Up @@ -422,7 +423,7 @@ def reset(self):
def burn_in_padding(self):
for n, mgr in self:
for a in [a for a in mgr.agents if a.is_empty()]:
for _ in range(mgr.sac.burn_in_step):
for _ in range(mgr.rl.burn_in_step):
a.add_transition(
obs_list=[np.zeros(t, dtype=np.float32) for t in mgr.obs_shapes],
action=mgr['initial_pre_action'][0],
Expand Down Expand Up @@ -464,14 +465,12 @@ def train(self, trained_steps: int):
return trained_steps

def post_step(self, ma_next_obs_list, ma_local_done):
self.set_obs_list(ma_next_obs_list)

for n, mgr in self:
mgr.post_step(ma_local_done[n])
mgr.post_step(ma_next_obs_list[n], ma_local_done[n])

def save_model(self, save_replay_buffer=False):
for n, mgr in self:
mgr.sac.save_model(save_replay_buffer)
mgr.rl.save_model(save_replay_buffer)


if __name__ == "__main__":
Expand Down
5 changes: 0 additions & 5 deletions algorithm/sac_base.py
Expand Up @@ -521,11 +521,6 @@ def _init_or_restore(self, last_ckpt: int):
else:
model.eval()

# https://github.com/pytorch/pytorch/issues/80809
if 'opt' in name:
for g in model.param_groups:
g['capturable'] = True

self._logger.info(f'Restored from {ckpt_restore_path}')

if self.train_mode and self.use_replay_buffer:
Expand Down
26 changes: 13 additions & 13 deletions algorithm/sac_main.py
Expand Up @@ -184,19 +184,19 @@ def _init_sac(self, config_abs_dir: Path):
spec.loader.exec_module(nn)
mgr.config['sac_config']['nn'] = nn

mgr.set_sac(SAC_Base(obs_shapes=mgr.obs_shapes,
d_action_size=mgr.d_action_size,
c_action_size=mgr.c_action_size,
model_abs_dir=mgr.model_abs_dir,
device=self.device,
ma_name=None if len(self.ma_manager) == 1 else n,
train_mode=self.train_mode,
last_ckpt=self.last_ckpt,
mgr.set_rl(SAC_Base(obs_shapes=mgr.obs_shapes,
d_action_size=mgr.d_action_size,
c_action_size=mgr.c_action_size,
model_abs_dir=mgr.model_abs_dir,
device=self.device,
ma_name=None if len(self.ma_manager) == 1 else n,
train_mode=self.train_mode,
last_ckpt=self.last_ckpt,

nn_config=mgr.config['nn_config'],
**mgr.config['sac_config'],
nn_config=mgr.config['nn_config'],
**mgr.config['sac_config'],

replay_config=mgr.config['replay_config']))
replay_config=mgr.config['replay_config']))

def _run(self):
self.ma_manager.pre_run(self.base_config['n_agents'])
Expand Down Expand Up @@ -282,15 +282,15 @@ def _run(self):
def _log_episode_summaries(self):
for n, mgr in self.ma_manager:
rewards = np.array([a.reward for a in mgr.agents])
mgr.sac.write_constant_summaries([
mgr.rl.write_constant_summaries([
{'tag': 'reward/mean', 'simple_value': rewards.mean()},
{'tag': 'reward/max', 'simple_value': rewards.max()},
{'tag': 'reward/min', 'simple_value': rewards.min()}
])

def _log_episode_info(self, iteration, iter_time):
for n, mgr in self.ma_manager:
global_step = format_global_step(mgr.sac.get_global_step())
global_step = format_global_step(mgr.rl.get_global_step())
rewards = [a.reward for a in mgr.agents]
rewards = ", ".join([f"{i:6.1f}" for i in rewards])
max_step = max([a.steps for a in mgr.agents])
Expand Down
4 changes: 2 additions & 2 deletions algorithm/sac_main_hitted.py
Expand Up @@ -65,7 +65,7 @@ def _log_episode_summaries(self):
rewards = np.array([a.reward for a in mgr.agents])
hitted = sum([a.hitted for a in mgr.agents])

mgr.sac.write_constant_summaries([
mgr.rl.write_constant_summaries([
{'tag': 'reward/mean', 'simple_value': rewards.mean()},
{'tag': 'reward/max', 'simple_value': rewards.max()},
{'tag': 'reward/min', 'simple_value': rewards.min()},
Expand All @@ -74,7 +74,7 @@ def _log_episode_summaries(self):

def _log_episode_info(self, iteration, iter_time):
for n, mgr in self.ma_manager:
global_step = format_global_step(mgr.sac.get_global_step())
global_step = format_global_step(mgr.rl.get_global_step())
rewards = [a.reward for a in mgr.agents]
rewards = ", ".join([f"{i:6.1f}" for i in rewards])
hitted = sum([a.hitted for a in mgr.agents])
Expand Down
17 changes: 17 additions & 0 deletions envs/dm_control/cartpole/config.yaml
@@ -0,0 +1,17 @@
default:
base_config:
env_type: DM_CONTROL
env_name: cartpole/swingup

max_step: -1
max_iter: 100
n_agents: 10

sac_config:
seed: 42

n_step: 10

tau: 0.9
v_lambda: 0.99
init_log_alpha: 0
13 changes: 13 additions & 0 deletions envs/dm_control/cartpole/nn.py
@@ -0,0 +1,13 @@
import algorithm.nn_models as m

ModelRep = m.ModelSimpleRep


class ModelQ(m.ModelQ):
def _build_model(self):
super()._build_model(c_dense_n=64, c_dense_depth=2)


class ModelPolicy(m.ModelPolicy):
def _build_model(self):
super()._build_model(c_dense_n=64, c_dense_depth=2)
17 changes: 17 additions & 0 deletions envs/gym/humanoid/config.yaml
@@ -0,0 +1,17 @@
default:
base_config:
env_type: GYM
env_name: Humanoid-v3

max_step: -1
max_iter: 100
n_agents: 10

sac_config:
seed: 42

n_step: 10

tau: 0.9
v_lambda: 0.99
init_log_alpha: 0
13 changes: 13 additions & 0 deletions envs/gym/humanoid/nn.py
@@ -0,0 +1,13 @@
import algorithm.nn_models as m

ModelRep = m.ModelSimpleRep


class ModelQ(m.ModelQ):
def _build_model(self):
super()._build_model(c_dense_n=64, c_dense_depth=2)


class ModelPolicy(m.ModelPolicy):
def _build_model(self):
super()._build_model(c_dense_n=64, c_dense_depth=2)
22 changes: 0 additions & 22 deletions envs/mujoco/half_cheetah/config.yaml

This file was deleted.

18 changes: 0 additions & 18 deletions envs/mujoco/half_cheetah/nn.py

This file was deleted.

72 changes: 0 additions & 72 deletions envs/mujoco/half_cheetah/nn_hard.py

This file was deleted.

0 comments on commit 98a6f84

Please sign in to comment.