# Horcrux Joystick 입력 학습 진행

## 필요 패키지 import

In [None]:
import gymnasium as gym
import numpy as np
import pandas as pd

# 조이스틱 환경 삽입
from horcrux_terrain_v1.envs import PlaneJoyWorld
from horcrux_terrain_v1.envs import PlaneWorld

# Ray 패키지 삽입
import ray
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.sac import SACConfig

from ray.tune.registry import register_env

## Ray 실행

In [None]:
import socket
import psutil

conn_ip = ""
interfaces = psutil.net_if_addrs()
for interface_name, addresses in interfaces.items():
    if "openvpn" in interface_name.lower() and "tap" in interface_name.lower():
        snicaddrs = interfaces[str(interface_name)]
        for addrfamily in snicaddrs:
            if addrfamily.family == socket.AF_INET:
                conn_ip = addrfamily.address

# 해당 init을 통해서 VPN을 통한 외부 접속 가능함.
ray.init(dashboard_host=conn_ip, dashboard_port=8265)

## Gym 환경 등록하기

In [None]:
env_config = {
    "forward_reward_weight": 6.5,
    "side_cost_weight": 2.0,
    "unhealthy_max_steps": 100,
    "healthy_reward": 0.5,
    "healthy_roll_range": (-35,35),
    "terminating_roll_range": (-85,85),
    "rotation_norm_cost_weight": 0.01,
    "rotation_orientation_cost_weight": 1.2,
    "termination_reward": 0,
    "gait_params": (30, 30, 60, 60, 0),
    "use_friction_chg": True,
    "joy_input_random": True,
}

# JoyWorld
register_env("joy-v1", lambda config: PlaneJoyWorld( forward_reward_weight=env_config["forward_reward_weight"], 
                                                     side_cost_weight=env_config["side_cost_weight"], 
                                                     unhealthy_max_steps=env_config["unhealthy_max_steps"],
                                                     healthy_reward=env_config["healthy_reward"], 
                                                     healthy_roll_range=env_config["healthy_roll_range"],
                                                     terminating_roll_range=env_config["terminating_roll_range"],
                                                     rotation_norm_cost_weight=env_config["rotation_norm_cost_weight"],
                                                     rotation_orientation_cost_weight=env_config["rotation_orientation_cost_weight"],
                                                     termination_reward=env_config["termination_reward"],
                                                     gait_params=env_config["gait_params"],
                                                     use_friction_chg=env_config["use_friction_chg"],
                                                     joy_input_random=env_config["joy_input_random"],
                                                   )
            )

# Plane
register_env("plane-v1", lambda config: PlaneWorld(forward_reward_weight=env_config["forward_reward_weight"], 
                                                 side_cost_weight=env_config["side_cost_weight"], 
                                                 unhealthy_max_steps=env_config["unhealthy_max_steps"], 
                                                 healthy_reward=env_config["healthy_reward"],
                                                 healthy_roll_range=env_config["healthy_roll_range"],
                                                 terminating_roll_range=env_config["terminating_roll_range"],
                                                 rotation_norm_cost_weight=env_config["rotation_norm_cost_weight"],
                                                 rotation_orientation_cost_weight=env_config["rotation_orientation_cost_weight"],
                                                 termination_reward=env_config["termination_reward"],
                                                 use_friction_chg=env_config["use_friction_chg"],
                                                 gait_params=env_config["gait_params"]))

## 학습된 가중치 불러오기 (현재 사용 불가)

In [None]:
# algo = Algorithm.from_checkpoint("./Paper_agents/good/Linear/SAC_layer_512_5_32_Linear_restart_final")

# trained_weights = algo.get_weights()
# algo.cleanup()

## 학습 알고리즘 설정하기

In [None]:
config = SACConfig()

# 구형 API 구조 사용
config.api_stack(
    enable_rl_module_and_learner=False,
    enable_env_runner_and_connector_v2=False,
)

config.environment("joy-v1")
config.framework("torch")

# 병렬 CPU 사용 설정
total_workers = 12
config.resources(num_gpus=1)
config.env_runners(num_env_runners=16)
config.training(
    gamma=0.95,
    replay_buffer_config={
    "_enable_replay_buffer_api": True,
    "capacity": int(1e6),
    # If True prioritized replay buffer will be used.
    "prioritized_replay": False,
    "prioritized_replay_alpha": 0.6,
    "prioritized_replay_beta": 0.4,
    "prioritized_replay_eps": 1e-6,
    # Whether to compute priorities already on the remote worker side.
    "worker_side_prioritization": False,
    },

    q_model_config = {
            "fcnet_hiddens": [512, 512, 512, 512, 512, 32],
            "fcnet_activation": "tanh",
            "post_fcnet_hiddens": [],
            "post_fcnet_activation": None,
            "custom_model": None,  # Use this to define custom Q-model(s).
            "custom_model_config": {},
    },
    policy_model_config = {
            "fcnet_hiddens": [512, 512, 512, 512, 512, 32],
            "fcnet_activation": "tanh",
            "post_fcnet_hiddens": [],
            "post_fcnet_activation": None,
            "custom_model": None,  # Use this to define a custom policy model.
            "custom_model_config": {},
    },

    train_batch_size_per_learner = 8192,
    num_steps_sampled_before_learning_starts = 9000,
)

config.evaluation(evaluation_interval=100,
                  evaluation_config={
                      "reder_mode": "rgb_array",
                      "render_camera_name": 'ceiling',
                      "joy_input_random": False,
                      "joy_input": (0.7, 0, 0.6),
                  },
)
algo = config.build()
