Skip to content

Commit

Permalink
Updated SAC example
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Dec 4, 2023
1 parent 6131c70 commit 8620449
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions examples/pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mushroom_rl.algorithms.actor_critic import SAC
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments.gym_env import Gym
from mushroom_rl.utils.dataset import compute_J, parse_dataset
from mushroom_rl.utils import TorchUtils

from tqdm import trange

Expand Down Expand Up @@ -66,10 +66,10 @@ def forward(self, state):
return a


def experiment(alg, n_epochs, n_steps, n_steps_test):
def experiment(alg, n_epochs, n_steps, n_steps_test, save):
np.random.seed()

logger = Logger(alg.__name__, results_dir=None)
logger = Logger(alg.__name__, results_dir='./logs' if save else None)
logger.strong_line()
logger.info('Experiment Algorithm: ' + alg.__name__)

Expand Down Expand Up @@ -121,11 +121,10 @@ def experiment(alg, n_epochs, n_steps, n_steps_test):

# RUN
dataset = core.evaluate(n_steps=n_steps_test, render=False)
s, *_ = parse_dataset(dataset)

J = np.mean(compute_J(dataset, mdp.info.gamma))
R = np.mean(compute_J(dataset))
E = agent.policy.entropy(s)
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy(dataset.state)

logger.epoch_info(0, J=J, R=R, entropy=E)

Expand All @@ -134,23 +133,22 @@ def experiment(alg, n_epochs, n_steps, n_steps_test):
for n in trange(n_epochs, leave=False):
core.learn(n_steps=n_steps, n_steps_per_fit=1)
dataset = core.evaluate(n_steps=n_steps_test, render=False)
s, *_ = parse_dataset(dataset)

J = np.mean(compute_J(dataset, mdp.info.gamma))
R = np.mean(compute_J(dataset))
E = agent.policy.entropy(s)
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy(dataset.state)

logger.epoch_info(n+1, J=J, R=R, entropy=E)

if save:
logger.log_best_agent(agent, J)

logger.info('Press a button to visualize pendulum')
input()
core.evaluate(n_episodes=5, render=True)


if __name__ == '__main__':
algs = [
SAC
]

for alg in algs:
experiment(alg=alg, n_epochs=40, n_steps=1000, n_steps_test=2000)
save = False
TorchUtils.set_default_device('cpu')
experiment(alg=SAC, n_epochs=40, n_steps=1000, n_steps_test=2000, save=save)

0 comments on commit 8620449

Please sign in to comment.