# Horcrux Joystick 입력 학습 진행

## 필요 패키지 import

In [1]:
import gymnasium as gym
import numpy as np
import pandas as pd

# 조이스틱 환경 삽입
import horcrux_terrain_v2
# from horcrux_terrain_v2.envs import PlaneJoyWorld
from horcrux_terrain_v2.envs import PlaneJoyDirWorld

# Ray 패키지 삽입
import ray
import os
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig

from ray.tune.registry import register_env

import mediapy as media

from scipy.ndimage import uniform_filter1d
from scipy.spatial.transform import Rotation

import matplotlib.pyplot as plt

from gymnasium.utils.save_video import save_video

from IPython.display import Video

# 사용자 구성 모델 정의

In [2]:
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.models import ModelCatalog

class CustomSACModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        model_shape = model_config['fcnet_hiddens']
        print(model_config)

        # Shared actor trunk
        self.shared = FullyConnectedNetwork(
            obs_space, action_space, model_shape[-1], model_config, name + "_shared"
        )

        # Value network head 확장
        self.value_branch = nn.Sequential(
            nn.Linear(model_shape[-1], 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        self._value_out = None

    def forward(self, input_dict, state, seq_lens):
        features, _ = self.shared(input_dict, state, seq_lens)
        self._value_out = self.value_branch(features)
        return features, state

    def value_function(self):
        return self._value_out.squeeze(1)
    
    
ModelCatalog.register_custom_model("custom_sac_model", CustomSACModel)

# 필요 함수 정의

In [3]:
def get_unique_filename(base_path, ext=".mp4"):
    """중복된 파일명이 존재하면 숫자를 증가하여 새로운 경로를 반환"""
    if not base_path.endswith(ext):
        base_path += ext  # 확장자 자동 추가

    file_name, file_ext = os.path.splitext(base_path)  # 파일명과 확장자 분리
    count = 0
    new_path = f"{file_name}-episode-0"+file_ext

    while os.path.exists(new_path):  # 파일 존재 여부 확인
        new_path = f"{file_name}{count}-episode-0{file_ext}"
        count += 1


    return f"rl-video{count-1}", new_path


def default_plot(x, y, f_name='default_plot', legends=['acc_x', 'acc_y', 'acc_z'], title=''):
    colors = plt.get_cmap("tab10").colors
    fig, ax = plt.subplots(figsize=(15/2.54, 10/2.54))
    ax.set_facecolor((0.95, 0.95, 0.95)) 

    n_column = len(np.shape(y))
    if n_column>2:
        print("The dimmension of data must be less than 3. (1D or 2D)")
        return -1
    
    n_data = np.shape(y)[1]

    for i in range(n_data):
        # **Plot**
        ax.plot(x, y[:,i], linewidth=1.5, linestyle="-", color=colors[i], label=legends[i])
        # ax.plot(x, y[:,i], linewidth=1.5, linestyle="-", color=colors[1], label=legends[1])
        # ax.plot(x, y[:,i], linewidth=1.5, linestyle="-", color=colors[2], label=legends[2])

    # **Grid 설정**
    ax.grid(True, linestyle="--", linewidth=1, color="#202020", alpha=0.7)  # 주요 그리드
    ax.minorticks_on()
    ax.grid(True, which="minor", linestyle=":", linewidth=0.5, color="#404040", alpha=0.5)  # 보조 그리드

    # **Axis 스타일 설정**
    ax.spines["top"].set_linewidth(1.0)
    ax.spines["right"].set_linewidth(1.0)
    ax.spines["left"].set_linewidth(1.0)
    ax.spines["bottom"].set_linewidth(1.0)

    ax.tick_params(axis="both", labelsize=11, width=1.0)  # 폰트 크기 및 라인 두께
    ax.xaxis.label.set_size(12)
    ax.yaxis.label.set_size(12)

    # **폰트 및 제목 설정**
    plt.rcParams["font.family"] = "Arial"
    ax.set_xlabel("X-Axis", fontsize=12, fontweight="bold")
    ax.set_ylabel("Y-Axis", fontsize=12, fontweight="bold")
    ax.set_title(title, fontsize=14, fontweight="bold")

    # **Legend (MATLAB 스타일 적용)**
    ax.legend(loc="upper right", ncol=3, fontsize=10, frameon=True)

    # **비율 설정 (MATLAB의 `pbaspect([2.1 1 1])`과 비슷한 효과)**
    fig.set_size_inches(2.1 * 5, 5)  # 비율 2.1:1 (기본 높이 5inch 기준)

    # **Save Figure (MATLAB saveas와 유사)**
    plt.savefig(f"./figs/{f_name}.png", dpi=600, bbox_inches="tight")

    plt.show()

def moving_average(data, window_size):
    kernel = np.ones(window_size) / window_size
    return np.convolve(data, kernel, mode='same')  # 'valid'는 경계 제외


def get_data_from_info(info):
    # Action info
    action = np.array([_info['action'] for _info in info])

    # Status info
    stat_init_rpy = np.array([_info['init_rpy'] for _info in info])
    stat_init_com = np.array([_info['init_com'] for _info in info])
    stat_xy_vel = np.array([[_info['x_velocity'], _info['y_velocity']] for _info in info])
    stat_yaw_vel = np.array([_info['yaw_velocity'] for _info in info])
    stat_quat = np.array([_info['head_quat'] for _info in info])
    stat_ang_vel = np.array([_info['head_ang_vel'] for _info in info])
    stat_lin_acc = np.array([_info['head_lin_acc'] for _info in info])
    stat_motion_vector = np.array([_info['motion_vector'] for _info in info])
    stat_com_pos = np.array([_info['com_pos'] for _info in info])
    stat_com_ypr = np.array([_info['com_ypr'] for _info in info])
    stat_step_ypr = np.array([_info['step_ypr'] for _info in info])
    stat_reward_func_orientation = np.array([_info['reward_func_orientation'] for _info in info])
    

    # Rew info
    rew_linear_movement = np.array([_info['reward_linear_movement'] for _info in info])
    reward_angular_movement = np.array([_info['reward_angular_movement'] for _info in info])
    reward_efficiency = np.array([_info['reward_efficiency'] for _info in info])
    reward_healthy = np.array([_info['reward_healthy'] for _info in info])
    cost_ctrl = np.array([_info['cost_ctrl'] for _info in info])
    cost_unhealthy = np.array([_info['cost_unhealthy'] for _info in info])
    cost_orientation = np.array([_info['cost_orientation'] for _info in info])
    cost_yaw_vel = np.array([_info['cost_yaw_vel'] for _info in info])
    direction_similarity = np.array([_info['direction_similarity'] for _info in info])
    rotation_alignment = np.array([_info['rotation_alignment'] for _info in info])
    vel_orientation = np.array([_info['velocity_theta'] for _info in info])

    # Input info
    input_joy = np.array([_info['joy_input'] for _info in info])
    gait_param = np.array([_info['gait_params'] for _info in info])

    data_dict = {
        'action': action,
        'stat_init_rpy': stat_init_rpy,
        'stat_init_com': stat_init_com,
        'stat_xy_vel': stat_xy_vel,
        'stat_yaw_vel': stat_yaw_vel,
        'stat_quat': stat_quat,
        'stat_ang_vel': stat_ang_vel,
        'stat_lin_acc': stat_lin_acc,
        'stat_motion_vector': stat_motion_vector,
        'stat_com_pos': stat_com_pos,
        'stat_com_ypr': stat_com_ypr,
        'stat_com_r_ypr':stat_reward_func_orientation,
        'stat_step_ypr': stat_step_ypr,

        'rew_linear_movement': rew_linear_movement,
        'reward_angular_movement': reward_angular_movement,
        'reward_efficiency': reward_efficiency,
        'reward_healthy': reward_healthy,
        'cost_ctrl': cost_ctrl,
        'cost_unhealthy': cost_unhealthy,
        'cost_orientation': cost_orientation,
        'cost_yaw_vel': cost_yaw_vel,
        'direction_similarity': direction_similarity,
        'rotation_alignment': rotation_alignment,
        'vel_orientation': vel_orientation,

        'input_joy': input_joy,
        'gait_param': gait_param,
    }
    
    return data_dict


## Ray 실행

In [4]:
ray.init(dashboard_host="0.0.0.0", dashboard_port=8265)

2025-04-28 10:05:26,178	INFO worker.py:1810 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://10.130.6.78:8265 [39m[22m


0,1
Python version:,3.12.9
Ray version:,2.39.0
Dashboard:,http://10.130.6.78:8265


## Gym 환경 등록하기

In [5]:
env_config = {
    "forward_reward_weight": 100.0,
    "rotation_reward_weight": 100.0,
    "unhealthy_max_steps": 150.0,
    "healthy_reward": 3.0,
    "healthy_roll_range": (-40,40),
    "terminating_roll_range": (-85,85),
    "rotation_norm_cost_weight": 4.5,
    "termination_reward": 0,
    "gait_params": (30, 30, 40, 40, 0),
    "use_friction_chg": True,
    "joy_input_random": True,
    "use_imu_window": True,
    "use_vels_window": True,
    "ctrl_cost_weight": 0.05,
}

render_env_config = env_config.copy()
render_env_config['render_mode'] = 'rgb_array'
render_env_config['render_camera_name'] = 'ceiling'

# env = gym.make("horcrux_terrain_v2/plane-v2", **render_env_config)

# JoyWorld
register_env("joy-v1", lambda config: PlaneJoyDirWorld( forward_reward_weight=env_config["forward_reward_weight"], 
                                                     rotation_reward_weight=env_config["rotation_reward_weight"], 
                                                     unhealthy_max_steps=env_config["unhealthy_max_steps"],
                                                     healthy_reward=env_config["healthy_reward"], 
                                                     healthy_roll_range=env_config["healthy_roll_range"],
                                                     terminating_roll_range=env_config["terminating_roll_range"],
                                                     rotation_norm_cost_weight=env_config["rotation_norm_cost_weight"],
                                                     termination_reward=env_config["termination_reward"],
                                                     gait_params=env_config["gait_params"],
                                                     use_friction_chg=env_config["use_friction_chg"],
                                                     joy_input_random=env_config["joy_input_random"],
                                                     use_imu_window=env_config["use_imu_window"],
                                                     ctrl_cost_weight=env_config["ctrl_cost_weight"],
                                                   )
            )

## 학습 알고리즘 설정하기 PPO
+ 신형 API 구조 사용해보기 (get_policy() 메서드 오류 생김... 추론 못함)

In [6]:
# from ray.rllib.core.rl_module import RLModuleSpec
# from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
#     PPOTorchRLModule
# )

# from ray.rllib.examples.rl_modules.classes.lstm_containing_rlm import (
#     LSTMContainingRLModule,
# )
# from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

# config = PPOConfig()

# # 구형 API 구조 사용
# config.api_stack(
#     enable_rl_module_and_learner=True,
#     enable_env_runner_and_connector_v2=True,
# )

# config.environment("joy-v1")
# config.framework("torch")

# # 병렬 CPU 사용 설정
# total_workers = 16
# config.learners(num_learners = 1, num_gpus_per_learner=1)
# config.env_runners(num_env_runners = total_workers, num_cpus_per_env_runner = 1, rollout_fragment_length = 'auto')
# config.rl_module(
#     rl_module_spec=RLModuleSpec(
#         module_class=LSTMContainingRLModule,
#         learner_only=False,
#         inference_only=False,
#     ),
#     model_config = DefaultModelConfig(
#         fcnet_hiddens=[512, 512, 512, 512, 512, 512],
#         fcnet_activation="swish",
#         use_lstm=True,
#         max_seq_len=100,
#         lstm_use_prev_action=True,
#         lstm_cell_size=256,
#     ),
# )
# config.training(
#     # Default config sets
#     gamma=0.95,
#     lr=0.0001,
#     train_batch_size_per_learner = 100000,
#     minibatch_size = 5000,
#     num_epochs = 10,
#     shuffle_batch_per_epoch = False,

#     # PPO config sets
#     entropy_coeff = 0.01,
#     vf_loss_coeff = 0.5, #이 값 튜닝 진행해야함. (기본값 : 1.0)
#     vf_clip_param = 5,
# )

# algo = config.build()


+ 구형 API 사용해서 구현
RNN 사용

In [7]:
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
config = PPOConfig()

# 구형 API 구조 사용
config.api_stack(
    enable_rl_module_and_learner=False,
    enable_env_runner_and_connector_v2=False,
)

config.environment("joy-v1")
config.framework("torch")
config.resources(
    num_cpus_for_main_process=4,
    num_gpus=1,
    num_gpus_per_learner_worker=1,
)

# 병렬 CPU 사용 설정
total_workers = 16
config.learners(num_learners = 1, num_gpus_per_learner=1)
config.env_runners(num_env_runners = total_workers, num_cpus_per_env_runner = 1, rollout_fragment_length = 'auto')

config.training(
    # Default config sets
    gamma=0.95,
    lr=0.0005,
    train_batch_size = 100000,
    minibatch_size = 5000,
    num_epochs = 40,
    shuffle_batch_per_epoch = True,
    model = {
        # "fcnet_hiddens": [256, 256, 256, 256, 64],
        "fcnet_hiddens": [256, 256, 256, 256, 256, 256, 64],
        "fcnet_activation": "tanh",
        # "post_fcnet_hiddens": [],
        # "post_fcnet_activation": "tanh",
        'vf_share_layers': False, #원래는 False로 학습했었음.
        "use_lstm": False,
        # "max_seq_len": 40,
        # "lstm_use_prev_action": True,
        # "lstm_cell_size": 64,

        # "custom_model": None,  # Use this to define custom Q-model(s).
        # "custom_model_config": {},
    },

    # PPO config sets
    clip_param=0.2,                # 기본값 0.3
    entropy_coeff = 0.01,          # 기본값 0.01
    kl_coeff=0.2,                  # 기본값 0.2
    lambda_=1.0,                   # 기본값 1.0
    vf_loss_coeff = 0.75,           # 이 값 튜닝 진행해야함. (기본값 : 1.0)
    vf_clip_param = 5.0,            # 기본값 5
    grad_clip=0.4,                # 기본값 0.5
)

algo = config.build()

try:
    if prior_weight:
        algo.set_weights(prior_weight)
        print("Prior weight is set to loaded weight.")

except:
    print("Prior weight does not exist.")

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-04-28 10:05:42,160	INFO trainable.py:161 -- Trainable.setup took 12.642 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


Prior weight does not exist.


In [None]:
algo.get_policy().model

+ RNN 미사용 모델

In [None]:
config = PPOConfig()

# 구형 API 구조 사용
config.api_stack(
    enable_rl_module_and_learner=False,
    enable_env_runner_and_connector_v2=False,
)

config.environment("joy-v1")
config.framework("torch")
config.resources(
    num_cpus_for_main_process=8,
    num_gpus=1,
    # num_learner_workers=1,
    # num_gpus_per_learner_worker=1,
)

config.learners(
    num_learners=0,
    num_gpus_per_learner=1,
)

# 병렬 CPU 사용 설정
total_workers = 16
config.env_runners(num_env_runners = total_workers, num_cpus_per_env_runner = 1, rollout_fragment_length = 'auto')

config.training(
    # Default config sets
    gamma=0.95,
    lr=0.0005,
    train_batch_size = 100000,
    minibatch_size = 10000,
    num_epochs = 10,
    shuffle_batch_per_epoch = False,
    model = {
        "fcnet_hiddens": [512, 512, 512, 512, 512, 512],
        "fcnet_activation": "swish",
        # "post_fcnet_hiddens": [],
        # "post_fcnet_activation": "tanh",

        # "custom_model": None,  # Use this to define custom Q-model(s).
        # "custom_model_config": {},
    },

    # PPO config sets
    entropy_coeff = 0.01,
    vf_loss_coeff = 0.5, #이 값 튜닝 진행해야함. (기본값 : 1.0)
    vf_clip_param = 5,
)

algo = config.build()

+ 학습된 모델 불러오기

In [None]:
base_path = '/home/bong/Project/snake_RL/GD_tor/learning/ray239-ppo'

## LP Final
# GD_tor/learning/ray239-ppo/learned_policy/20250418-LP-Final/PPO_NEW_512_7_416_vf_noshare_38
## LP Best
# GD_tor/learning/ray239-ppo/learned_policy/20250418-LP-Final/PPO_NEW_512_7_416_vf_noshare_2_best

algo = Algorithm.from_checkpoint(path=base_path + '/learned_policy/20250418-LP-Final/PPO_NEW_512_7_416_vf_noshare_2_best/')

# config = algo.get_config()
# prior_weight = algo.get_weights()
# algo.cleanup()

In [None]:
from pprint import pprint
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

# pprint(DefaultModelConfig())
# pprint(algo.get_config().to_dict())
# algo.get_default_policy_class(config).
# algo.get_module()

# algo.get_config().to_dict()
# algo.get_module().get_initial_state()
# algo.get_metadata()
algo.get_policy().model

## 학습시작
+ 구형 API

In [8]:
from pprint import pprint
import datetime
from scipy.io import savemat

n_iter = 1000
save_iter = 0
save_name = "PPO_8DIR_256_6_param_update_428"

for i in range(n_iter):
    result = algo.train()
    print(f"{i:03d}", end=", ")
    # result.pop("config")
    # pprint(result)

    if i%50 == 0:
        checkpoint_dir = algo.save(save_name + "_" + str(save_iter))
        print(f"Checkpoint saved in directory {checkpoint_dir}")
        save_iter += 1


        # Record Validation Env
        env = gym.make("horcrux_terrain_v2/plane-v3", **render_env_config)
        obs = env.reset()[0]
        env_done = False
        init_prev_a = prev_a = np.array([0]*14)
        lstm_cell_size = config["model"]["lstm_cell_size"]

        if algo.config.enable_rl_module_and_learner:
            init_state = state = algo.get_policy().model.get_initial_state()
        else:
            init_state = state = [np.zeros([lstm_cell_size], np.float32) for _ in range(2)]

        rew_return = 0
        frames = []
        info = []

        for i in range(3000):
            act, _state_out, _ = algo.compute_single_action(observation=obs, state=state, prev_action=prev_a, explore=False)
            obs, _step_rew, _, env_done, env_info = env.step(act)
            pixels = env.render()
            frames.append(pixels)
            info.append(env_info)
            rew_return += _step_rew
            state = _state_out
            prev_a = act

        _video_base_name = 'rl-video'

        _f_name, _full_path = get_unique_filename(f"./video/{_video_base_name}")
        rew_dict = get_data_from_info(info)
        rew_dict['rew_return'] = rew_return
        rew_dict['motionMatrix'] = info[-1]['motionMatrix']

        # Save Video
        save_video(frames, "./video/", name_prefix=_f_name, fps=env.metadata['render_fps'])

        # Save Video Info
        _f_video_info = open(f"./video/joy_input.txt", 'a')
        _f_video_info.write(f'File creation time: {datetime.datetime.now()}\n')
        _f_video_info.write(f'Video file name: {_f_name}, Joy input: {info[0]["joy_input"]}, Friction: {info[0]["friction_coeff"]}\n')
        _f_video_info.close()

        # Save Reward Info mat file
        savemat(f"./data/{save_name}_{_f_name}.mat", rew_dict)

        env.reset()
        env.close()


algo.save(save_name + str("_final"))



000, Checkpoint saved in directory TrainingResult(checkpoint=Checkpoint(filesystem=local, path=PPO_8DIR_256_6_param_update_428_0), metrics={'custom_metrics': {}, 'episode_media': {}, 'info': {'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'cur_kl_coeff': np.float64(0.2), 'cur_lr': np.float64(0.0005000000000000001), 'total_loss': np.float64(3.4996494367718696), 'policy_loss': np.float64(-0.05383548174926545), 'vf_loss': np.float64(4.998071069717407), 'vf_explained_var': np.float64(-9.05841588973999e-06), 'kl': np.float64(0.020831007733104343), 'entropy': np.float64(19.92345423936844), 'entropy_coeff': np.float64(0.009999999999999998)}, 'model': {}, 'num_grad_updates_lifetime': np.float64(400.5), 'diff_num_grad_updates_vs_sampler_policy': np.float64(399.5)}}, 'num_env_steps_sampled': 100000, 'num_env_steps_trained': 100000, 'num_agent_steps_sampled': 100000, 'num_agent_steps_trained': 100000}, 'env_runners': {'episode_reward_max': np.float64(-1320.0428637975058), 

TrainingResult(checkpoint=Checkpoint(filesystem=local, path=PPO_8DIR_256_6_param_update_428_final), metrics={'custom_metrics': {}, 'episode_media': {}, 'info': {'learner': {'default_policy': {'custom_metrics': {}, 'learner_stats': {'cur_kl_coeff': np.float64(12.974633789062496), 'cur_lr': np.float64(0.0005000000000000001), 'total_loss': np.float64(3.4393532010912895), 'policy_loss': np.float64(-0.07590399413253181), 'vf_loss': np.float64(4.993869961500168), 'vf_explained_var': np.float64(-0.10325241863727569), 'kl': np.float64(0.009018513053961784), 'entropy': np.float64(34.715717606544494), 'entropy_coeff': np.float64(0.009999999999999998)}, 'model': {}, 'num_grad_updates_lifetime': np.float64(799600.5), 'diff_num_grad_updates_vs_sampler_policy': np.float64(399.5)}}, 'num_env_steps_sampled': 100000000, 'num_env_steps_trained': 100000000, 'num_agent_steps_sampled': 100000000, 'num_agent_steps_trained': 100000000}, 'env_runners': {'episode_reward_max': np.float64(230.72442737621512), 'e

정책 녹화하기

In [9]:
from pprint import pprint
import datetime
from scipy.io import savemat

config = algo.get_config()

save_name = "8DIR-1-0"
num_videos = 30
for iteration in range(num_videos):
    # Record Validation Env
    env = gym.make("horcrux_terrain_v2/plane-v3", **render_env_config)
    obs = env.reset()[0]
    env_done = False
    init_prev_a = prev_a = np.array([0]*14)
    lstm_cell_size = config["model"]["lstm_cell_size"]

    if algo.config.enable_rl_module_and_learner:
        init_state = state = algo.get_policy().model.get_initial_state()
    else:
        init_state = state = [np.zeros([lstm_cell_size], np.float32) for _ in range(2)]

    rew_return = 0
    frames = []
    info = []

    for i in range(1000):
        act, _state_out, _ = algo.compute_single_action(observation=obs, state=state, prev_action=prev_a)
        obs, _step_rew, _, env_done, env_info = env.step(act)
        pixels = env.render()
        frames.append(pixels)
        info.append(env_info)
        rew_return += _step_rew
        state = _state_out
        prev_a = act

    _video_base_name = 'rl-video'

    _f_name, _full_path = get_unique_filename(f"./test/{_video_base_name}")
    rew_dict = get_data_from_info(info)
    rew_dict['rew_return'] = rew_return

    # Save Video
    save_video(frames, "./test/", name_prefix=_f_name, fps=env.metadata['render_fps'])

    # Save Video Info
    _f_video_info = open(f"./video/joy_input.txt", 'a')
    _f_video_info.write(f'File creation time: {datetime.datetime.now()}\n')
    _f_video_info.write(f'Video file name: {_f_name}, Joy input: {info[0]["joy_input"]}, Friction: {info[0]["friction_coeff"]}\n')
    _f_video_info.close()

    # Save Reward Info mat file
    savemat(f"./test/{save_name}_{_f_name}.mat", rew_dict)

    env.close()

In [None]:
# algo.get_module().input_specs_train
# algo.get_module().input_specs_inference()


algo.get_policy().model.get_initial_state()