
# Demonstration Viewer

This notebook is for viewing and processing demonstrations so that they can be used for servoing. This has two main components. The first is subsampling the trajectory to obtain a set of keysteps for which we want to do servoing. It is not necessary to servo to all intermediate states, this increases the speed of servoing dramatically.

The second is computing a foregreound segmentation mask. This is needed in order to be able to focus on the objects of interested and to avoid being confounded by various other objects.

Script Arguments:

    recording: the directory in which the recording is located. Should include:
        `frame_000000.npz`
        `camera_info.npz`
   
Returns:
    `servo_keep.json`
    `servo_mask.json`
    

# Setup

First we start by loading the demonstration.

In [None]:
import os
import numpy as np
import logging

def is_notebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True  # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interpreter

interactive = is_notebook()  # becomes overwritten
if interactive:
    get_ipython().run_line_magic('matplotlib', 'notebook')
    from ipywidgets import widgets, interact, Layout
    import matplotlib.pyplot as plt

# Load Demonstration

In [None]:
import copy
import json
from tqdm import tqdm

from robot_io.envs.playback_env import PlaybackEnv
from robot_io.recorder.simple_recorder import unprocess_seg

if interactive:    
    # set parameters here
    #recording = "../tmp_test/pick_n_place/"
    #recording = "../tmp_test/shape_sorting_trapeze_rN/"
    recording = "/home/argusm/CLUSTER/robot_recordings/flow/sick_wtt/16-51-30"
else:
    # expect commandline input
    import sys
    if len(sys.argv) != 3:
        print("Usage: Demonstration_Viewer.py <episode_dir>")
    recording = sys.argv[1]

if not os.path.isdir(recording):
    ValueError(f"Recording directory not found: {recording}")

segment_conf_fn = os.path.join(recording, "segment_conf.json")
keep_fn = os.path.join(recording, f"servo_keep.json")
mask_fn = os.path.join(recording, f"servo_mask.npz")

try:
    with open(segment_conf_fn, "r") as f_obj:
        orig_seg_conf = json.load(f_obj)
        if type(orig_seg_conf["objects"]) == list:
            orig_seg_conf["objects"] = orig_seg_conf["objects"][0]
        seg_conf = copy.deepcopy(orig_seg_conf)
except FileNotFoundError:
    seg_conf = None

In [None]:
rec = PlaybackEnv(recording).to_list()
video_recording = np.array([renv.cam.get_image()[0] for renv in rec])
actions = np.array([renv.get_action()["motion"] for renv in rec], dtype=object)
tcp_pos = np.array([renv.robot.get_tcp_pos() for renv in rec])
tcp_orn = np.array([renv.robot.get_tcp_orn() for renv in rec])
gripper_width = np.array([renv.robot.gripper.width() for renv in rec])

if actions.ndim == 2 and actions.shape[1] == 3:
    gripper_actions = np.array(actions[:, 2], dtype=float)
else:
    gripper_actions = np.array(actions[:, -1], dtype=float)

assert gripper_width.ndim == 1
gripper_width = np.array(gripper_width, dtype=float)
percent_invalid = np.sum(np.isnan(gripper_width))/len(gripper_width)
if percent_invalid > .1:
    logging.warning("gripper positions are not valid")
    gripper_width_valid = False
else:
    gripper_width_valid = True


masks_sim_l = []
move_anchors = []
wp_names = []
for rec_el in rec:
    rec_info = rec_el.data["info"].item()
    
    if "seg_mask" in rec_info:
        tmp = unprocess_seg(rec_info["seg_mask"])[0]
        masks_sim_l.append(tmp)
    else:
        masks_sim_l.append(None)
    
    if "move_anchor" in rec_info:
        move_anchors.append(rec_info["move_anchor"])
    else:
        move_anchors.append(None)
    
    if "wp_name" in rec_info:
        wp_names.append(rec_info["wp_name"])
    else:
        wp_names.append("")

masks_sim = np.array(masks_sim_l)
if np.all(masks_sim == None):
    masks_sim = None

