* Ray RLlib 노트북

필요 패키지 삽입

In [None]:
import gymnasium as gym
import numpy as np
import pandas as pd

from horcrux_terrain_v1.envs import SandWorld
from horcrux_terrain_v1.envs import PlaneWorld
from horcrux_terrain_v1.envs import PlanePipeWorld

import ray
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.sac import SACConfig

from ray.tune.registry import register_env


Ray 실행 (Warning 관련 무시 키워드)

In [None]:
ray.init(runtime_env={"env_vars": {"PYTHONWARNINGS": "ignore::DeprecationWarning"}})

Gym -> Rllib Env 등록

In [3]:
env_config = {
    "forward_reward_weight": 4,
    "side_cost_weight": 3,
    "unhealthy_max_steps": 100,
    "healthy_roll_range": (-40,40),
    "terminating_roll_range": (-70,70),
    "rotation_norm_cost_weight": 0.1,
    "termination_reward": 0,
    # "use_gait": False,
    # "gait_params" : (30,30,80,40,0),
}

# Sand
register_env("sand-v1", lambda config: SandWorld(forward_reward_weight=env_config["forward_reward_weight"], 
                                                 side_cost_weight=env_config["side_cost_weight"], 
                                                 unhealthy_max_steps=env_config["unhealthy_max_steps"], 
                                                 healthy_roll_range=env_config["healthy_roll_range"],
                                                 terminating_roll_range=env_config["terminating_roll_range"],
                                                 rotation_norm_cost_weight=env_config["rotation_norm_cost_weight"],
                                                 termination_reward=env_config["termination_reward"]))

# Plane
register_env("plane-v1", lambda config: PlaneWorld(forward_reward_weight=env_config["forward_reward_weight"], 
                                                 side_cost_weight=env_config["side_cost_weight"], 
                                                 unhealthy_max_steps=env_config["unhealthy_max_steps"], 
                                                 healthy_roll_range=env_config["healthy_roll_range"],
                                                 terminating_roll_range=env_config["terminating_roll_range"],
                                                 rotation_norm_cost_weight=env_config["rotation_norm_cost_weight"],
                                                 termination_reward=env_config["termination_reward"]))

# Pipe
register_env("pipe-v1", lambda config: PlanePipeWorld(forward_reward_weight=env_config["forward_reward_weight"], 
                                                 side_cost_weight=env_config["side_cost_weight"], 
                                                 unhealthy_max_steps=env_config["unhealthy_max_steps"], 
                                                 healthy_roll_range=env_config["healthy_roll_range"],
                                                 terminating_roll_range=env_config["terminating_roll_range"],
                                                 rotation_norm_cost_weight=env_config["rotation_norm_cost_weight"],
                                                 termination_reward=env_config["termination_reward"]))

이전 학습결과 불러오기

In [None]:
algo = Algorithm.from_checkpoint("./PPO_LOAD_V4_LP_512x5_4_7")

In [8]:
algo.get_policy().export_model("./model")

CPU 모델로 변경

In [None]:
# PPO Config
new_algo_config = PPOConfig()
# Activate new API stack. -> 구려서 안씀.
new_algo_config.api_stack(
    enable_rl_module_and_learner=False,
    enable_env_runner_and_connector_v2=False,
)
new_algo_config.framework("torch")
new_algo_config.environment("plane-v1")
new_algo_config.resources(num_gpus=0)
new_algo_config.training(
    gamma=0.9, 
    lr=0.001, 
    # kl_coeff=0.3, 

    # See model catalog for more options.
    # https://docs.ray.io/en/latest/rllib/rllib-models.html
    model={ "fcnet_hiddens": [512, 512, 512, 512, 512],
            },
)

# nn_weight = algo.get_weights()

# GPU model to CPU model
# algo.get_policy().model.cpu()


# cp = algo.save("PPO-LP-CPU")

In [None]:
new_algo = new_algo_config.build()

In [None]:
new_algo.get_policy().from_checkpoint("./PPO_LOAD_V4_LP_512x5_4_7")
# new_algo.get_policy().model.cpu()

# new_algo.save("MO-PPO-LP-CPU")

In [None]:
new_algo.compute_single_action(observation=np.random.random((94,)), explore=False)

Policy 테스트

In [187]:
new_env_config = env_config.copy()
new_env_config["gait_params"] = (30,30,40,40,0)

import time

env = gym.make("horcrux_terrain_v1/plane-v1", 
               terminate_when_unhealthy = False, 
               render_mode = "human", 
            #    render_camera_name = 'ceiling', 
               use_gait = True,
               **new_env_config,
               ) 

for j in range(3):
   t_elength = time.time()
   t_now = time.time()
   episode_return = 0
   terminated = truncated = False

   obs, info = env.reset()


   for i in range(1500):
      # while (time.time() - t_now) < 0.1:
      #    pass
      t_now = time.time()
      action = new_algo.compute_single_action(obs, explore=False)
    #   print(action)
      obs, reward, terminated, truncated, info = env.step(action)
      
      prev_a = action

      episode_return += reward

      if terminated:
         print("terminated")
         break

   print(f"Reached episode return of {episode_return}.")
   print(f"Time elapsed: {time.time() - t_elength}")

env.close()

[1.3524626 1.3480861 1.357535  1.348239  1.3535776 1.3541676 1.3468273
 1.352368  1.3504469 1.3488488 1.3536749 1.3481944 1.3477944 1.3538078]
[1.3492264 1.3486934 1.3515334 1.3501146 1.3495238 1.3494679 1.3488638
 1.3497941 1.3502128 1.348424  1.3499427 1.3496855 1.3483757 1.3479973]
[1.3433179 1.3422965 1.3463092 1.3450594 1.3482717 1.3421644 1.347463
 1.3546385 1.358533  1.3445584 1.3523055 1.3500615 1.3521538 1.3408014]
[1.3489414 1.3528799 1.3519994 1.3489983 1.3620064 1.3526528 1.3502693
 1.3437304 1.3486032 1.3530693 1.3501176 1.3472288 1.3515041 1.3497863]
[1.3544271 1.3462383 1.3440564 1.3419213 1.3497539 1.3540344 1.3540262
 1.355004  1.3507057 1.3506647 1.3441888 1.354438  1.3531194 1.3481605]
[1.3425483 1.3534966 1.3468347 1.3510327 1.3513395 1.3506064 1.3498417
 1.346462  1.3485956 1.3533909 1.3450626 1.3541821 1.344963  1.3553522]
[1.3461792 1.3491566 1.3477942 1.344788  1.3527915 1.356246  1.3451666
 1.3554243 1.345032  1.347825  1.3568884 1.3504347 1.3496689 1.3509313]
