diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index e3d8cd0010..8ee3247557 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -77,6 +77,7 @@ def __init__( conditional_sigma=self.condition_sigma_on_obs, tanh_squash=tanh_squash, ) + self._clip_action = not tanh_squash # Save the m_size needed for export self._export_m_size = self.m_size # m_size needed for training is determined by network, not trainer settings @@ -203,8 +204,13 @@ def evaluate( action, log_probs, entropy, memories = self.sample_actions( vec_obs, vis_obs, masks=masks, memories=memories ) - run_out["action"] = ModelUtils.to_numpy(action) + + if self._clip_action and self.use_continuous_act: + clipped_action = torch.clamp(action, -3, 3) / 3 + else: + clipped_action = action run_out["pre_action"] = ModelUtils.to_numpy(action) + run_out["action"] = ModelUtils.to_numpy(clipped_action) # Todo - make pre_action difference run_out["log_probs"] = ModelUtils.to_numpy(log_probs) run_out["entropy"] = ModelUtils.to_numpy(entropy) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 8660d820aa..7376a032c5 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -136,7 +136,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) if self.policy.use_continuous_act: - actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) + actions = ModelUtils.list_to_tensor(batch["actions_pre"]).unsqueeze(-1) else: actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py b/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py index 16b9ced0c6..8c201a01d6 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py @@ -68,7 +68,7 @@ def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None: buffer = create_agent_buffer(behavior_spec, 5) curiosity_rp.update(buffer) reward_old = curiosity_rp.evaluate(buffer)[0] - for _ in range(10): + for _ in range(20): curiosity_rp.update(buffer) reward_new = curiosity_rp.evaluate(buffer)[0] assert reward_new < reward_old diff --git a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py index 62ce456bc5..af155ae36c 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py @@ -121,7 +121,7 @@ def test_recurrent_ppo(use_discrete): PPO_TORCH_CONFIG, hyperparameters=new_hyperparams, network_settings=new_network_settings, - max_steps=5000, + max_steps=6000, ) check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index c52bfd1802..bab2738215 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -160,7 +160,7 @@ def test_get_probs_and_entropy(): action_list, dist_list ) assert log_probs.shape == (1, 2, 2) - assert entropies.shape == (1, 2, 2) + assert entropies.shape == (1, 1, 2) assert all_probs is None for log_prob in log_probs.flatten(): diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index 8540909f6d..3ad4274a90 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -66,7 +66,11 @@ def pdf(self, value): return torch.exp(log_prob) def entropy(self): - return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON) + return torch.mean( + 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON), + dim=1, + keepdim=True, + ) # Use equivalent behavior to TF class TanhGaussianDistInstance(GaussianDistInstance): @@ -131,7 +135,7 @@ def __init__( hidden_size, num_outputs, kernel_init=Initialization.KaimingHeNormal, - kernel_gain=0.1, + kernel_gain=0.2, bias_init=Initialization.Zero, ) self.tanh_squash = tanh_squash @@ -140,7 +144,7 @@ def __init__( hidden_size, num_outputs, kernel_init=Initialization.KaimingHeNormal, - kernel_gain=0.1, + kernel_gain=0.2, bias_init=Initialization.Zero, ) else: diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index 983d44b2fb..9a88fb9d63 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -133,7 +133,7 @@ def __init__( self.final_flat, self.h_size, kernel_init=Initialization.KaimingHeNormal, - kernel_gain=1.0, + kernel_gain=1.41, # Use ReLU gain ), nn.LeakyReLU(), ) @@ -165,7 +165,7 @@ def __init__( self.final_flat, self.h_size, kernel_init=Initialization.KaimingHeNormal, - kernel_gain=1.0, + kernel_gain=1.41, # Use ReLU gain ), nn.LeakyReLU(), ) @@ -200,7 +200,7 @@ def __init__( self.final_flat, self.h_size, kernel_init=Initialization.KaimingHeNormal, - kernel_gain=1.0, + kernel_gain=1.41, # Use ReLU gain ), nn.LeakyReLU(), ) @@ -251,7 +251,7 @@ def __init__( n_channels[-1] * height * width, output_size, kernel_init=Initialization.KaimingHeNormal, - kernel_gain=1.0, + kernel_gain=1.41, # Use ReLU gain ) self.sequential = nn.Sequential(*layers) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index ec1a73fa3a..c7be7db8ed 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -39,12 +39,18 @@ def linear_layer( :param output_size: The size of the output tensor :param kernel_init: The Initialization to use for the weights of the layer :param kernel_gain: The multiplier for the weights of the kernel. Note that in - TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling - KaimingHeNormal with kernel_gain of 0.1 + TensorFlow, the gain is square-rooted. Therefore calling with scale 0.01 is equivalent to calling + KaimingHeNormal with kernel_gain of 0.1 :param bias_init: The Initialization to use for the weights of the bias layer """ layer = torch.nn.Linear(input_size, output_size) - _init_methods[kernel_init](layer.weight.data) + if ( + kernel_init == Initialization.KaimingHeNormal + or kernel_init == Initialization.KaimingHeUniform + ): + _init_methods[kernel_init](layer.weight.data, nonlinearity="linear") + else: + _init_methods[kernel_init](layer.weight.data) layer.weight.data *= kernel_gain _init_methods[bias_init](layer.bias.data) return layer diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 6583e47ce4..f4c32f3a30 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -292,6 +292,9 @@ def __init__( self.distribution = MultiCategoricalDistribution( self.encoding_size, self.action_spec.discrete_branches ) + # During training, clipping is done in TorchPolicy, but we need to clip before ONNX + # export as well. + self._clip_action_on_export = not tanh_squash @property def memory_size(self) -> int: @@ -339,6 +342,8 @@ def forward( if self.action_spec.is_continuous(): action_list = self.sample_action(dists) action_out = torch.stack(action_list, dim=-1) + if self._clip_action_on_export: + action_out = torch.clamp(action_out, -3, 3) / 3 else: action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1) return (