num_frames = len(rec)
max_frame = num_frames-1

print(f"loaded segmentation masks {np.sum([s is not None for s in masks_sim_l])}/{num_frames}" )
print(f"loaded move_anchors {np.sum([a is not None for a in move_anchors])}/{num_frames}" )
print(f"loaded waypoint names {np.sum([wp !='' for wp in wp_names])}/{num_frames}")
    
print("loaded.")

# 1. Compute Keep Steps 

Decide which frames to keep, this information is saved as a dictionary whos keys are the frame index.

There are several possible sources of information to use when deciding which frames to keep.

1. The waypoint names. If these are provided they show when a motion segment ends.
2. Gripper action and state. Segment by the state of the gripper, also keep only those where gripper is stable.
3. TCP motion. Find those frames where movement is minimal.

In [None]:
# use actions here instead of state position recordings as these
# are more direct and reliable, but in general try to use states
# as these are less suceptible to problems
gripper_change_steps = np.where(np.diff(gripper_actions))[0].tolist()

# divide sequence into steps, defined by gripper action
segment_steps = np.zeros(num_frames)
segment_steps[np.array(gripper_change_steps)+1] = 1
segment_steps = np.cumsum(segment_steps).astype(int)

In [None]:
if recording == "/home/argusm/CLUSTER/robot_recordings/flow/sick_wtt/16-51-30":
    keep_manual = {580:dict(name="manual-1"), 699:dict(name="manual-2")}
else:
    keep_manual = {}

In [None]:
from demo_trajectory_utils import get_demo_continous 
from demo_trajectory_utils import get_keep_from_wpnames, get_keep_from_gripper, get_keep_from_motion
from demo_trajectory_utils import filter_by_move_anchors, filter_by_motions, check_names_grip
from demo_trajectory_utils import set_trajectory_actions
from demo_trajectory_utils import get_rel_motion, get_servo_anchors


# module reloading for when updating demo_trajectory_utils functions
from importlib import reload
import demo_trajectory_utils
reload(demo_trajectory_utils)
from demo_trajectory_utils import set_trajectory_actions

trajectory_debug_plots = False

is_continous = get_demo_continous(tcp_pos)
keep_wpnames = get_keep_from_wpnames(wp_names)
keep_edge = {0: dict(name="demo_start"), int(max_frame):dict(name="demo_end")}
keep_motion = {}
keep_cmb = copy.copy(keep_edge)

if keep_wpnames:
    logging.info("Trajectory segmentation method: waypoint-names")
    # if all our waypoints have names, we can use these to segment
    # the trajectory into individual steps.
    check_names_grip(wp_names, gripper_change_steps)
    keep_gripper = get_keep_from_gripper(gripper_actions)
    if move_anchors:
        filter_rel = [a == "rel" for a in move_anchors]
        #filter_by_move_anchors(keep_wpnames, wp_names, filter_rel)
    keep_cmb.update(keep_wpnames)
    keep_cmb.update(keep_gripper)
    motion_threshold = 0.001

elif not is_continous:
    logging.info("Trajectory segmentation method: discrete-demo")
    # if we have a hand-crafted recording, which includes only steps for
    # servoing
    keep_all = list(range(len(tcp_pos)))
    keep_cmb.update(keep_all)
    keep_cmb.update(keep_gripper)
    motion_threshold = 0.001
else:
    logging.info("Trajectory segmentation method: motion-cues")
    # this case is a bit more complicated, we can use heuristics like low velocity
    # to extract segment our trajectory.
    keep_motion  = get_keep_from_motion(tcp_pos)
    keep_gripper = get_keep_from_gripper(gripper_actions)
    
    # int(k) for json
    keep_motion = dict([(int(k), dict(name=f"motion-{i}")) for i, k in enumerate(keep_motion)])
    keep_cmb.update(keep_motion)
    keep_cmb.update(keep_gripper)
    motion_threshold = 0.02

keep_cmb = {k: keep_cmb[k] for k in sorted(keep_cmb)}

# postprocess keep frames    
filter_by_motions(keep_cmb, tcp_pos, tcp_orn, gripper_actions, threshold=motion_threshold)

