Skip to content

Commit

Permalink
Fix in examples
Browse files Browse the repository at this point in the history
- Pendulum trust region example is now the same for vectorize and normal
core environments
- Vectorized pendulum seems to be broken still...
  • Loading branch information
boris-il-forte committed Jan 18, 2024
1 parent 1d2f04a commit 867fd3d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
15 changes: 6 additions & 9 deletions examples/pendulum_trust_region.py
Expand Up @@ -8,7 +8,7 @@

from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import Gym
from mushroom_rl.algorithms.actor_critic import TRPO, PPO
from mushroom_rl.algorithms.actor_critic import PPO, TRPO

from mushroom_rl.policy import GaussianTorchPolicy

Expand All @@ -24,12 +24,9 @@ def __init__(self, input_shape, output_shape, n_features, **kwargs):
self._h2 = nn.Linear(n_features, n_features)
self._h3 = nn.Linear(n_features, n_output)

nn.init.xavier_uniform_(self._h1.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight,
gain=nn.init.calculate_gain('linear'))
nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain('linear'))

def forward(self, state, **kwargs):
features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
Expand Down Expand Up @@ -116,8 +113,8 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
cg_residual_tol=1e-10)

algs_params = [
(TRPO, 'trpo', trpo_params),
(PPO, 'ppo', ppo_params)
(PPO, 'ppo', ppo_params),
(TRPO, 'trpo', trpo_params)
]

for alg, alg_name, alg_params in algs_params:
Expand Down
Expand Up @@ -8,7 +8,7 @@

from mushroom_rl.core import VectorCore, Logger, MultiprocessEnvironment
from mushroom_rl.environments import Gym
from mushroom_rl.algorithms.actor_critic import PPO
from mushroom_rl.algorithms.actor_critic import PPO, TRPO

from mushroom_rl.policy import GaussianTorchPolicy

Expand All @@ -27,12 +27,9 @@ def __init__(self, input_shape, output_shape, n_features, **kwargs):
self._h2 = nn.Linear(n_features, n_features)
self._h3 = nn.Linear(n_features, n_output)

nn.init.xavier_uniform_(self._h1.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight,
gain=nn.init.calculate_gain('linear'))
nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain('linear'))

def forward(self, state, **kwargs):
features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
Expand Down Expand Up @@ -119,7 +116,8 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
cg_residual_tol=1e-10)

algs_params = [
(PPO, 'ppo', ppo_params)
(PPO, 'ppo', ppo_params),
(TRPO, 'trpo', trpo_params)
]

for alg, alg_name, alg_params in algs_params:
Expand Down

0 comments on commit 867fd3d

Please sign in to comment.