In [None]:
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout

from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.recombination.record_multi import get_configurations
from flow_control.servoing.module import ServoingModule
from flow_control.runner import evaluate_control

renderer = "debug"
root_dir = "/tmp/flow_experiments3"
num_episodes = 20
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.35)

demo_cfgs = get_configurations(prefix="demo", num_episodes=num_episodes)
recordings = []
for _, _, _, demo_seed, demo_dir in demo_cfgs:
    recordings.append(demo_dir)
    
def eval_cmb(live_seed, demo_seed, exist_ok=False):
    save_dir = f"{root_dir}_{demo_seed:03d}"
    demo_dir = recordings[demo_seed]
    if Path(save_dir).is_dir():
        if exist_ok:
            shutil.rmtree(save_dir)
        else:
            raise ValueError
    servo_module = ServoingModule(demo_dir, control_config=control_config,
                                  start_paused=False, plot=False, plot_save_dir=None)
    env = RobotSimEnv(task='shape_sorting', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=500, control='absolute-full',
                      img_size=(256, 256), param_randomize=("geom",), seed=live_seed)
    _, reward, _, info = evaluate_control(env, servo_module,
                                          max_steps=130,
                                          save_dir=save_dir)
    return reward, save_dir

In [None]:
#live_seed, demo_seed = 3, 19
live_seed, demo_seed = 5, 12
reward, save_dir = eval_cmb(live_seed, demo_seed, exist_ok=True)
print(f"live_i {live_seed} demo_seed {demo_seed} -> reward {reward}")

# Evaluate Multiple Combination

For this copy over the data from the `Multi_Demo_Viewer.ipynb` notebook.

In [None]:
live_episodes = [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 16, 18, 19]
demo_episodes = [19,  8, 12, 19,  3, 16, 12,  3, 14, 16,  9, 16,  7, 10,  7]
rewards = []

for live_i, demo_i in zip(live_episodes, demo_episodes):
    reward, _ = eval_cmb(live_i, demo_i, exist_ok=True)
    rewards.append(reward)

In [None]:
import numpy as np
print("live ", np.array(live_episodes))
print("demo ", np.array(demo_episodes))
print("rew  ", np.array(rewards, dtype=int))
print("mean:", np.mean(rewards))

print()
print("failures:")
print("live", np.array(live_episodes)[np.logical_not(rewards)])
print("demo", np.array(demo_episodes)[np.logical_not(rewards)])

In [None]:
# Load the demonstration episodes
from flow_control.servoing.playback_env_servo import PlaybackEnvServo
playbacks = [PlaybackEnvServo(save_dir)]

In [None]:
# Plot the demonstrations
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Servoing Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(run_index, frame_index):
    image = playbacks[run_index][frame_index].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    #print("wp_name:", playbacks[run_index][frame_index].get_info()["wp_name"])
    fg_mask = playbacks[run_index].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)
    print(playbacks[run_index])
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, run_index=slider_w, frame_index=slider_i)

In [None]:
from collections import defaultdict
print("num_steps", len(playbacks[0]))

servo_keys = ["loss", "demo_index", "threshold"]
servo_list = defaultdict(list)
for i in range(len(playbacks[0])):
    info = playbacks[0][i].get_info()
    for key in servo_keys:
        try:
            servo_list[key].append(info[key])
        except KeyError:
            pass

fig, ax = plt.subplots(1, figsize=(8, 6))
ax2 = ax.twinx()
ax.plot(servo_list["loss"], label="loss")
ax2.plot(servo_list["demo_index"],color="orange", label="demo_index")
ax.plot(servo_list["threshold"],color="k", label="threshold")
plt.legend()
plt.show()