In [133]:
from pathlib import Path
import sys
import gymnasium as gym
import panda_gym
import numpy as np
import pandas as pd
from gymnasium.envs.registration import register
from gymnasium.wrappers import FlattenObservation

PROJECT_ROOT_DIR = Path().absolute().parent
PROJECT_ROOT_DIR

if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from rollout.rollout_by_policy import rollout_by_goal_with_policy
from utils.sb3_env_wrappers import ScaledObservationWrapper
from my_reach_env import MyPandaReachEnv
from utils.load_data import load_data, split_data
from models.sb3_model import PPOWithBCLoss

In [134]:
goal_range = 0.3
distance_threshold = 0.01
EXPERIMENT_NAME = "iter_3/reacher_256_256_1e7steps_8envs_kl_1e-1_loss_5"
# EXPERT_DATA_CACHE_DIR = "rollout/cache/myreach_pid_speed_1.5.csv"
EXPERT_DATA_CACHE_DIR = "rollout/cache/myreach_from_iter_2_rl_bc.csv"
csv_save_name = "iter_3_myreach_rl_bc_5.csv"

In [135]:
register(
    id="my-reach",
    entry_point=f"my_reach_env:MyPandaReachEnv",
    kwargs={"reward_type": "sparse", "control_type": "ee", "goal_range": goal_range, "distance_threshold": distance_threshold},
    max_episode_steps=50,
)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [136]:
data_file: Path = PROJECT_ROOT_DIR / EXPERT_DATA_CACHE_DIR
scaled_obs, acts, infos, obs_scaler = load_data(data_file)

env = gym.make("my-reach")
env = ScaledObservationWrapper(env=FlattenObservation(env), scaler=obs_scaler)

argv[0]=--background_color_red=0.8745098114013672
argv[1]=--background_color_green=0.21176470816135406
argv[2]=--background_color_blue=0.1764705926179886


In [137]:
# load policy
policy_save_dir = PROJECT_ROOT_DIR / "checkpoints" / "rl" / EXPERIMENT_NAME
algo_ppo = PPOWithBCLoss.load(str((policy_save_dir / "best_model").absolute()))

In [138]:
success_cnt = 0
total_df = None
goal_range_int = int(goal_range * 100)
speed=2
x_low, x_high = -int(goal_range_int/2), int(goal_range_int/2)
y_low, y_high = -int(goal_range_int/2), int(goal_range_int/2)
z_low, z_high = 0, int(goal_range_int)
for x in range(x_low, x_high+1, 2):
    for y in range(y_low, y_high+1, 2):
        for z in range(z_low, z_high+1, 2):
            terminated, truncated, traj = rollout_by_goal_with_policy(env=env, goal=np.array([x/100., y/100., z/100.]), policy=algo_ppo.policy)
            if terminated:
                if total_df is None:
                    total_df = pd.DataFrame(traj)
                else:
                    total_df = pd.concat([total_df, pd.DataFrame(traj)])
                success_cnt += 1

success_cnt

[32m success. goal: (-0.15, -0.15, 0.0), steps: 9 [0m
[32m success. goal: (-0.15, -0.15, 0.02), steps: 9 [0m
[32m success. goal: (-0.15, -0.15, 0.04), steps: 8 [0m
[32m success. goal: (-0.15, -0.15, 0.06), steps: 9 [0m
[32m success. goal: (-0.15, -0.15, 0.08), steps: 14 [0m
[31m truncated. goal: (-0.15, -0.15, 0.1), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.12), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.14), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.16), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.18), steps: 50 [0m
[32m success. goal: (-0.15, -0.15, 0.2), steps: 9 [0m
[32m success. goal: (-0.15, -0.15, 0.22), steps: 7 [0m
[32m success. goal: (-0.15, -0.15, 0.24), steps: 6 [0m
[31m truncated. goal: (-0.15, -0.15, 0.26), steps: 50 [0m
[32m success. goal: (-0.15, -0.15, 0.28), steps: 7 [0m
[32m success. goal: (-0.15, -0.15, 0.3), steps: 7 [0m
[32m success. goal: (-0.15, -0.13, 0.0), steps: 8 [0m
[32m success. go

3971

In [139]:
total_df

Unnamed: 0,s_x,s_y,s_z,s_v_x,s_v_y,s_v_z,s_g_x,s_g_y,s_g_z,a_x,a_y,a_z
0,0.038440,-2.845172e-12,0.197400,-2.227948e-09,5.934769e-11,5.533001e-09,-0.15,-0.15,1.835453e-09,-1.000000,-0.629405,-0.955481
1,0.017125,-2.195832e-02,0.168294,-9.344083e-01,-5.874947e-01,-1.153287e+00,-0.15,-0.15,1.835453e-09,-0.714563,-0.491378,-0.693177
2,-0.011835,-4.348623e-02,0.130121,-2.971343e-01,-1.588004e-01,-5.053183e-01,-0.15,-0.15,1.835453e-09,-0.948292,-0.566342,-0.662128
3,-0.044668,-6.582601e-02,0.100810,-7.123211e-01,-3.249866e-01,-5.717320e-01,-0.15,-0.15,1.835453e-09,-0.668921,-0.425169,-0.736320
4,-0.071127,-8.361599e-02,0.066224,-4.527286e-01,-1.422098e-01,-6.750474e-01,-0.15,-0.15,1.835453e-09,-0.822079,-0.463381,-0.352213
...,...,...,...,...,...,...,...,...,...,...,...,...
13,0.142495,1.386164e-01,0.297192,7.042279e-03,6.813938e-03,9.437105e-03,0.15,0.15,3.000000e-01,0.110277,0.033291,0.043451
14,0.143413,1.393219e-01,0.298054,6.958847e-03,5.341846e-03,6.544074e-03,0.15,0.15,3.000000e-01,0.105515,0.028905,0.036114
15,0.144316,1.398712e-01,0.298645,6.861021e-03,4.161383e-03,4.481440e-03,0.15,0.15,3.000000e-01,0.101664,0.025511,0.030848
16,0.145201,1.403000e-01,0.299046,6.725033e-03,3.249471e-03,3.025537e-03,0.15,0.15,3.000000e-01,0.098381,0.022911,0.027075


In [140]:
len(total_df) / success_cnt

5.928481490808361

In [141]:
total_df.to_csv(csv_save_name, index=False)