In [None]:
import os
import time
import json
import shutil
import unittest
import subprocess
from pathlib import Path
import numpy as np

from scipy.spatial.transform import Rotation as R

from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.demo.demo_episode_recorder import record_sim
from flow_control.runner import evaluate_control
from flow_control.servoing.module import ServoingModule
from flow_control.servoing.playback_env_servo import PlaybackEnvServo
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout
import seaborn as sns

%matplotlib inline

root_dir = "../tmp/ss_traj_based"

In [None]:
recordings = sorted([os.path.join(root_dir, rec) for rec in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, rec))])

# Load the demonstration episodes
playbacks = [PlaybackEnvServo(rec) for rec in recordings[:]]

# Plot the demonstrations
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(demo_index, frame_index):
    image = playbacks[demo_index][frame_index].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    print("wp_name:", playbacks[demo_index][frame_index].get_info()["wp_name"])
    fg_mask = playbacks[demo_index].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, demo_index=slider_w, frame_index=slider_i)

In [None]:
def filter_demo(pb):
    return pb[-1].data['rew'] > 0 and np.mean(pb.get_fg_mask()) > 0.005

demo_good = [filter_demo(pb) for pb in playbacks]
good_demonstrations = np.where(demo_good)[0]
print(good_demonstrations)
good_demonstrations = [int(x) for x in good_demonstrations]
live_seeds = good_demonstrations

In [None]:
from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.servoing.module import ServoingModule
from flow_control.runner import evaluate_control
from math import pi

def eval_cmb_single(save_dir, live_seed, demo_rec):
    renderer = "debug"
    control_config = dict(mode="pointcloud-abs-rotz", threshold=0.4)
    task = 'shape_sorting'
    object_selected = 'trapeze'
    task_variant = 'rR'
    param_info = {"object_selected": object_selected}
    env = RobotSimEnv(task='shape_sorting', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=200, control='absolute-full',
                      img_size=(256, 256),
                      param_randomize=("geom",),
                      param_info=param_info,
                      task_info=dict(object_rot_range={"rP":pi/2.,"rR":pi/6.}[task_variant]),
                      seed=live_seed)

    servo_module = ServoingModule(demo_rec, control_config=control_config,
                                  start_paused=False, plot=False, plot_save_dir=None)
    _, reward, _, info = evaluate_control(env, servo_module,
                                          max_steps=130,
                                          save_dir=save_dir)
    return reward

In [None]:
rewards = np.zeros((10, 10))
save_root = "../tmp/single_demo_run_ss_rR"
os.makedirs(save_root, exist_ok=True)

for live_idx, live_seed in enumerate(good_demonstrations[0:10]):
    for demo_idx, demo_seed in enumerate(good_demonstrations[0:10]):
        if live_seed == demo_seed:
            continue
        demo_rec = recordings[demo_seed]
        save_dir = f"{save_root}/run_ss_trapeze_rR_{live_seed}_{demo_seed}"
        reward = eval_cmb_single(save_dir, live_seed, demo_rec)
        rewards[live_idx, demo_idx] = reward
        np.savez(f"../tmp/single_demo_rewards_ss_rR.npz", rewards)

In [None]:
# runs = [os.path.join(save_root, run) for run in os.listdir(save_root)]
# run_playbacks = [PlaybackEnvServo(rec) for rec in runs]

In [None]:
save_root = "../tmp/single_demo_run_ss"

In [None]:
def filter_demo(pb):
    return 1 if pb[-1].data['rew'] > 0 and np.atleast_1d(pb[-1].data['info'])[0]['object_selected'] == 2 else 0

single_rewards = np.zeros((len(good_demonstrations), len(good_demonstrations)))
mask = np.zeros((len(good_demonstrations), len(good_demonstrations)))

