In [30]:
import sys
from pathlib import Path
from stable_baselines3.dqn import DQN
import optuna

PROJECT_ROOT_DIR = Path().absolute().parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from envs.sys_id_env import SystemIdentificationEnv
from utils.wrappers.multibinary_to_discrete import MultiBinaryToDiscreteWrapper

optuna.logging.set_verbosity(optuna.logging.WARNING)  # 关闭Optuna控制台的输出

PROJECT_ROOT_DIR

PosixPath('/home/ucav/pythonprojects/rl_sys_id')

In [None]:
params_config = {
    "g": {
        "initial_value": 10.0,
        "optimize": True,
        "range": [9.0, 10.0],
    },
    "m": {
        "initial_value": 1.0,
        "optimize": True,
        "range": [0.8, 1.2],
    },
    "l": {
        "initial_value": 1.0,
        "optimize": True,
        "range": [0.7, 1.3],
    },
}

env = SystemIdentificationEnv(
    dynamics_env_id="CustomPendulum-v0",
    params_config=params_config,
    obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/obs_real.npy",
    act_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/act_real.npy",
    next_obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/next_obs_real.npy",
    bo_optimizer_n_trials=30,
    bo_optimizer_seed=42,
    bo_optimizer_n_jobs=16,
    reward_b=0.2,
    max_steps=10,
    loss_threshold=1e-10,
)
env = MultiBinaryToDiscreteWrapper(env)

In [32]:
policy_path = PROJECT_ROOT_DIR / "checkpoints/sys_id/custom_pendulum/sac/g_9_5_m_0_9_l_1_2/seed_1/best_model.zip"
algo = DQN.load(policy_path, env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [33]:
obs, info = env.reset(seed=10)
print(f"info 0: {info}")

while True:
    action, _ = algo.policy.predict(observation=obs, deterministic=True)
    next_obs, reward, terminated, truncated, info = env.step(action=action)
    print(obs, action, reward)
    print(info)
    obs = next_obs
    if terminated or truncated:
        break

info 0: {'loss': 0.00024436624}
[10.  1.  1.] 6 -0.4328812880789597
{'params_before_optimize': {'g': 10.0, 'm': 1.0, 'l': 1.0}, 'params_selected_to_optimize': {'g': {'initial_value': 10.0, 'optimize': True, 'range': [9.0, 10.0]}, 'm': {'initial_value': 1.0, 'optimize': True, 'range': [0.8, 1.2]}, 'l': {'initial_value': 1.0, 'optimize': True, 'range': [0.7, 1.3]}}, 'params_after_optimize': {'g': 9.070518527599743, 'm': 1.0984759516442764, 'l': 1.1363101318100588}, 'loss': 3.7143709050724283e-06, 'success': False}
[9.0705185 1.0984759 1.1363101] 6 -0.27172982439902
{'params_before_optimize': {'g': 9.070518527599743, 'm': 1.0984759516442764, 'l': 1.1363101318100588}, 'params_selected_to_optimize': {'g': {'initial_value': 10.0, 'optimize': True, 'range': [9.0, 10.0]}, 'm': {'initial_value': 1.0, 'optimize': True, 'range': [0.8, 1.2]}, 'l': {'initial_value': 1.0, 'optimize': True, 'range': [0.7, 1.3]}}, 'params_after_optimize': {'g': 9.740522208993532, 'm': 0.8705047602533282, 'l': 1.216168