# 以固定间隔枚举目标空间采样

In [149]:
from pathlib import Path
import sys
import gymnasium as gym
import panda_gym
import numpy as np
import pandas as pd
from gymnasium.envs.registration import register
from gymnasium.wrappers import FlattenObservation

PROJECT_ROOT_DIR = Path().absolute().parent
PROJECT_ROOT_DIR

if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from rollout.rollout_by_policy import rollout_by_goal_with_policy
from utils.sb3_env_wrappers import ScaledObservationWrapper
from my_reach_env import MyPandaReachEnv
from utils.load_data import load_data, split_data
from models.sb3_model import PPOWithBCLoss

In [150]:
goal_range = 0.3
distance_threshold = 0.01
EXPERIMENT_NAME = "iter_3/reacher_256_256_100epochs_loss_5"
# EXPERT_DATA_CACHE_DIR = "rollout/cache/myreach_pid_speed_1.5.csv"
EXPERT_DATA_CACHE_DIR = "rollout/cache/myreach_from_iter_2_rl_bc.csv"
csv_save_name = "iter_3_myreach_bc_5.csv"

In [151]:
register(
    id="my-reach",
    entry_point=f"my_reach_env:MyPandaReachEnv",
    kwargs={"reward_type": "sparse", "control_type": "ee", "goal_range": goal_range, "distance_threshold": distance_threshold},
    max_episode_steps=50,
)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [152]:
data_file: Path = PROJECT_ROOT_DIR / EXPERT_DATA_CACHE_DIR
scaled_obs, acts, infos, obs_scaler = load_data(data_file)

env = gym.make("my-reach")
env = ScaledObservationWrapper(env=FlattenObservation(env), scaler=obs_scaler)

argv[0]=--background_color_red=0.8745098114013672
argv[1]=--background_color_green=0.21176470816135406
argv[2]=--background_color_blue=0.1764705926179886


In [153]:
# load policy
policy_save_dir = PROJECT_ROOT_DIR / "checkpoints" / "bc" / EXPERIMENT_NAME
algo_ppo = PPOWithBCLoss.load(str((policy_save_dir / "bc_checkpoint").absolute()))

In [154]:
success_cnt = 0
total_df = None
goal_range_int = int(goal_range * 100)
speed=2
x_low, x_high = -int(goal_range_int/2), int(goal_range_int/2)
y_low, y_high = -int(goal_range_int/2), int(goal_range_int/2)
z_low, z_high = 0, int(goal_range_int)
for x in range(x_low, x_high+1, 2):
    for y in range(y_low, y_high+1, 2):
        for z in range(z_low, z_high+1, 2):
            terminated, truncated, traj = rollout_by_goal_with_policy(env=env, goal=np.array([x/100., y/100., z/100.]), policy=algo_ppo.policy)
            if terminated:
                if total_df is None:
                    total_df = pd.DataFrame(traj)
                else:
                    total_df = pd.concat([total_df, pd.DataFrame(traj)])
                success_cnt += 1

success_cnt

[31m truncated. goal: (-0.15, -0.15, 0.0), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.02), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.04), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.06), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.08), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.1), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.12), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.14), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.15, 0.16), steps: 50 [0m
[32m success. goal: (-0.15, -0.15, 0.18), steps: 11 [0m
[32m success. goal: (-0.15, -0.15, 0.2), steps: 10 [0m
[32m success. goal: (-0.15, -0.15, 0.22), steps: 10 [0m
[32m success. goal: (-0.15, -0.15, 0.24), steps: 9 [0m
[32m success. goal: (-0.15, -0.15, 0.26), steps: 9 [0m
[32m success. goal: (-0.15, -0.15, 0.28), steps: 9 [0m
[31m truncated. goal: (-0.15, -0.15, 0.3), steps: 50 [0m
[31m truncated. goal: (-0.15, -0.13, 0.0), steps: 50 [0m


2695

In [155]:
total_df

Unnamed: 0,s_x,s_y,s_z,s_v_x,s_v_y,s_v_z,s_g_x,s_g_y,s_g_z,a_x,a_y,a_z
0,0.038440,-2.845172e-12,0.197400,-2.227948e-09,5.934769e-11,5.533001e-09,-0.15,-0.15,0.18,-0.818392,-1.000000,-0.026094
1,0.012749,-2.571467e-02,0.194471,-2.681201e-01,-1.240516e+00,1.325766e-01,-0.15,-0.15,0.18,-1.000000,-0.739918,-0.103200
2,-0.019951,-6.730323e-02,0.191532,-4.324457e-01,-3.996804e-01,-6.455885e-02,-0.15,-0.15,0.18,-0.767588,-0.587459,0.073793
3,-0.048744,-9.259280e-02,0.192626,-1.921438e-01,-2.722411e-01,1.516464e-01,-0.15,-0.15,0.18,-0.828790,-0.612349,-0.069604
4,-0.077410,-1.182547e-01,0.190659,-3.708776e-01,-3.681294e-01,-5.332219e-02,-0.15,-0.15,0.18,-0.560042,-0.487128,0.086997
...,...,...,...,...,...,...,...,...,...,...,...,...
7,0.134086,1.375646e-01,0.258653,3.742145e-02,1.093792e-01,1.202703e-01,0.15,0.15,0.28,0.310955,0.205043,0.231976
8,0.136880,1.442531e-01,0.266102,2.187905e-02,5.616264e-02,7.397930e-02,0.15,0.15,0.28,0.226537,0.084572,0.147792
9,0.138971,1.463822e-01,0.270720,1.655194e-02,1.547260e-02,3.929159e-02,0.15,0.15,0.28,0.187572,0.017968,0.091029
10,0.141016,1.459456e-01,0.273284,1.561577e-02,-3.725023e-03,1.993250e-02,0.15,0.15,0.28,0.176052,-0.002270,0.062020


In [156]:
len(total_df) / success_cnt

6.657884972170686

In [157]:
total_df.to_csv(csv_save_name, index=False)