# after filtering do manual overrides
keep_cmb.update(keep_manual)
keep_cmb = {k: keep_cmb[k] for k in sorted(keep_cmb)}

# also sets grip_dist
servo_anchors = get_servo_anchors(move_anchors)

if keep_wpnames:
    for k in keep_cmb:
        if servo_anchors[k] == -1:
            keep_cmb[k]["skip"] = True
        else:
            keep_cmb[k]["skip"] = False
else:
    for k in keep_cmb:
        keep_cmb[k]["skip"] = False

if "sick_wtt" in recording:
    grp_o_df = -1
else:
    grp_o_df = 0
set_trajectory_actions(keep_cmb, segment_steps, tcp_pos, tcp_orn, gripper_actions, grip_open_default=grp_o_df)

for k, v in keep_cmb.items():
    if v["grip_dist"] >= 2:
        keep_cmb[k]["skip"] = True
    else:
        keep_cmb[k]["skip"] = False

print()
for k,v in keep_cmb.items():
    print(f"{k}".ljust(5),f"{v['name']}".ljust(15),
          f"grip_dist={v['grip_dist']}".ljust(15), f"skip={int(v['skip'])}",f"pre={len(v['pre'])}")
    
# save keep frames
with open(keep_fn, 'w') as outfile:
    json.dump(keep_cmb, outfile)
print("Saved to", keep_fn)

#plot_motion_error()

#TOD(max): redo fitler_keep in a iterative greedy manner.

In [None]:
def keep2plot(keep):
    arr = np.zeros(len(actions))
    arr[list(keep.keys())] = 1
    return arr

if interactive:
    fig, (ax, ax2) = plt.subplots(2, 1)
    fig.suptitle("Keep frame Components")
    line = ax.imshow(video_recording[0])
    ax.set_axis_off()
    #ax2.plot(gripper_width*10, label="grip raw")
    ax2.plot((gripper_actions+1)/2,"--", label="gripper action")
    ax2.plot(segment_steps/10, label="steps")
    if keep_wpnames:
        ax2.plot(keep2plot(keep_wpnames), label="keep_wpnames")
    if keep_motion:
        ax2.plot(keep2plot(keep_motion), label="keep_motion")
    ax2.plot(keep2plot(keep_gripper), label="keep_gripper")
    ax2.plot(keep2plot(keep_edge), label="keep_edge")
    
    if servo_anchors:
        ax2.plot([a == -1 for a in servo_anchors], "--", label="rel")
    
    ax2.set_ylabel("value")
    ax2.set_xlabel("frame number")
    vline = ax2.axvline(x=2, color="k")
    ax2.legend()

    def update(w):
        if wp_names:
            print("wp_name:",wp_names[w])
        if w in keep_cmb:
            print(keep_cmb[w])
            print("pos:", tcp_pos[w],"orn:", tcp_orn[w])
        vline.set_data([w, w], [0, 1])
        line.set_data(video_recording[w])
        fig.canvas.draw_idle()
    slider_w = widgets.IntSlider(min=0, max=max_frame, step=1, value=0,
                                 layout=Layout(width='70%'))
    interact(update, w=slider_w)

### 1.1 TCP Stationary Filter

There are to options here. The demonstration is recorded in a sparse way, this indicates that every frame should be kept, or the demonstration is recorded in a dense way, meaning that we need to figure out which frames we want to keep.

In the case of dense demonstrations, a good heuristic to use is to use transitions in which there is slow robot motion indicates motion to a stable position.

In [None]:
from trajectory_plots import plot_tcp_stationary

if interactive and trajectory_debug_plots:
    plot_tcp_stationary(tcp_pos, video_recording)

Show when the gripping is done, depending on gripper motion.

In [None]:
from trajectory_plots import plot_gripper_stable 
if interactive and gripper_width_valid and trajectory_debug_plots:
    plot_gripper_stable(gripper_width, gripper_actions, video_recording)

## 1. C. Verify keep frames

