In [50]:
import sys
from pathlib import Path
from stable_baselines3.dqn import DQN
import optuna

PROJECT_ROOT_DIR = Path().absolute().parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from envs.sys_id_env import SystemIdentificationEnv
from utils.wrappers.multibinary_to_discrete import MultiBinaryToDiscreteWrapper

optuna.logging.set_verbosity(optuna.logging.WARNING)  # 关闭Optuna控制台的输出

PROJECT_ROOT_DIR

PosixPath('/home/ucav/pythonprojects/rl_sys_id')

In [51]:
params_config = {
    "gravity": {
        "initial_value": 9.8,
        "optimize": True,
        "range": [9.5, 10.0],
    },
    "masscart": {
        "initial_value": 1.0,
        "optimize": True,
        "range": [0.8, 1.2],
    },
    "masspole": {
        "initial_value": 0.1,
        "optimize": True,
        "range": [0.08, 0.12],
    },
    "length": {
        "initial_value": 0.5,
        "optimize": True,
        "range": [0.4, 0.6],
    },
    "force_mag": {
        "initial_value": 10.0,
        "optimize": True,
        "range": [9.0, 11.0],
    },
    "tau": {
        "initial_value": 0.02,
        "optimize": True,
        "range": [0.018, 0.022],
    },
}

env = SystemIdentificationEnv(
    dynamics_env_id="CustomCartPole-v0",
    params_config=params_config,
    obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_cartpole/permanent/obs_real.npy",
    act_real_file_path=PROJECT_ROOT_DIR / "data/custom_cartpole/permanent/act_real.npy",
    next_obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_cartpole/permanent/next_obs_real.npy",
    bo_optimizer_n_trials=30,
    bo_optimizer_n_jobs=16,
    bo_optimizer_sample_num_in_optimize=1000,
    reward_b=0.2,
    max_steps=5,
    loss_threshold=1e-10,
)
env = MultiBinaryToDiscreteWrapper(env)

In [52]:
policy_path = PROJECT_ROOT_DIR / "checkpoints/sys_id/custom_cartpole/ppo/g_9_8_mc_1_1_mp_0_09_l_0_55_fm_10_0_tau_0_02/bo_opt_step_30/seed_1/best_model.zip"
algo = DQN.load(policy_path, env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [53]:
obs, info = env.reset(seed=259324)
print(f"info 0: {info}")

while True:
    action, _ = algo.policy.predict(observation=obs, deterministic=True)
    next_obs, reward, terminated, truncated, info = env.step(action=action)
    print(obs, action, reward, next_obs)
    print(info)
    obs = next_obs
    if terminated or truncated:
        print("finished")
        break

info 0: {'loss': 0.00069014146}
[ 9.8   1.    0.1   0.5  10.    0.02] 46 -0.30730543173187924 [9.751348   1.         0.10847028 0.55959064 9.288495   0.02      ]
{'params_before_optimize': {'gravity': 9.8, 'masscart': 1.0, 'masspole': 0.1, 'length': 0.5, 'force_mag': 10.0, 'tau': 0.02}, 'params_selected_to_optimize': {'gravity': {'initial_value': 9.8, 'optimize': True, 'range': [9.5, 10.0]}, 'masspole': {'initial_value': 0.1, 'optimize': True, 'range': [0.08, 0.12]}, 'length': {'initial_value': 0.5, 'optimize': True, 'range': [0.4, 0.6]}, 'force_mag': {'initial_value': 10.0, 'optimize': True, 'range': [9.0, 11.0]}}, 'params_after_optimize': {'gravity': 9.751347247625441, 'masspole': 0.10847028124988517, 'length': 0.5595906297462993, 'force_mag': 9.288495167143376}, 'loss': 1.8914257680080482e-06, 'success': False}
[9.751348   1.         0.10847028 0.55959064 9.288495   0.02      ] 45 -0.3915815245430152 [9.9761505  1.         0.10690418 0.5449301  9.288495   0.01989127]
{'params_before