From 26b0b231ec68c2f9470d943c0172110e5d435a04 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 6 Aug 2020 15:13:45 -0700 Subject: [PATCH 1/6] Decay for PPO --- .../mlagents/trainers/ppo/optimizer_torch.py | 50 +++++++++++---- ml-agents/mlagents/trainers/torch/utils.py | 62 ++++++++++++++++++- 2 files changed, 100 insertions(+), 12 deletions(-) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 0c92f35d8b..9012ff8c08 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -26,6 +26,7 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.hyperparameters: PPOSettings = cast( PPOSettings, trainer_settings.hyperparameters ) + self.decay_schedule = self.hyperparameters.learning_rate_schedule self.optimizer = torch.optim.Adam( params, lr=self.trainer_settings.hyperparameters.learning_rate @@ -37,22 +38,25 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.stream_names = list(self.reward_signals.keys()) - def ppo_value_loss(self, values, old_values, returns): + def ppo_value_loss( + self, + values: Dict[str, torch.Tensor], + old_values: Dict[str, torch.Tensor], + returns: Dict[str, torch.Tensor], + epsilon: float, + ) -> torch.Tensor: """ Creates training-specific Tensorflow ops for PPO models. :param returns: :param old_values: :param values: """ - - decay_epsilon = self.hyperparameters.epsilon - value_losses = [] for name, head in values.items(): old_val_tensor = old_values[name] returns_tensor = returns[name] clipped_value_estimate = old_val_tensor + torch.clamp( - head - old_val_tensor, -decay_epsilon, decay_epsilon + head - old_val_tensor, -1 * epsilon, epsilon ) v_opt_a = (returns_tensor - head) ** 2 v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 @@ -89,6 +93,28 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: :param num_sequences: Number of sequences to process. :return: Results of update. """ + # Get decayed parameters + decay_learning_rate = ModelUtils.get_decayed_parameter( + self.decay_schedule, + self.hyperparameters.learning_rate, + 1e-10, + self.trainer_settings.max_steps, + self.policy.get_current_step(), + ) + decay_epsilon = ModelUtils.get_decayed_parameter( + self.decay_schedule, + self.hyperparameters.beta, + 0.1, + self.trainer_settings.max_steps, + self.policy.get_current_step(), + ) + decay_beta = ModelUtils.get_decayed_parameter( + self.decay_schedule, + self.hyperparameters.beta, + 1e-5, + self.trainer_settings.max_steps, + self.policy.get_current_step(), + ) returns = {} old_values = {} for name in self.reward_signals: @@ -128,18 +154,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) - value_loss = self.ppo_value_loss(values, old_values, returns) + value_loss = self.ppo_value_loss(values, old_values, returns, decay_epsilon) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, ModelUtils.list_to_tensor(batch["action_probs"]), ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32), ) - loss = ( - policy_loss - + 0.5 * value_loss - - self.hyperparameters.beta * torch.mean(entropy) - ) + loss = policy_loss + 0.5 * value_loss - decay_beta * torch.mean(entropy) + + # Set optimizer learning rate + ModelUtils.apply_learning_rate(self.optimizer, decay_learning_rate) self.optimizer.zero_grad() loss.backward() @@ -147,6 +172,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: update_stats = { "Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), "Losses/Value Loss": value_loss.detach().cpu().numpy(), + "Policy/Learning Rate": decay_learning_rate, + "Policy/Epsilon": decay_epsilon, + "Policy/Beta": decay_beta, } return update_stats diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 600c7eb8d9..e284001aef 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -10,7 +10,7 @@ VectorEncoder, VectorAndUnnormalizedInputEncoder, ) -from mlagents.trainers.settings import EncoderType +from mlagents.trainers.settings import EncoderType, ScheduleType from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance @@ -29,6 +29,66 @@ def swish(input_activation: torch.Tensor) -> torch.Tensor: """Swish activation function. For more info: https://arxiv.org/abs/1710.05941""" return torch.mul(input_activation, torch.sigmoid(input_activation)) + @staticmethod + def apply_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: + """ + Apply a learning rate to a torch optimizer. + :param optim: Optimizer + :param lr: Learning rate + """ + for param_group in optim.param_groups: + param_group["lr"] = lr + + @staticmethod + def get_decayed_parameter( + schedule: ScheduleType, + initial_value: float, + min_value: float, + max_step: int, + global_step: int, + ) -> float: + """ + Get the value of a parameter that should be decayed, assuming it is a function of + global_step. + :param schedule: Type of learning rate schedule. + :param initial_value: Initial value before decay. + :param min_value: Decay value to this value by max_step. + :param max_step: The final step count where the return value should equal min_value. + :param global_step: The current step count. + :return: The value. + """ + if schedule == ScheduleType.CONSTANT: + return initial_value + elif schedule == ScheduleType.LINEAR: + return ModelUtils.polynomial_decay( + initial_value, min_value, max_step, global_step + ) + else: + raise UnityTrainerException(f"The schedule {schedule} is invalid.") + + @staticmethod + def polynomial_decay( + initial_value: float, + min_value: float, + max_step: int, + global_step: int, + power: float = 1.0, + ) -> float: + """ + Get a decayed value based on a polynomial schedule, with respect to the current global step. + :param initial_value: Initial value before decay. + :param min_value: Decay value to this value by max_step. + :param max_step: The final step count where the return value should equal min_value. + :param global_step: The current step count. + :param power: Power of polynomial decay. 1.0 (default) is a linear decay. + :return: The current decayed value. + """ + global_step = min(global_step, max_step) + decayed_value = (initial_value - min_value) * ( + 1 - float(global_step) / max_step + ) ** (power) + min_value + return decayed_value + @staticmethod def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module: ENCODER_FUNCTION_BY_TYPE = { From 42506bce6d38c77090f4f2537bdb3d2de05e8ac5 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 6 Aug 2020 15:22:04 -0700 Subject: [PATCH 2/6] Decay for SAC --- .../mlagents/trainers/sac/optimizer_torch.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index b5653a9f65..20a54015a0 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -65,7 +65,7 @@ def forward( def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): super().__init__(policy, trainer_params) hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) - lr = hyperparameters.learning_rate + self.initial_lr = hyperparameters.learning_rate # lr_schedule = hyperparameters.learning_rate_schedule # max_step = trainer_params.max_steps self.tau = hyperparameters.tau @@ -137,9 +137,12 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): for param in policy_params: logger.debug(param.shape) - self.policy_optimizer = torch.optim.Adam(policy_params, lr=lr) - self.value_optimizer = torch.optim.Adam(value_params, lr=lr) - self.entropy_optimizer = torch.optim.Adam([self._log_ent_coef], lr=lr) + self.decay_schedule = hyperparameters.learning_rate_schedule + self.policy_optimizer = torch.optim.Adam(policy_params, lr=self.initial_lr) + self.value_optimizer = torch.optim.Adam(value_params, lr=self.initial_lr) + self.entropy_optimizer = torch.optim.Adam( + [self._log_ent_coef], lr=self.initial_lr + ) def sac_q_loss( self, @@ -351,6 +354,14 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: indexed by name. If none, don't update the reward signals. :return: Output from update process. """ + decay_learning_rate = ModelUtils.get_decayed_parameter( + self.decay_schedule, + self.initial_lr, + 1e-10, + self.trainer_settings.max_steps, + self.policy.get_current_step(), + ) + rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) @@ -436,14 +447,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: total_value_loss = q1_loss + q2_loss + value_loss + ModelUtils.apply_learning_rate(self.policy_optimizer, decay_learning_rate) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() + ModelUtils.apply_learning_rate(self.value_optimizer, decay_learning_rate) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() + ModelUtils.apply_learning_rate(self.entropy_optimizer, decay_learning_rate) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() @@ -459,6 +473,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: .detach() .cpu() .numpy(), + "Policy/Learning Rate": decay_learning_rate, } return update_stats From 53bd2550f5878a2bc92ebffc0696e2504948722c Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 6 Aug 2020 15:25:20 -0700 Subject: [PATCH 3/6] Fix issue with epsilon --- ml-agents/mlagents/trainers/ppo/optimizer_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 9012ff8c08..fbaa83f2b4 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -103,7 +103,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: ) decay_epsilon = ModelUtils.get_decayed_parameter( self.decay_schedule, - self.hyperparameters.beta, + self.hyperparameters.epsilon, 0.1, self.trainer_settings.max_steps, self.policy.get_current_step(), From 7d7a401bba55184b3d5621393445deb4f6267a12 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Thu, 6 Aug 2020 15:45:35 -0700 Subject: [PATCH 4/6] Tests for decay functions --- .../trainers/tests/torch/test_utils.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index 25c7a6c05e..a8ea62b7bd 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -2,7 +2,7 @@ import torch import numpy as np -from mlagents.trainers.settings import EncoderType +from mlagents.trainers.settings import EncoderType, ScheduleType from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.torch.encoders import ( @@ -77,6 +77,40 @@ def test_create_encoders( assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type)) +def test_get_decayed_parameter(): + test_steps = [0, 4, 9] + # Test constant decay + for _step in test_steps: + _param = ModelUtils.get_decayed_parameter( + ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1], _step + ) + assert _param == 1.0 + + test_results = [1.0, 0.6444, 0.2] + # Test linear decay + for _step, _result in zip(test_steps, test_results): + _param = ModelUtils.get_decayed_parameter( + ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1], _step + ) + assert _param == pytest.approx(_result, abs=0.01) + + # Test invalid + with pytest.raises(UnityTrainerException): + ModelUtils.get_decayed_parameter( + "SomeOtherSchedule", 1.0, 0.2, test_steps[-1], _step + ) + + +def test_polynomial_decay(): + test_steps = [0, 4, 9] + test_results = [1.0, 0.7, 0.2] + for _step, _result in zip(test_steps, test_results): + decayed = ModelUtils.polynomial_decay( + 1.0, 0.2, test_steps[-1], _step, power=0.8 + ) + assert decayed == pytest.approx(_result, abs=0.01) + + def test_list_to_tensor(): # Test converting pure list unconverted_list = [[1, 2], [1, 3], [1, 4]] From 9a5ef0eee8fa88f55c155165771965cfb96b2b4e Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 7 Aug 2020 11:08:18 -0700 Subject: [PATCH 5/6] Address comments --- .../mlagents/trainers/ppo/optimizer_torch.py | 55 ++++++++-------- .../mlagents/trainers/sac/optimizer_torch.py | 40 +++++------- ml-agents/mlagents/trainers/torch/utils.py | 65 +++++++++++-------- 3 files changed, 83 insertions(+), 77 deletions(-) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index fbaa83f2b4..ec09c74af5 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -26,7 +26,24 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.hyperparameters: PPOSettings = cast( PPOSettings, trainer_settings.hyperparameters ) - self.decay_schedule = self.hyperparameters.learning_rate_schedule + self.decay_learning_rate = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.learning_rate, + 1e-10, + self.trainer_settings.max_steps, + ) + self.decay_epsilon = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.epsilon, + 0.1, + self.trainer_settings.max_steps, + ) + self.decay_beta = ModelUtils.DecayedValue( + self.hyperparameters.learning_rate_schedule, + self.hyperparameters.beta, + 1e-5, + self.trainer_settings.max_steps, + ) self.optimizer = torch.optim.Adam( params, lr=self.trainer_settings.hyperparameters.learning_rate @@ -94,27 +111,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: :return: Results of update. """ # Get decayed parameters - decay_learning_rate = ModelUtils.get_decayed_parameter( - self.decay_schedule, - self.hyperparameters.learning_rate, - 1e-10, - self.trainer_settings.max_steps, - self.policy.get_current_step(), - ) - decay_epsilon = ModelUtils.get_decayed_parameter( - self.decay_schedule, - self.hyperparameters.epsilon, - 0.1, - self.trainer_settings.max_steps, - self.policy.get_current_step(), - ) - decay_beta = ModelUtils.get_decayed_parameter( - self.decay_schedule, - self.hyperparameters.beta, - 1e-5, - self.trainer_settings.max_steps, - self.policy.get_current_step(), - ) + decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) + decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) + decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: @@ -154,17 +153,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) - value_loss = self.ppo_value_loss(values, old_values, returns, decay_epsilon) + value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, ModelUtils.list_to_tensor(batch["action_probs"]), ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32), ) - loss = policy_loss + 0.5 * value_loss - decay_beta * torch.mean(entropy) + loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy) # Set optimizer learning rate - ModelUtils.apply_learning_rate(self.optimizer, decay_learning_rate) + ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() @@ -172,9 +171,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: update_stats = { "Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), "Losses/Value Loss": value_loss.detach().cpu().numpy(), - "Policy/Learning Rate": decay_learning_rate, - "Policy/Epsilon": decay_epsilon, - "Policy/Beta": decay_beta, + "Policy/Learning Rate": decay_lr, + "Policy/Epsilon": decay_eps, + "Policy/Beta": decay_bet, } return update_stats diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 20a54015a0..85dc2e218d 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -65,18 +65,12 @@ def forward( def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): super().__init__(policy, trainer_params) hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) - self.initial_lr = hyperparameters.learning_rate - # lr_schedule = hyperparameters.learning_rate_schedule - # max_step = trainer_params.max_steps self.tau = hyperparameters.tau self.init_entcoef = hyperparameters.init_entcoef self.policy = policy self.act_size = policy.act_size policy_network_settings = policy.network_settings - # h_size = policy_network_settings.hidden_units - # num_layers = policy_network_settings.num_layers - # vis_encode_type = policy_network_settings.vis_encode_type self.tau = hyperparameters.tau self.burn_in_ratio = 0.0 @@ -137,11 +131,20 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): for param in policy_params: logger.debug(param.shape) - self.decay_schedule = hyperparameters.learning_rate_schedule - self.policy_optimizer = torch.optim.Adam(policy_params, lr=self.initial_lr) - self.value_optimizer = torch.optim.Adam(value_params, lr=self.initial_lr) + self.decay_learning_rate = ModelUtils.DecayedValue( + hyperparameters.learning_rate_schedule, + hyperparameters.learning_rate, + 1e-10, + self.trainer_settings.max_steps, + ) + self.policy_optimizer = torch.optim.Adam( + policy_params, lr=hyperparameters.learning_rate + ) + self.value_optimizer = torch.optim.Adam( + value_params, lr=hyperparameters.learning_rate + ) self.entropy_optimizer = torch.optim.Adam( - [self._log_ent_coef], lr=self.initial_lr + [self._log_ent_coef], lr=hyperparameters.learning_rate ) def sac_q_loss( @@ -354,14 +357,6 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: indexed by name. If none, don't update the reward signals. :return: Output from update process. """ - decay_learning_rate = ModelUtils.get_decayed_parameter( - self.decay_schedule, - self.initial_lr, - 1e-10, - self.trainer_settings.max_steps, - self.policy.get_current_step(), - ) - rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) @@ -447,17 +442,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: total_value_loss = q1_loss + q2_loss + value_loss - ModelUtils.apply_learning_rate(self.policy_optimizer, decay_learning_rate) + decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) + ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() - ModelUtils.apply_learning_rate(self.value_optimizer, decay_learning_rate) + ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() - ModelUtils.apply_learning_rate(self.entropy_optimizer, decay_learning_rate) + ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() @@ -473,7 +469,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: .detach() .cpu() .numpy(), - "Policy/Learning Rate": decay_learning_rate, + "Policy/Learning Rate": decay_lr, } return update_stats diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index ddfdc97940..0644defec9 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -30,7 +30,7 @@ def swish(input_activation: torch.Tensor) -> torch.Tensor: return torch.mul(input_activation, torch.sigmoid(input_activation)) @staticmethod - def apply_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: + def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: """ Apply a learning rate to a torch optimizer. :param optim: Optimizer @@ -39,32 +39,43 @@ def apply_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: for param_group in optim.param_groups: param_group["lr"] = lr - @staticmethod - def get_decayed_parameter( - schedule: ScheduleType, - initial_value: float, - min_value: float, - max_step: int, - global_step: int, - ) -> float: - """ - Get the value of a parameter that should be decayed, assuming it is a function of - global_step. - :param schedule: Type of learning rate schedule. - :param initial_value: Initial value before decay. - :param min_value: Decay value to this value by max_step. - :param max_step: The final step count where the return value should equal min_value. - :param global_step: The current step count. - :return: The value. - """ - if schedule == ScheduleType.CONSTANT: - return initial_value - elif schedule == ScheduleType.LINEAR: - return ModelUtils.polynomial_decay( - initial_value, min_value, max_step, global_step - ) - else: - raise UnityTrainerException(f"The schedule {schedule} is invalid.") + class DecayedValue: + def __init__( + self, + schedule: ScheduleType, + initial_value: float, + min_value: float, + max_step: int, + ): + """ + Object that represnets value of a parameter that should be decayed, assuming it is a function of + global_step. + :param schedule: Type of learning rate schedule. + :param initial_value: Initial value before decay. + :param min_value: Decay value to this value by max_step. + :param max_step: The final step count where the return value should equal min_value. + :param global_step: The current step count. + :return: The value. + """ + self.schedule = schedule + self.initial_value = initial_value + self.min_value = min_value + self.max_step = max_step + + def get_value(self, global_step: int) -> float: + """ + Get the value at a given global step. + :param global_step: Step count. + :returns: Decayed value at this global step. + """ + if self.schedule == ScheduleType.CONSTANT: + return self.initial_value + elif self.schedule == ScheduleType.LINEAR: + return ModelUtils.polynomial_decay( + self.initial_value, self.min_value, self.max_step, global_step + ) + else: + raise UnityTrainerException(f"The schedule {self.schedule} is invalid.") @staticmethod def polynomial_decay( From 7cbb205e0abd9411f31965f9e6b5ac11debb57d5 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 7 Aug 2020 11:10:49 -0700 Subject: [PATCH 6/6] Update tests --- .../trainers/tests/torch/test_utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index e34d1359a5..f6f286c84a 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -79,28 +79,26 @@ def test_create_encoders( assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type)) -def test_get_decayed_parameter(): +def test_decayed_value(): test_steps = [0, 4, 9] # Test constant decay + param = ModelUtils.DecayedValue(ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1]) for _step in test_steps: - _param = ModelUtils.get_decayed_parameter( - ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1], _step - ) + _param = param.get_value(_step) assert _param == 1.0 test_results = [1.0, 0.6444, 0.2] # Test linear decay + param = ModelUtils.DecayedValue(ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1]) for _step, _result in zip(test_steps, test_results): - _param = ModelUtils.get_decayed_parameter( - ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1], _step - ) + _param = param.get_value(_step) assert _param == pytest.approx(_result, abs=0.01) # Test invalid with pytest.raises(UnityTrainerException): - ModelUtils.get_decayed_parameter( - "SomeOtherSchedule", 1.0, 0.2, test_steps[-1], _step - ) + ModelUtils.DecayedValue( + "SomeOtherSchedule", 1.0, 0.2, test_steps[-1] + ).get_value(0) def test_polynomial_decay():