In [None]:
def keep2plot2(keep):
    arr = np.zeros(len(actions))
    for k,v in keep.items():
        if v["grip_dist"] < 2:
            arr[k] = 1.0
        else:
            arr[k] = 0.8
        if "anchor" in v and v["anchor"] == "rel":
            arr[k] = 0.1
        #arr[k] = 1.0 - min(v["grip_dist"], 5) * 0.1
    return arr

if interactive:
    fig, (ax, ax2) = plt.subplots(2, 1)
    fig.suptitle("Verify keep frames")
    line = ax.imshow(video_recording[0])
    ax.set_axis_off()
    ax2.plot(gripper_width*10, label="grip raw")
    ax2.plot(segment_steps/10, label="steps")
    ax2.plot(keep2plot2(keep_cmb), label="keep")
    #ax2.plot((gripper_actions+1)/2, label="gripper action")
    if servo_anchors:
        ax2.plot([a == -1 for a in servo_anchors], "--", label="rel")

    ax2.set_ylabel("value")
    ax2.set_xlabel("frame number")
    vline = ax2.axvline(x=2, color="k")
    ax2.legend()

    def update(w):
        if wp_names:
            print("frame name:", wp_names[w])
        vline.set_data([w, w], [0, 1])
        line.set_data(video_recording[w])
        fig.canvas.draw_idle()
        if w in keep_cmb:
            print(keep_cmb[w])
            print()
    slider_w = widgets.IntSlider(min=0, max=max_frame, step=1, value=0,
                                 layout=Layout(width='70%'))
    interact(update, w=slider_w)

# 2. Compute Mask from Color Images

Mask out the foreground object so that foreground specific flow can be calculated.

In [None]:
seg_conf_manual = ("sick_wtt" in segment_conf_fn)

if "sick_wtt/16-51-30" in segment_conf_fn:
    blue_threshold = 0.77
    conf_objects = dict(
    blue_block=[{'name': 'color', 'color': [0, 0, 1], 'threshold': blue_threshold},
                {'name': 'center'}],
    white_block=[{'name': 'color', 'color': "keep_black", 'threshold': .38},
                {'name': 'center'}])
    conf_sequence = ("blue_block", "white_block", "white_block")
else:
    blue_threshold = 0.65
    conf_objects = dict(
    blue_block=[{'name': 'color', 'color': [0, 0, 1], 'threshold': blue_threshold},
                {'name': 'center'}])
    conf_sequence = ("blue_block", "blue_block", "blue_block")

    
seg_conf_m = dict(objects=conf_objects, sequence=conf_sequence)

if seg_conf_manual:
    if seg_conf is not None:
        print(f"Overloading color segmentation config with local values:\n{segment_conf_fn}")
    seg_conf = seg_conf_m
    orig_seg_conf = copy.deepcopy(seg_conf_m)
    
if seg_conf is None:
    print(f"Skipping color segmentation, config file not found:\n{segment_conf_fn}")

## 2.1 Compute Mask from Color Images

In [None]:
from skimage import measure
from scipy import ndimage
from demo_segment_util import mask_color, erode_mask, label_mask, mask_center

# create a segmentation mask
def get_mask(frame, step_conf, depth=None):
    """
    create segmentation mask for single frame
    Args:
        frame: input frame w x h x 3 [0,255] array
        i: index of frame, for indexing parameters
        threshold: threshold for color
        
    Returns:
        mask: binary numpy array, with True == keep
    """
    threshold = step_conf[0]["threshold"]    
    image = frame.copy()
    
    for seg_option in step_conf:
        name = seg_option["name"]
        
        if name == "color":
            color_choice = seg_option["color"]
            mask = mask_color(image, color_choice=color_choice, threshold=threshold)
            
        elif name == "erode":
            mask = erode_mask(mask)
            
        elif name == "height":
            raise NotImplementedError
            depth2 = transform_depth(depth, np.linalg.inv(T_tcp_cam))
            mask2 = get_mask_depth(depth2, 600, 1550)
            mask[mask2] = True
    
        elif name == "labels":
            raise NotImplementedError
            mask = ndimage.morphology.binary_closing(mask, iterations=4)
            mask = label_mask(mask, i)
    
        elif name == "imgheight":
            height_val = seg_option["height"]
            mask[:height_val, :] = False
            
        elif name == "center":
            mask = mask_center(mask)
            
    return mask