rewards = []
for live_idx, live_i in enumerate(good_demonstrations[0:10]):
    rew_temp = []
    for demo_idx, demo_i in enumerate(good_demonstrations[0:10]):
        if live_i == demo_i:
            continue
        save_dir = f"{save_root}/run_ss_trapeze_rP_{live_i}_{demo_i}"
        pb = PlaybackEnvServo(save_dir)
        filtered_reward = filter_demo(pb)
        rew_temp.append(filtered_reward)
        
        mask[live_idx, demo_idx] = filtered_reward
        single_rewards[live_idx, demo_idx] = pb[-1].data['rew']
#         rewards.append(filtered_reward)
    rewards.append(rew_temp)
print(np.mean(rewards))
        

In [None]:
rew_mean = np.mean(rewards, axis=0)
np.std(rew_mean)

In [None]:
rew = np.load("../tmp/single_demo_rewards_ss_rP.npz")['arr_0']
rew_mean = np.mean(rew, axis=0)
np.std(rew_mean)

In [None]:
# import numpy as np
# rew = np.array([[-1., -1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,
#          0.,  1.,  1.,  0.,  1.,  1.,  1.],
#        [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
#         -1., -1., -1., -1., -1., -1., -1.],
#        [ 1., -1., -1.,  0.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  0.,  1.,
#          1.,  0.,  0.,  0.,  0.,  1.,  0.],
#        [ 0., -1.,  0., -1.,  0.,  0.,  1.,  1.,  0.,  0.,  1.,  1.,  0.,
#          1.,  1.,  1.,  0.,  1.,  1.,  1.],
#        [ 1., -1.,  0.,  0., -1.,  0.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,
#          1.,  0.,  1.,  0.,  1.,  1.,  1.],
#        [ 1., -1.,  1.,  0.,  0., -1.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,
#          1.,  1.,  1.,  0.,  1.,  1.,  0.],
#        [ 0., -1.,  1.,  0.,  1.,  0., -1.,  0.,  1.,  0.,  1.,  1.,  1.,
#          1.,  0.,  1.,  0.,  1.,  1.,  1.],
#        [ 0., -1.,  1.,  0.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  0.,  1.,
#          1.,  0.,  1.,  0.,  1.,  1.,  1.],
#        [ 1., -1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  0.,  0.,  0.,  1.,
#          1.,  0.,  1.,  0.,  1.,  0.,  1.],
#        [ 0., -1.,  1.,  0.,  1.,  1.,  1.,  0.,  1., -1.,  1.,  0.,  0.,
#          1.,  0.,  1.,  0.,  1.,  1.,  0.],
#        [ 1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,
#          1.,  1.,  0.,  0.,  1.,  1.,  0.],
#        [ 0., -1.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  1., -1.,  1.,
#          0.,  1.,  1.,  1.,  1.,  1.,  1.],
#        [ 1., -1.,  1.,  0.,  0.,  1.,  0.,  1.,  1.,  1.,  0.,  0., -1.,
#          1.,  0.,  0.,  0.,  1.,  0.,  1.],
#        [ 1., -1.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,
#         -1.,  0.,  1.,  0.,  1.,  1.,  1.],
#        [ 0., -1.,  1.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,
#          1., -1.,  1.,  1.,  1.,  1.,  1.],
#        [ 0., -1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,
#          1.,  0., -1.,  0.,  1.,  1.,  1.],
#        [ 0., -1.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,
#          1.,  0.,  1., -1.,  1.,  1.,  1.],
#        [ 0., -1.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,
#          1.,  1.,  1.,  0., -1.,  1.,  1.],
#        [ 0., -1.,  1.,  1.,  1.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,
#          0.,  0.,  1.,  1.,  1., -1.,  0.],
#        [ 1., -1.,  1.,  0.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,
#          1.,  1.,  1.,  0.,  1.,  1., -1.]])

In [None]:
np.mean(rew_zero_one)

In [None]:
rew_mean = np.mean(rew_zero_one, axis=0)

In [None]:
np.std(rew_zero_one)