In [1]:
import gymnasium as gym
import sys, os
from typing import Callable
import datetime
import time
import optuna
from stable_baselines3.common import type_aliases
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback, CallbackList, EventCallback
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder, VecEnv, sync_envs_normalization, is_vecenv_wrapped
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.sac import SAC
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common import base_class  # pytype: disable=pyi-error
from newAlgo import HIP, EvalStudentCallback, evaluate_student_policy
from arm_cassie_env.cassie_env.cassieRLEnvMirror import cassieRLEnvMirror
from arm_cassie_env.cassie_env.oldCassie import OldCassieMirrorEnv
from inversepolicies import IPMDPolicy
import wandb
from wandb.integration.sb3 import WandbCallback
from gymnasium.envs.registration import register
import warnings
import numpy as np

In [4]:
def train_SAC(env:str = "HalfCheetah-v4"):
    config = {
        "policy_type": "MlpPolicy",
        "total_timesteps": 5e6,
        "env_id": env,
        'n_envs': 16,
    }
    run = wandb.init(
        project="IRL HIP Tuning",
        config=config,
        name=f'{time.strftime("%Y-%m-%d-%H-%M-%S")}',
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
        save_code=True,  # optional
    )
    wandbcallback = WandbCallback(
            model_save_path=f"models/{run.id}",
            model_save_freq=10000,
            gradient_save_freq=10000,
            verbose=2,
        )
    # Create log dir
    train_env = make_vec_env(config['env_id'], n_envs=config['n_envs'],vec_env_cls=SubprocVecEnv)
	# Separate evaluation env
    eval_env = make_vec_env(config['env_id'], n_envs=1,vec_env_cls=SubprocVecEnv)
	# Use deterministic actions for evaluation
    eval_callback = EvalCallback(eval_env, 
                                 best_model_save_path=f"./logs/{run.name}/teacher/",
                                 log_path=f"./logs/{run.name}/teacher/", 
                                 eval_freq=2000,
                                 n_eval_episodes=3,
                                 deterministic=True, 
                                 render=False,
                                 verbose=1)
    
    callback_list = CallbackList([eval_callback, wandbcallback, ])
	# Init model
    sac_model = SAC(policy=config['policy_type'],
                    env=train_env,
                    tensorboard_log=f'logs/tensorboard/{run.name}/',
                    
                    )
    sac_model.learn(total_timesteps=int(config['total_timesteps']), callback=callback_list, progress_bar=True)
    run.finish()

In [None]:
for env in ['Ant-v4', 'HalfCheetah-v4', 'Hopper-v4', 'Humanoid-v4', 'Walker2d-v4']:
    train_SAC(env=env)