# Demonstration Recording Script

This is a script for recording multiple demonstrations based on an environment.

It has some special functions that allow re-orientation of the objects, this is done with the `update_object_orn` funciton, which gets and modifies the object pose.

TODOs:
1. Make the rotated grasping policy better (it touches the object too often)
2. Combine both sources of demonstrations and make sure the correct one is selected.

In [None]:
import os
import time
import json
import shutil
import unittest
import subprocess
from pathlib import Path

from scipy.spatial.transform import Rotation as R

from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.demo.demo_episode_recorder import record_sim
from flow_control.runner import evaluate_control
from flow_control.servoing.module import ServoingModule

renderer = "debug"

def get_configurations(root_dir="/tmp/flow_experiments4", num_episodes=10, prefix=""):
    task="shape_sorting"
    object_selected = "trapeze"
    #object_selected = "semicircle"
    #object_selected = "oval"

    orn_options = dict(
        #rR=None  # rotation is randomized
        #rN=R.from_euler("xyz", (0, 0, 0), degrees=True).as_quat(),
        #rZ=R.from_euler("xyz", (0, 0, 20), degrees=True).as_quat(),
        rY=R.from_euler("xyz", (0, 90, 0), degrees=True).as_quat(),
        #rX=R.from_euler("xyz", (90, 0, 0), degrees=True).as_quat(),
        #rXZ=R.from_euler("xyz", (180, 0, 160), degrees=True).as_quat()
        )

    os.makedirs(root_dir, exist_ok=True)

    if prefix == "":
        save_dir_template = os.path.join(root_dir, f"{task}_{object_selected}")
    else:
        save_dir_template = os.path.join(root_dir, f"{prefix}_{task}_{object_selected}")

    for seed in range(num_episodes):
        for orn_name, orn in orn_options.items():
            save_dir = save_dir_template + f"_{orn_name}"+f"_seed{seed:03d}"
            yield object_selected, orn_name, orn, seed, save_dir

In [None]:
def update_object_orn(env, object_selected):
    object_uid = env._task.object_uids[object_selected]
    object_pos, object_orn = env.p.getBasePositionAndOrientation(object_uid)
    print("pos & orn", object_pos, object_orn)
    object_orn = orn
    
    env.p.resetBasePositionAndOrientation(object_uid, object_pos, object_orn,
                                           physicsClientId=env.cid)

    object_pos, object_orn = env.p.getBasePositionAndOrientation(object_uid)
    print("pos & orn2", object_pos, object_orn)


demo_cfgs = get_configurations(num_episodes=10, prefix="demo")
for object_selected, orn_name, orn, seed, save_dir in demo_cfgs:
    param_info={"object_selected": object_selected}
    
    env = RobotSimEnv(task='shape_sorting', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=200, control='absolute-full',
                      img_size=(256, 256),
                      param_randomize=("geom",),
                      param_info=param_info,
                      seed=seed)
    
    update_object_orn(env, object_selected)
        
    if os.path.isdir(save_dir):
        # lsof file if there are NSF issues.
        shutil.rmtree(save_dir)
    record_sim(env, save_dir)
    
    del env
    time.sleep(.5)
    print(save_dir)

## TODO for Today: 
1. View multiple demonstrations
2. Pick a few good ones.

In [None]:
from flow_control.servoing.playback_env_servo import PlaybackEnvServo

demo_cfgs = get_configurations(num_episodes=10, prefix="demo")
recordings = []
for _, _, _, demo_seed, demo_dir in demo_cfgs:
    recordings.append(demo_dir)
    
print("Number of recordings:", len(recordings))
print(recordings[0])
print(recordings[-1])
# Load the demonstration episodes
playbacks = [PlaybackEnvServo(rec) for rec in recordings[:]]

In [None]:
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout

# Plot the demonstrations
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(demo_index, frame_index):
    image = playbacks[demo_index][frame_index].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    print("wp_name:", playbacks[demo_index][frame_index].get_info()["wp_name"])
    fg_mask = playbacks[demo_index].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, demo_index=slider_w, frame_index=slider_i)

In [None]:
good_demonstrations = [5, 6, 7]  # maybe 9 too...

## Now show what the image at the end of each block looks like

In [None]:
def get_block_name(wp_name):
    return wp_name.split("_")[0]

block_ends = {}

demo_index = 6
prev_block_name = None

for demo_index in good_demonstrations:
    block_set = {}
    for frame_index in range(len(playbacks[demo_index])-1):
        wp_name = playbacks[demo_index][frame_index].get_info()["wp_name"]
        next_name = playbacks[demo_index][frame_index+1].get_info()["wp_name"]
        block_name = get_block_name(wp_name)
        next_block_name = get_block_name(next_name)
        if block_name != next_block_name:
            print("wp_name:", wp_name, "@", frame_index)
            block_set[block_name] = frame_index
            #image = playbacks[demo_index][frame_index-1].cam.get_image()[0]
            #fig, ax = plt.subplots(1,figsize=(8, 6))
            #fig.suptitle(f"Demo: {block_name}")
            #ax.imshow(image)
            #plt.show()
        prev_block_name = block_name
    block_ends[demo_index] = block_set
    print()

In [None]:
# Plot the end of segments
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(w, i):
    demo_index = list(block_ends.keys())[w]
    frame_index = list(block_ends[demo_index].values())[i]
    
    
    image = playbacks[demo_index][frame_index].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    print("wp_name:", playbacks[demo_index][frame_index].get_info()["wp_name"])
    fg_mask = playbacks[demo_index].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)
    
slider_w = widgets.IntSlider(min=0, max=len(block_ends)-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=max_frames-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, w=slider_w, i=slider_i)

### Check if the block size is still randomized