In [1]:
# SPDX-FileCopyrightText: Copyright (c) 2022 Guillaume Bellegarda. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2022 EPFL, Guillaume Bellegarda

"""
Run stable baselines 3 on quadruped env 
Check the documentation! https://stable-baselines3.readthedocs.io/en/master/
"""
import os
from datetime import datetime
# stable baselines 3
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO, SAC
from stable_baselines3.common.env_util import make_vec_env
# utils
from utils.utils import CheckpointCallback
from utils.file_utils import get_latest_model
# gym environment
from env.quadruped_gym_env import QuadrupedGymEnv


LEARNING_ALG = "PPO" # or "SAC"
LOAD_NN = False # if you want to initialize training with a previous model 
NUM_ENVS = 1    # how many pybullet environments to create for data collection
USE_GPU = False # make sure to install all necessary drivers 

LEARNING_ALG = "SAC";  USE_GPU = True
# after implementing, you will want to test how well the agent learns with your MDP: 
# env_configs = {"motor_control_mode":"CPG",
#                "task_env": "FWD_LOCOMOTION", #  "LR_COURSE_TASK",
#                "observation_space_mode": "LR_COURSE_OBS"}
env_configs = {}

if USE_GPU and LEARNING_ALG=="SAC":
    gpu_arg = "auto" 
else:
    gpu_arg = "cpu"

if LOAD_NN:
    interm_dir = "./logs/intermediate_models/"
    log_dir = interm_dir + '' # add path
    stats_path = os.path.join(log_dir, "vec_normalize.pkl")
    model_name = get_latest_model(log_dir)

# directory to save policies and normalization parameters
SAVE_PATH = './logs/intermediate_models/'+ datetime.now().strftime("%m%d%y%H%M%S") + '/'
os.makedirs(SAVE_PATH, exist_ok=True)
# checkpoint to save policy network periodically
checkpoint_callback = CheckpointCallback(save_freq=30000, save_path=SAVE_PATH,name_prefix='rl_model', verbose=2)
# create Vectorized gym environment
env = lambda: QuadrupedGymEnv(**env_configs)  
env = make_vec_env(env, monitor_dir=SAVE_PATH,n_envs=NUM_ENVS)
# normalize observations to stabilize learning (why?)
env = VecNormalize(env, norm_obs=True, norm_reward=False, clip_obs=100.)

if LOAD_NN:
    env = lambda: QuadrupedGymEnv(**env_configs)
    env = make_vec_env(env, monitor_dir=SAVE_PATH, n_envs=NUM_ENVS)
    env = VecNormalize.load(stats_path, env)

# Multi-layer perceptron (MLP) policy of two layers of size _,_ 
policy_kwargs = dict(net_arch=[256,256])
# What are these hyperparameters? Check here: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
n_steps = 4096 
learning_rate = lambda f: 1e-4 
ppo_config = {  "gamma":0.99, 
                "n_steps": int(n_steps/NUM_ENVS), 
                "ent_coef":0.0, 
                "learning_rate":learning_rate, 
                "vf_coef":0.5,
                "max_grad_norm":0.5, 
                "gae_lambda":0.95, 
                "batch_size":128,
                "n_epochs":10, 
                "clip_range":0.2, 
                "clip_range_vf":1,
                "verbose":1, 
                "tensorboard_log":None, 
                "_init_setup_model":True, 
                "policy_kwargs":policy_kwargs,
                "device": gpu_arg}

# What are these hyperparameters? Check here: https://stable-baselines3.readthedocs.io/en/master/modules/sac.html
sac_config={"learning_rate":1e-4,
            "buffer_size":300000,
            "batch_size":256,
            "ent_coef":'auto', 
            "gamma":0.99, 
            "tau":0.005,
            "train_freq":1, 
            "gradient_steps":1,
            "learning_starts": 10000,
            "verbose":1, 
            "tensorboard_log":None,
            "policy_kwargs": policy_kwargs,
            "seed":None, 
            "device": gpu_arg}

if LEARNING_ALG == "PPO":
    model = PPO('MlpPolicy', env, **ppo_config)
elif LEARNING_ALG == "SAC":
    model = SAC('MlpPolicy', env, **sac_config)
else:
    raise ValueError(LEARNING_ALG + 'not implemented')

