In [None]:
import sys
from pathlib import Path
from stable_baselines3.dqn import DQN
import optuna

PROJECT_ROOT_DIR = Path().absolute().parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from envs.sys_id_env import SystemIdentificationEnv
from utils.wrappers.multibinary_to_discrete import MultiBinaryToDiscreteWrapper

optuna.logging.set_verbosity(optuna.logging.WARNING)  # 关闭Optuna控制台的输出

PROJECT_ROOT_DIR

In [None]:
params_config = {
    "g": {
        "initial_value": 10.0,
        "optimize": True,
        "range": [9.0, 10.0],
    },
    "m": {
        "initial_value": 1.0,
        "optimize": True,
        "range": [0.8, 1.2],
    },
    "l": {
        "initial_value": 1.0,
        "optimize": True,
        "range": [0.7, 1.3],
    },
}

env = SystemIdentificationEnv(
    dynamics_env_id="CustomPendulum-v0",
    params_config=params_config,
    obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/permanent/obs_real.npy",
    act_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/permanent/act_real.npy",
    next_obs_real_file_path=PROJECT_ROOT_DIR / "data/custom_pendulum/permanent/next_obs_real.npy",
    bo_optimizer_n_trials=30,
    bo_optimizer_n_jobs=16,
    bo_optimizer_sample_num_in_optimize=1000,
    reward_b=0.2,
    max_steps=10,
    loss_threshold=1e-10,
)
env = MultiBinaryToDiscreteWrapper(env)

In [None]:
policy_path = PROJECT_ROOT_DIR / "checkpoints/sys_id/custom_pendulum/sac/g_9_5_m_0_9_l_1_2/seed_2/best_model.zip"
algo = DQN.load(policy_path, env)

In [None]:
obs, info = env.reset(seed=17954)
print(f"info 0: {info}")

while True:
    action, _ = algo.policy.predict(observation=obs, deterministic=True)
    next_obs, reward, terminated, truncated, info = env.step(action=action)
    print(obs, action, reward, next_obs)
    print(info)
    obs = next_obs
    if terminated or truncated:
        print("finished")
        break