In [31]:
import sys
from pathlib import Path
from stable_baselines3.dqn import DQN
import optuna

PROJECT_ROOT_DIR = Path().absolute().parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from envs.sys_id_env import SystemIdentificationEnv
from utils.wrappers.multibinary_to_discrete import MultiBinaryToDiscreteWrapper

optuna.logging.set_verbosity(optuna.logging.WARNING)  # 关闭Optuna控制台的输出

PROJECT_ROOT_DIR

PosixPath('/home/ucav/pythonprojects/rl_sys_id')

In [32]:
params_config = {
    "g": {
        "initial_value": 10.0,
        "optimize": True,
        "range": [9.0, 10.0],
    },
    "m": {
        "initial_value": 1.0,
        "optimize": True,
        "range": [0.8, 1.2],
    },
    "l": {
        "initial_value": 1.0,
        "optimize": True,
        "range": [0.7, 1.3],
    },
}

env = SystemIdentificationEnv(
    dynamics_env_id="CustomPendulum-v0",
    params_config=params_config,
    obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/obs_real.npy",
    act_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/act_real.npy",
    next_obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/next_obs_real.npy",
    bo_optimizer_n_trials=30,
    bo_optimizer_n_jobs=16,
    reward_b=0.2,
    max_steps=20,
    loss_threshold=1e-10,
)
env = MultiBinaryToDiscreteWrapper(env)

In [33]:
policy_path = PROJECT_ROOT_DIR / "checkpoints/sys_id/custom_pendulum/sac/g_9_5_m_0_9_l_1_2/seed_2/best_model.zip"
algo = DQN.load(policy_path, env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
obs, info = env.reset(seed=9851)
print(f"info 0: {info}")

while True:
    action, _ = algo.policy.predict(observation=obs, deterministic=True)
    next_obs, reward, terminated, truncated, info = env.step(action=action)
    print(obs, action, reward, next_obs)
    print(info)
    obs = next_obs
    if terminated or truncated:
        print("finished")
        break

info 0: {'loss': 0.00024436624}
[10.  1.  1.] 6 -0.37553155349725054 [9.067515  1.0832785 1.1277757]
{'params_before_optimize': {'g': 10.0, 'm': 1.0, 'l': 1.0}, 'params_selected_to_optimize': {'g': {'initial_value': 10.0, 'optimize': True, 'range': [9.0, 10.0]}, 'm': {'initial_value': 1.0, 'optimize': True, 'range': [0.8, 1.2]}, 'l': {'initial_value': 1.0, 'optimize': True, 'range': [0.7, 1.3]}}, 'params_after_optimize': {'g': 9.067515118419042, 'm': 1.0832784975363612, 'l': 1.1277756640256547}, 'loss': 1.8250441371492343e-06, 'success': False}
[9.067515  1.0832785 1.1277757] 2 -0.35299300474835404 [9.067515  1.0742176 1.1253773]
{'params_before_optimize': {'g': 9.067515118419042, 'm': 1.0832784975363612, 'l': 1.1277756640256547}, 'params_selected_to_optimize': {'m': {'initial_value': 1.0, 'optimize': True, 'range': [0.8, 1.2]}, 'l': {'initial_value': 1.0, 'optimize': True, 'range': [0.7, 1.3]}}, 'params_after_optimize': {'m': 1.0742175177856266, 'l': 1.1253772882720297}, 'loss': 1.339

In [35]:
from gymnasium.utils import seeding

rng, seed = seeding.np_random(1)
for i in range(10):
    print(rng.integers(0, 10))

4
5
7
9
0
1
8
9
2
3