if LOAD_NN:
    if LEARNING_ALG == "PPO":
        model = PPO.load(model_name, env)
    elif LEARNING_ALG == "SAC":
        model = SAC.load(model_name, env)
    print("\nLoaded model", model_name, "\n")

# Learn and save (may need to train for longer)
model.learn(total_timesteps=1000000, log_interval=1,callback=checkpoint_callback)
# Don't forget to save the VecNormalize statistics when saving the agent
model.save( os.path.join(SAVE_PATH, "rl_model" ) ) 
env.save(os.path.join(SAVE_PATH, "vec_normalize.pkl" )) 
if LEARNING_ALG == "SAC": # save replay buffer 
    model.save_replay_buffer(os.path.join(SAVE_PATH,"off_policy_replay_buffer"))





Using cpu device
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1e+03    |
|    ep_rew_mean     | 0.446    |
| time/              |          |
|    episodes        | 1        |
|    fps             | 138      |
|    time_elapsed    | 7        |
|    total_timesteps | 1001     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1e+03    |
|    ep_rew_mean     | 0.669    |
| time/              |          |
|    episodes        | 2        |
|    fps             | 131      |
|    time_elapsed    | 15       |
|    total_timesteps | 2002     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 988      |
|    ep_rew_mean     | 0.566    |
| time/              |          |
|    episodes        | 3        |
|    fps             | 135      |
|    time_elapsed    | 21       |
|    total_timesteps | 2965    

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 794      |
|    ep_rew_mean     | 0.659    |
| time/              |          |
|    episodes        | 16       |
|    fps             | 70       |
|    time_elapsed    | 181      |
|    total_timesteps | 12705    |
| train/             |          |
|    actor_loss      | -90.6    |
|    critic_loss     | 33.1     |
|    ent_coef        | 0.767    |
|    ent_coef_loss   | -4.73    |
|    learning_rate   | 0.0001   |
|    n_updates       | 2704     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 764      |
|    ep_rew_mean     | 0.64     |
| time/              |          |
|    episodes        | 17       |
|    fps             | 67       |
|    time_elapsed    | 193      |
|    total_timesteps | 12988    |
| train/             |          |
|    actor_loss      | -99.2    |
|    critic_loss     | 10.1     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 638      |
|    ep_rew_mean     | 0.664    |
| time/              |          |
|    episodes        | 29       |
|    fps             | 42       |
|    time_elapsed    | 438      |
|    total_timesteps | 18505    |
| train/             |          |
|    actor_loss      | -195     |
|    critic_loss     | 14.6     |
|    ent_coef        | 0.437    |
|    ent_coef_loss   | -12.8    |
|    learning_rate   | 0.0001   |
|    n_updates       | 8504     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 639      |
|    ep_rew_mean     | 0.647    |
| time/              |          |
|    episodes        | 30       |
|    fps             | 41       |
|    time_elapsed    | 467      |
|    total_timesteps | 19179    |
| train/             |          |
|    actor_loss      | -200     |
|    critic_loss     | 11.9     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 628      |
|    ep_rew_mean     | 0.578    |
| time/              |          |
|    episodes        | 41       |
|    fps             | 32       |
|    time_elapsed    | 794      |
|    total_timesteps | 25733    |
| train/             |          |
|    actor_loss      | -213     |
|    critic_loss     | 13.9     |
|    ent_coef        | 0.217    |
|    ent_coef_loss   | -19.9    |
|    learning_rate   | 0.0001   |
|    n_updates       | 15732    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 616      |
|    ep_rew_mean     | 0.578    |
| time/              |          |
|    episodes        | 42       |
|    fps             | 32       |
|    time_elapsed    | 802      |
|    total_timesteps | 25864    |
| train/             |          |
|    actor_loss      | -215     |
|    critic_loss     | 15.2     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 613      |
|    ep_rew_mean     | 0.636    |
| time/              |          |
|    episodes        | 53       |
|    fps             | 28       |
|    time_elapsed    | 1136     |
|    total_timesteps | 32474    |
| train/             |          |
|    actor_loss      | -206     |
|    critic_loss     | 79.2     |
|    ent_coef        | 0.115    |
|    ent_coef_loss   | -18.2    |
|    learning_rate   | 0.0001   |
|    n_updates       | 22473    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 605      |
|    ep_rew_mean     | 0.685    |
| time/              |          |
|    episodes        | 54       |
|    fps             | 28       |
|    time_elapsed    | 1147     |
|    total_timesteps | 32684    |
| train/             |          |
|    actor_loss      | -211     |
|    critic_loss     | 20.6     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 520      |
|    ep_rew_mean     | 0.716    |
| time/              |          |
|    episodes        | 67       |
|    fps             | 27       |
|    time_elapsed    | 1263     |
|    total_timesteps | 34855    |
| train/             |          |
|    actor_loss      | -210     |
|    critic_loss     | 57.9     |
|    ent_coef        | 0.0917   |
|    ent_coef_loss   | -17.2    |
|    learning_rate   | 0.0001   |
|    n_updates       | 24854    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 517      |
|    ep_rew_mean     | 0.716    |
| time/              |          |
|    episodes        | 68       |
|    fps             | 27       |
|    time_elapsed    | 1280     |
|    total_timesteps | 35158    |
| train/             |          |
|    actor_loss      | -203     |
|    critic_loss     | 51.8     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 485      |
|    ep_rew_mean     | 0.709    |
| time/              |          |
|    episodes        | 80       |
|    fps             | 26       |
|    time_elapsed    | 1454     |
|    total_timesteps | 38805    |
| train/             |          |
|    actor_loss      | -204     |
|    critic_loss     | 130      |
|    ent_coef        | 0.0647   |
|    ent_coef_loss   | -9.87    |
|    learning_rate   | 0.0001   |
|    n_updates       | 28804    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 485      |
|    ep_rew_mean     | 0.712    |
| time/              |          |
|    episodes        | 81       |
|    fps             | 26       |
|    time_elapsed    | 1479     |
|    total_timesteps | 39322    |
| train/             |          |
|    actor_loss      | -198     |
|    critic_loss     | 42.6     |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 466      |
|    ep_rew_mean     | 0.688    |
| time/              |          |
|    episodes        | 93       |
|    fps             | 25       |
|    time_elapsed    | 1669     |
|    total_timesteps | 43339    |
| train/             |          |
|    actor_loss      | -197     |
|    critic_loss     | 28.8     |
|    ent_coef        | 0.0541   |
|    ent_coef_loss   | 2.83     |
|    learning_rate   | 0.0001   |
|    n_updates       | 33338    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 464      |
|    ep_rew_mean     | 0.681    |
| time/              |          |
|    episodes        | 94       |
|    fps             | 25       |
|    time_elapsed    | 1685     |
|    total_timesteps | 43644    |
| train/             |          |
|    actor_loss      | -201     |
|    critic_loss     | 404      |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 407      |
|    ep_rew_mean     | 0.668    |
| time/              |          |
|    episodes        | 106      |
|    fps             | 25       |
|    time_elapsed    | 1811     |
|    total_timesteps | 46091    |
| train/             |          |
|    actor_loss      | -196     |
|    critic_loss     | 39.6     |
|    ent_coef        | 0.0501   |
|    ent_coef_loss   | 2.11     |
|    learning_rate   | 0.0001   |
|    n_updates       | 36090    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 404      |
|    ep_rew_mean     | 0.669    |
| time/              |          |
|    episodes        | 107      |
|    fps             | 25       |
|    time_elapsed    | 1829     |
|    total_timesteps | 46419    |
| train/             |          |
|    actor_loss      | -205     |
|    critic_loss     | 26       |
|    ent_coef 

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 350      |
|    ep_rew_mean     | 0.68     |
| time/              |          |
|    episodes        | 120      |
|    fps             | 24       |
|    time_elapsed    | 1973     |
|    total_timesteps | 49062    |
| train/             |          |
|    actor_loss      | -193     |
|    critic_loss     | 29.2     |
|    ent_coef        | 0.0452   |
|    ent_coef_loss   | -0.656   |
|    learning_rate   | 0.0001   |
|    n_updates       | 39061    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 348      |
|    ep_rew_mean     | 0.678    |
| time/              |          |
|    episodes        | 121      |
|    fps             | 24       |
|    time_elapsed    | 1991     |
|    total_timesteps | 49398    |
| train/             |          |
|    actor_loss      | -193     |
|    critic_loss     | 78.4     |
|    ent_coef 

KeyboardInterrupt: 