diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py index c19efa30..59719dba 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py @@ -61,21 +61,15 @@ def __init__(self, mdp_info, policy_class, policy_params, self._replay_memory = ReplayMemory(initial_replay_size, max_replay_size) target_critic_params = deepcopy(critic_params) - self._critic_approximator = Regressor(TorchApproximator, - **critic_params) - self._target_critic_approximator = Regressor(TorchApproximator, - **target_critic_params) + self._critic_approximator = Regressor(TorchApproximator, **critic_params) + self._target_critic_approximator = Regressor(TorchApproximator, **target_critic_params) target_actor_params = deepcopy(actor_params) - self._actor_approximator = Regressor(TorchApproximator, - **actor_params) - self._target_actor_approximator = Regressor(TorchApproximator, - **target_actor_params) + self._actor_approximator = Regressor(TorchApproximator, **actor_params) + self._target_actor_approximator = Regressor(TorchApproximator, **target_actor_params) - self._init_target(self._critic_approximator, - self._target_critic_approximator) - self._init_target(self._actor_approximator, - self._target_actor_approximator) + self._init_target(self._critic_approximator, self._target_critic_approximator) + self._init_target(self._actor_approximator, self._target_actor_approximator) policy = policy_class(self._actor_approximator, **policy_params) @@ -100,23 +94,19 @@ def __init__(self, mdp_info, policy_class, policy_params, def fit(self, dataset): self._replay_memory.add(dataset) if self._replay_memory.initialized: - state, action, reward, next_state, absorbing, _ =\ - self._replay_memory.get(self._batch_size()) + state, action, reward, next_state, absorbing, _ = self._replay_memory.get(self._batch_size()) q_next = self._next_q(next_state, absorbing) q = reward + self.mdp_info.gamma * q_next - self._critic_approximator.fit(state, action, q, - **self._critic_fit_params) + self._critic_approximator.fit(state, action, q, **self._critic_fit_params) if self._fit_count % self._policy_delay() == 0: loss = self._loss(state) self._optimize_actor_parameters(loss) - self._update_target(self._critic_approximator, - self._target_critic_approximator) - self._update_target(self._actor_approximator, - self._target_actor_approximator) + self._update_target(self._critic_approximator, self._target_critic_approximator) + self._update_target(self._actor_approximator, self._target_actor_approximator) self._fit_count += 1