def get_cur_mask(i):
    # mask according to current fg object
    cur_step = segment_steps[i]
    cur_obj = seg_conf["sequence"][cur_step]
    step_conf = seg_conf["objects"][cur_obj]
    mask = get_mask(video_recording[i], step_conf)
    return mask

# Plot
if seg_conf and interactive:
    print("Colored stuff is keept (mask==True)")
    print("gripper_change_steps:", gripper_change_steps)
    print("segments: ", len(seg_conf))

    fig, ax = plt.subplots(1, 1)
    line = ax.imshow(video_recording[0])
    ax.set_axis_off()
    prev_step = 0
    def update(i, t):
        cur_step = segment_steps[i]
        cur_obj = seg_conf["sequence"][cur_step]
        global prev_step
        if cur_step != prev_step:
            # don't change order here, without double checking
            saved_t = seg_conf["objects"][cur_obj][0]["threshold"]
            print(f"switching step {prev_step} -> {cur_step}, loading t={saved_t}")
            prev_step = cur_step
            slider_t.value = saved_t*100
        else:
            seg_conf["objects"][cur_obj][0]["threshold"] = t/100
            
        mask = get_cur_mask(i)
        image = video_recording[i].copy()
        image[np.logical_not(mask)] = 255, 255, 255
        line.set_data(image)
        fig.canvas.draw_idle()

    slider_i = widgets.IntSlider(min=0, max=max_frame, step=1, value=0,
                                 layout=Layout(width='70%'))
    first_obj = seg_conf["sequence"][0]
    slider_t = widgets.IntSlider(min=0, max=100, step=1, value=seg_conf["objects"][first_obj][0]["threshold"]*100,
                                 layout=Layout(width='70%'))
    interact(update, i=slider_i, t=slider_t)

In [None]:
# display changes to thresholds
if seg_conf:
    for name in seg_conf["objects"]:
        print("name:", name)
        c = seg_conf["objects"][name][0]["color"]
        t = seg_conf["objects"][name][0]["threshold"]
        t_i = orig_seg_conf["objects"][name][0]["threshold"]
        
        #, seg_option_orig in zip(seg_conf["objects"], orig_seg_conf["objects"]):
        if t != t_i:
            print("c={}, t={} / was t'={}".format(c, t, t_i))
        else:
            print("c={}, t={}".format(c, t))
        print()


In [None]:
if seg_conf:
    switch_frame = gripper_change_steps
    print("switching at:", switch_frame)
    if orig_seg_conf is None:
        orig_seg_conf = seg_conf
        
    obj_ids = {}
    obj_ids_list = []
    for i, obj in enumerate(seg_conf["objects"]):
        obj_ids[obj] = i+1
        obj_ids_list.append(i+1)
        print(f"{obj} -> {i+1}")

    fg_obj = []
    masks_list = []
    for i in tqdm(range(len(video_recording))):
        # get foreground object
        cur_step = segment_steps[i]
        cur_obj = seg_conf["sequence"][cur_step]
        fg_obj.append(obj_ids[cur_obj])

        m_masks = []
        for obj_name in seg_conf["objects"]:
            mask = get_mask(video_recording[i], seg_conf["objects"][obj_name])
            m_masks.append(mask)

        overlapp = np.sum(m_masks, axis=0) > 1
        if np.any(overlapp):
            print(f"WARNING: There is overlapp at {i}")
        masks_list.append(m_masks)

    fg_obj = np.array(fg_obj)
    assert fg_obj.ndim == 1

    masks_list = np.array(masks_list)
    masks_list = masks_list.transpose(1, 0, 2, 3).astype(np.uint8)
    obj_ids_arr = np.array(obj_ids_list).reshape(-1, 1, 1, 1)
    masks_list = obj_ids_arr*masks_list
    masks_list = masks_list.sum(axis=0)
    
    np.savez_compressed(mask_fn, mask=masks_list, fg=fg_obj)
    print("Saved to", mask_fn)
    
    servo_anchors = fg_obj.tolist()

    if seg_conf != orig_seg_conf:
        print("Warning using new segmentation config values")

