Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,16 @@ def sample_actions(
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param vec_obs: List of vector observations.
:param vis_obs: List of visual observations.
:param masks: Loss masks for RNN, else None.
:param memories: Input memories when using RNN, else None.
:param seq_len: Sequence length when using RNN.
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
:return: Tuple of actions, log probabilities (dependent on all_log_probs), entropies, and
output memories, all as Torch Tensors.
:return: Tuple of actions, actions clipped to -1, 1, log probabilities (dependent on all_log_probs),
entropies, and output memories, all as Torch Tensors.
"""
if memories is None:
dists, memories = self.actor_critic.get_dists(
Expand All @@ -155,8 +155,14 @@ def sample_actions(
actions = actions[:, 0, :]
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)

if self._clip_action and self.use_continuous_act:
clipped_action = torch.clamp(actions, -3, 3) / 3
else:
clipped_action = actions
return (
actions,
clipped_action,
all_logs if all_log_probs else log_probs,
entropy_sum,
memories,
Expand Down Expand Up @@ -201,14 +207,10 @@ def evaluate(

run_out = {}
with torch.no_grad():
action, log_probs, entropy, memories = self.sample_actions(
action, clipped_action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)

if self._clip_action and self.use_continuous_act:
clipped_action = torch.clamp(action, -3, 3) / 3
else:
clipped_action = action
run_out["pre_action"] = ModelUtils.to_numpy(action)
run_out["action"] = ModelUtils.to_numpy(clipped_action)
# Todo - make pre_action difference
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/sac/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
self.target_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
(sampled_actions, log_probs, _, _) = self.policy.sample_actions(
(sampled_actions, _, log_probs, _, _) = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
).unsqueeze(0)

with torch.no_grad():
_, log_probs1, _, _ = policy1.sample_actions(
_, _, log_probs1, _, _ = policy1.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
)
_, log_probs2, _, _ = policy2.sample_actions(
_, _, log_probs2, _, _ = policy2.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
)

Expand Down
12 changes: 11 additions & 1 deletion ml-agents/mlagents/trainers/tests/torch/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,13 @@ def test_sample_actions(rnn, visual, discrete):
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)

(sampled_actions, log_probs, entropies, memories) = policy.sample_actions(
(
sampled_actions,
clipped_actions,
log_probs,
entropies,
memories,
) = policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,
Expand All @@ -141,6 +147,10 @@ def test_sample_actions(rnn, visual, discrete):
)
else:
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size)
assert clipped_actions.shape == (
64,
policy.behavior_spec.action_spec.continuous_size,
)
assert entropies.shape == (64,)

if rnn:
Expand Down
12 changes: 9 additions & 3 deletions ml-agents/mlagents/trainers/torch/components/bc/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update(self) -> Dict[str, np.ndarray]:
# Don't continue training if the learning rate has reached 0, to reduce training time.

decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
if self.current_lr <= 0:
if self.current_lr <= 1e-10: # Unlike in TF, this never actually reaches 0.
return {"Losses/Pretraining Loss": 0}

batch_losses = []
Expand Down Expand Up @@ -164,7 +164,13 @@ def _update_batch(
else:
vis_obs = []

selected_actions, all_log_probs, _, _ = self.policy.sample_actions(
(
selected_actions,
clipped_actions,
all_log_probs,
_,
_,
) = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,
Expand All @@ -173,7 +179,7 @@ def _update_batch(
all_log_probs=True,
)
bc_loss = self._behavioral_cloning_loss(
selected_actions, all_log_probs, expert_actions
clipped_actions, all_log_probs, expert_actions
)
self.optimizer.zero_grad()
bc_loss.backward()
Expand Down