## PPO-stable-baselines3 实现

In [1]:
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import BaseCallback
from gym.wrappers import TimeLimit
from stable_baselines3.common.env_checker import check_env
import gym

def run_train():
    env = gym.make('Walker2d-v2')
    env = TimeLimit(env, max_episode_steps=int(1e3))
    check_env(env)

    # define callback function
    class TensorboardCallback(BaseCallback):
        """
        Custom callback for plotting additional values in tensorboard.
        """
        def __init__(self, log_dir, verbose=0):
            super(TensorboardCallback, self).__init__(verbose)
            self.log_dir = log_dir

        def _on_step(self) -> bool:
            if self.n_calls % 51200 == 0:
                print("Saving new best model")
                self.model.save(self.log_dir + f"/model_saved/PPO/admit_diana_{self.n_calls}")
            return True

    log_dir = "log/"

    model = PPO(
        policy="MlpPolicy",
        env=env,
        verbose=1,
        tensorboard_log=log_dir,
        device="cuda:0",
    )
    # model = A2C.load("./log/model_saved/admit_diana_51200.zip")
    model.learn(total_timesteps=int(2e6), callback=TensorboardCallback(log_dir=log_dir))
    # model.save("admit_diana")
    obs = env.reset()
    ep_reward = 0
    for i in range(int(1e6)):
        action, _states = model.predict(obs)
        obs, reward, done, info = env.step(action)
        ep_reward += reward
        if i % 2000 == 0 or done:
            env.reset()
            print(ep_reward)
            ep_reward = 0


run_train()


Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to log/PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 18.6     |
|    ep_rew_mean     | -0.527   |
| time/              |          |
|    fps             | 457      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
--------------------------------------
| rollout/                |          |
|    ep_len_mean          | 19.8     |
|    ep_rew_mean          | 1.57     |
| time/                   |          |
|    fps                  | 384      |
|    iterations           | 2        |
|    time_elapsed         | 10       |
|    total_timesteps      | 4096     |
| train/                  |          |
|    approx_kl            | 0.01872  |
|    clip_fraction        | 0.262    |
|    clip_range           | 0.2      |
|    entropy_loss         | -8.51  



-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 205         |
|    ep_rew_mean          | 228         |
| time/                   |             |
|    fps                  | 332         |
|    iterations           | 26          |
|    time_elapsed         | 160         |
|    total_timesteps      | 53248       |
| train/                  |             |
|    approx_kl            | 0.010342469 |
|    clip_fraction        | 0.117       |
|    clip_range           | 0.2         |
|    entropy_loss         | -8.21       |
|    explained_variance   | 0.818       |
|    learning_rate        | 0.0003      |
|    loss                 | 138         |
|    n_updates            | 250         |
|    policy_gradient_loss | -0.0218     |
|    std                  | 0.95        |
|    value_loss           | 248         |
-----------------------------------------
-----------------------------------------
| rollout/                |       