Skip to content

Commit

Permalink
Minor PEP8 fixes in DDPG
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Nov 30, 2023
1 parent cb99e5c commit 503735f
Showing 1 changed file with 10 additions and 20 deletions.
30 changes: 10 additions & 20 deletions mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,15 @@ def __init__(self, mdp_info, policy_class, policy_params,
self._replay_memory = ReplayMemory(initial_replay_size, max_replay_size)

target_critic_params = deepcopy(critic_params)
self._critic_approximator = Regressor(TorchApproximator,
**critic_params)
self._target_critic_approximator = Regressor(TorchApproximator,
**target_critic_params)
self._critic_approximator = Regressor(TorchApproximator, **critic_params)
self._target_critic_approximator = Regressor(TorchApproximator, **target_critic_params)

target_actor_params = deepcopy(actor_params)
self._actor_approximator = Regressor(TorchApproximator,
**actor_params)
self._target_actor_approximator = Regressor(TorchApproximator,
**target_actor_params)
self._actor_approximator = Regressor(TorchApproximator, **actor_params)
self._target_actor_approximator = Regressor(TorchApproximator, **target_actor_params)

self._init_target(self._critic_approximator,
self._target_critic_approximator)
self._init_target(self._actor_approximator,
self._target_actor_approximator)
self._init_target(self._critic_approximator, self._target_critic_approximator)
self._init_target(self._actor_approximator, self._target_actor_approximator)

policy = policy_class(self._actor_approximator, **policy_params)

Expand All @@ -100,23 +94,19 @@ def __init__(self, mdp_info, policy_class, policy_params,
def fit(self, dataset):
self._replay_memory.add(dataset)
if self._replay_memory.initialized:
state, action, reward, next_state, absorbing, _ =\
self._replay_memory.get(self._batch_size())
state, action, reward, next_state, absorbing, _ = self._replay_memory.get(self._batch_size())

q_next = self._next_q(next_state, absorbing)
q = reward + self.mdp_info.gamma * q_next

self._critic_approximator.fit(state, action, q,
**self._critic_fit_params)
self._critic_approximator.fit(state, action, q, **self._critic_fit_params)

if self._fit_count % self._policy_delay() == 0:
loss = self._loss(state)
self._optimize_actor_parameters(loss)

self._update_target(self._critic_approximator,
self._target_critic_approximator)
self._update_target(self._actor_approximator,
self._target_actor_approximator)
self._update_target(self._critic_approximator, self._target_critic_approximator)
self._update_target(self._actor_approximator, self._target_actor_approximator)

self._fit_count += 1

Expand Down

0 comments on commit 503735f

Please sign in to comment.