### 2.1 B. Check Masks from Color with Simulation

In [None]:
if seg_conf and masks_sim is not None:
    for i in range(num_frames):
        image = video_recording[i].copy()
        mask = masks_list[i] == fg_obj[i]
        # mask segmentation mask(gt) with fg mask (computed)
        ma = np.ma.array(masks_sim[i], mask=np.logical_not(mask))
        ma_unique = np.unique(ma, return_counts=True)
        # unique is sorted by size, pick the biggest
        idx_largest = np.where(ma_unique[0])[-1][0]
        seg_id, mask_count = ma_unique[0][idx_largest], ma_unique[1][idx_largest]
        seg_count = np.sum(masks_sim[i] == seg_id)
        # test how much we segmented / how much there is
        score = mask_count / seg_count
        assert score > .9
    print("Segmentation test passed.")
else:
    print("No comparison possible.")

## 2.2 Compute Masks from Simulation

This cell extracts foreground masks from simulation recordings. It does this by looking at the recordings info variables, where a anchor object UID can be specified. This is usually done by the task policy.

The move_anchor is the object with which we are moving relative to, this is most often but not always the object of interest or the foreground object.

In [None]:
if masks_sim is not None and servo_anchors and not seg_conf:
    fg_obj_sim = servo_anchors
    np.savez_compressed(mask_fn, mask=masks_sim, fg=fg_obj_sim)
    print("Saved to", mask_fn)

In [None]:
# Verify
if masks_sim is not None and servo_anchors and interactive:
    fig, ax = plt.subplots(1)
    handle = ax.imshow(video_recording[0])
    ax.set_axis_off()
    
    def update(i):
        image = video_recording[i].copy()
        
        # this is the code used by the seroving module, so don't change.
        mask = masks_sim[i] == fg_obj_sim[i]
        
        print(round(np.mean(mask)*100), "% fg, mask shape", mask.shape)
        image[np.logical_not(mask)] = 255, 255, 255
        handle.set_data(image)
        fig.canvas.draw_idle()

    slider_i2 = widgets.IntSlider(min=0, max=max_frame, step=1, value=0,
                                 layout=Layout(width='70%'))
    interact(update, i=slider_i2)

## 2.3 Check Results

In [None]:
if  interactive:
    for k, v in keep_cmb.items():
        if v["skip"] == False:
            print(k,)
    tmp = np.load(mask_fn)
    
    fig, ax = plt.subplots(1)
    handle = ax.imshow(tmp["mask"][0] == tmp["fg"][0])
    ax.set_axis_off()
    
    def update(i):
        image = video_recording[i].copy()
        mask = tmp["mask"][i] == tmp["fg"][i]
        print(round(np.mean(mask)*100), "% fg, mask shape", mask.shape)
        image[np.logical_not(mask)] = 255, 255, 255
        handle.set_data(image)
        fig.canvas.draw_idle()

    slider_i2 = widgets.IntSlider(min=0, max=max_frame, step=1, value=0,
                                 layout=Layout(width='70%'))
    interact(update, i=slider_i2)

In [None]:
tmp = np.load(mask_fn)
for k, info in keep_cmb.items():
    mask = tmp["mask"][k] == tmp["fg"][k]
    if servo_anchors[k] == -1:
        pixels_segmented = np.sum(mask==True)
        if pixels_segmented > 0:
            logging.warning("Keyframe %s: segmentation given for relative motion.", k)
    else:
        percent_segmented = np.mean(mask==True)
        if percent_segmented < .01:
            logging.warning("Keyframe %s: low fraction of image segmented for keyframe %s", k, round(percent_segmented*100))
    
print(f"Checking {mask_fn} passed.")                

# 4. Masking based on Depth