# Multi-Object Live Selection
1. Use already computed scores stored as npz files
2. For each part, compute scores at the front (current_rgb vs first frame of the part)
3. Compute full trahectory from current state to the goal state

In [None]:
import os
import copy
import time
import json
import shutil
import unittest
import subprocess
from pathlib import Path
import numpy as np

from scipy.spatial.transform import Rotation as R

from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.demo.demo_episode_recorder import record_sim
from flow_control.runner import evaluate_control
from flow_control.servoing.module import ServoingModule
from flow_control.servoing.playback_env_servo import PlaybackEnvServo
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout
import seaborn as sns
from tqdm import tqdm

%matplotlib inline

root_dir = "../tmp/recom"

In [None]:
recordings = sorted([os.path.join(root_dir, rec) for rec in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, rec))])

# Load the demonstration episodes
playbacks = [PlaybackEnvServo(rec) for rec in recordings[:]]

# Plot the demonstrations
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(demo_index, frame_index):
    image = playbacks[demo_index][frame_index].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    print("wp_name:", playbacks[demo_index][frame_index].get_info()["wp_name"])
    fg_mask = playbacks[demo_index].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, demo_index=slider_w, frame_index=slider_i)

In [None]:
# Code to update segmentation mask for the last frame

from robot_io.recorder.simple_recorder import unprocess_seg

object_name = "Oval_Sort_Box"
info = playbacks[0][-1].get_info()
print(info.keys())
name2uid = {}
for i in range(10):
    i = str(i)
    if i not in info.keys():
        continue
    name2uid[info[i]["name"]] = info[i]["UID"]
print(name2uid)
obj_uid = int(name2uid[object_name])
srf_uid = int(name2uid["surface_red"])
seg_mask = info["seg_mask"]
seg_obj, _ = unprocess_seg(seg_mask)
seg_cmb = np.logical_or(seg_obj == obj_uid, seg_obj == srf_uid)
plt.imshow(seg_cmb)
plt.show()

# Read scores (errors) from the stored files

In [None]:
error_matrix = np.load('em_norm_old_fn.npz')['arr_0']
errors_rear = np.load('errors_rear_new_seg_mask.npz')['arr_0']

score_matrix = 1 - error_matrix
scores_rear = 1 - errors_rear

In [None]:
# Get good demonstrations and demonstration part information
def filter_demo(pb):
    return pb[-1].data['rew'] > 0 and np.mean(pb.get_fg_mask()) > 0.005

demo_good = [filter_demo(pb) for pb in playbacks]
good_demonstrations = np.where(demo_good)[0]

good_demonstrations = [int(x) for x in good_demonstrations][0:38]
live_seeds = good_demonstrations[0:39]
print(good_demonstrations)

In [None]:
# Load demo segmentation file
demo_seg_file = f'{root_dir}/demo_parts_manual2.json'
fp = open(demo_seg_file)
demo_parts = json.load(fp)
live_seeds = [int(key) for key in demo_parts.keys()]
demo_keys = demo_parts.keys()

## Compute live scores 

In [None]:
def compute_current_scores(playbacks, current_rgb, demo_parts, demo_good, traj_idx=0, live_seed=0):    
    sim_errors = np.ones(len(playbacks)) # lower is better
    mean_flows = np.zeros(len(playbacks))

    for demo_seed in good_demonstrations:
        start_idx = demo_parts[str(demo_seed)][traj_idx]['start']
        demo_rgb =  playbacks[demo_seed][start_idx].cam.get_image()[0]
        demo_mask =  playbacks[demo_seed].fg_masks[start_idx]
        error, mean_flow = similarity_from_reprojection(current_rgb, demo_rgb, demo_mask)
        sim_errors[demo_seed] = error
        mean_flows[demo_seed] = mean_flow
    errors_norm = normalize_errors(sim_errors, mean_flows, demo_good)
    scores = 1 - errors_norm
    
    return scores

def similarity_from_reprojection(live_rgb, demo_rgb, demo_mask, return_images=False):
    # evaluate the similarity via flow reprojection error
    flow = servo_module.flow_module.step(demo_rgb, live_rgb)
    warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
    diff = (warped - (demo_rgb / 255.0))
    error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
    error = error.sum() / demo_mask.sum()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    if return_images:
        return error, mean_flow, flow, warped
    return error, mean_flow

from sklearn.preprocessing import minmax_scale

def normalize_errors(errors, flows, demo_good):
    errors_l = errors[demo_good]
    mean_flows_l = flows[demo_good]
    errors_norm = np.ones(errors.shape)
    w = .5
    errors_norm[demo_good] = np.mean((1*minmax_scale(errors_l), w*minmax_scale(mean_flows_l)),axis=0)/(1+w)
    return errors_norm

# Servo Module
# Load Servoing Module
from flow_control.servoing.module import ServoingModule
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
servo_module = ServoingModule(recordings[0], control_config=control_config,
                              start_paused=False)

In [None]:
import ipdb
def split_keypoints(pb, demo_part):
    demo_keep = sorted(list(pb.keep_dict.keys()))
    keep_all = copy.copy(pb.keep_dict)
    keep_parts = {}
    for p in demo_part:
        if p["start"] == 0:
            p_start = -1
        else:
            p_start = p["start"]

        parts = []
        for demo_index in demo_keep:
            if p_start < demo_index and p["end"] >= demo_index:
                parts.append(demo_index)   

        keep_parts[p["name"]] = parts
    keep_parts['locate'].append(keep_parts['insert'][0])
    return keep_parts

keypoint_info = {}
for demo_seed in good_demonstrations:
    keypoint_info[demo_seed] = split_keypoints(playbacks[demo_seed], demo_parts[str(demo_seed)])

keypoint_info

## Search for best trajectory 

In [None]:
def get_best_segments(score_front, score_rear, score_matrix, score_fn='min'):
    x, y = score_matrix.shape
    best_score_fn = {'score': 0.0, 'idx1': -1, 'idx2': -1}

    if score_fn == 'min':
        idx1 = np.argmax(score_front)
        idx2 = np.argmax(score_rear)
        best_score_fn['idx1'] = idx1
        best_score_fn['idx2'] = idx2

        return best_score_fn

    for i in range(x):
        for j in range(y):
            total_score_fn = 0.0
            if score_fn == 'sum':
                total_score_fn = score_matrix[i][j] + score_front[i] + score_rear[j]
            elif score_fn == 'prod':
                total_score_fn = score_front[i] * score_matrix[i][j] * score_rear[j]

            if total_score_fn > best_score_fn['score']:
                best_score_fn['score'] = total_score_fn
                best_score_fn['idx1'] = i
                best_score_fn['idx2'] = j

    return best_score_fn

In [None]:
from math import pi
from flow_control.servoing.module import ServoingModule
from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.runner import evaluate_control
import ipdb
import cv2

save_root = '../tmp/recom_run_new_mask_oval'
selected_recordings = {}
def eval_cmb(playbacks, demo_good, live_seed, demo_parts, keypoint_info, exist_ok=False, best_idx=None):
    renderer = 'debug'
    task_variant = 'rP'
    task = 'pick_n_place'
    object_selected = 'oval'
    
    # Instantiate env
    env = RobotSimEnv(task='recombination', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=500, control='absolute-full',
                      img_size=(256, 256), param_randomize=("geom",),
                      param_info={'object_selected': object_selected, 'task_selected': task},
                      task_info=dict(object_rot_range={"rP":pi/2.,"rR":pi/6.}[task_variant]),
                      seed=int(live_seed))
    
    traj_map = {0: 'locate', 1: 'insert'}
    
    save_dir = None
    reward = 0
    for idx in range(2):
        
        state, _, _, _ = env.step(None)
        current_rgb = state['rgb_gripper']

        if best_idx is None:
            scores = compute_current_scores(playbacks, current_rgb, demo_parts, demo_good, traj_idx=idx, live_seed=live_seed)
        
            sm = score_matrix
            if idx == 1:
                sm = np.ones_like(score_matrix)
            best_segments = get_best_segments(scores, scores_rear, sm, score_fn='sum')
            best_idx = best_segments['idx1']
#         return scores, current_rgb, best_segments
        best_demo = recordings[best_idx]
        kp_info = keypoint_info[best_idx]
        kps = kp_info[traj_map[idx]]
        
        if idx == 0:
            selected_recordings[live_seed] = [best_demo]
        
        if idx == 1:
            save_dir = f"{save_root}/run_pnp_{live_seed}_{best_idx}"
            folder_idx = 1
            updated_dir = save_dir
            while os.path.isdir(updated_dir):
                updated_dir = f"{save_dir}_{folder_idx}"
                folder_idx += 1
            
            save_dir = updated_dir
            selected_recordings[live_seed].append(best_demo)
        servo_module = ServoingModule(best_demo, control_config=control_config,
                                      start_paused=False, plot=False, plot_save_dir=None,
                                      load='select', selected_kp=kps)
        _, reward, _, info = evaluate_control(env, servo_module, max_steps=130, save_dir=save_dir,
                                             initial_align=True if idx == 0 else False)
        
        best_idx = None
    del env
    del servo_module
    return reward

In [None]:
# Get Servoing Rewards
num_live_seeds = 20
rewards = []
for live_idx, live_seed in enumerate(range(20)):
    rew = eval_cmb(playbacks, demo_good, live_seed, demo_parts, keypoint_info)
    rewards.append(rew)

In [None]:
# Trapeze: [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 0.9(semicircle once)
# Oval: [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0] 0.8
print(rewards, np.mean(rewards))

In [None]:
# num_live_seeds = 20
# rewards = []
# ##
# scores_all_live_seeds = []
# live_imgs = []
# best_segments = []
# for live_idx, live_seed in enumerate(range(20)):
#     scores, live_img, best_seg = eval_cmb(playbacks, demo_good, live_seed, demo_parts, keypoint_info)
#     scores_all_live_seeds.append(scores)
#     live_imgs.append(live_img)
#     best_segments.append(best_seg)

In [None]:
# Plot the demonstrations
# %matplotlib notebook
# fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(8, 6))
# fig.suptitle("Demonstration Frames")
# ax1.set_axis_off()
# ax2.set_axis_off()
# image_h = ax1.imshow(playbacks[0].cam.get_image()[0])
# image_live = ax2.imshow(live_imgs[0])

# def update(demo_index, live_index, frame_index):
#     image = playbacks[demo_index][frame_index].cam.get_image()[0]
#     image_h.set_data(image)
#     fig.canvas.draw_idle()
#     image_live.set_data(live_imgs[live_index])
#     print(f"Score: {scores_all_live_seeds[live_index][demo_index]}")
#     print(f"Best Indices: {best_segments[live_index]['idx1']}, {best_segments[live_index]['idx2']}")
#     print(f"Max Score at demo_index = {np.argmax(scores_all_live_seeds[live_index])}")
#     print(f"Demonstration: {recordings[demo_index]}")
    
# slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
#                              layout=Layout(width='70%'))
# slider_i = widgets.IntSlider(min=0, max=20, step=1, value=0,
#                              layout=Layout(width='70%'))
# slider_demo = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
#                              layout=Layout(width='70%'))

# interact(update, demo_index=slider_w, live_index=slider_i, frame_index= slider_demo)

### Score matrix plot

In [None]:
# Plot the demonstrations
# %matplotlib notebook
# fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5,figsize=(8, 6))
# fig.suptitle("Demonstration Frames")
# ax1.set_axis_off()
# ax2.set_axis_off()
# ax3.set_axis_off()
# ax4.set_axis_off()
# ax5.set_axis_off()

# image_h1 = ax1.imshow(playbacks[0][demo_parts['0'][0]['end']].cam.get_image()[0])
# image_h2 = ax2.imshow(playbacks[0][demo_parts['0'][1]['start']].cam.get_image()[0])
# image_h3 = ax3.imshow(playbacks[0].fg_masks[demo_parts['0'][1]['start']])
# image_h4 = ax4.imshow(np.zeros((256, 256, 3)))
# image_h5 = ax5.imshow(np.zeros((256, 256, 3)))

# def update(demo_idx1, demo_idx2):
#     if str(demo_idx1) in demo_parts:
#         image_1 = playbacks[demo_idx1][demo_parts[str(demo_idx1)][0]['end']].cam.get_image()[0]
#     else:
#         image_1 = np.zeros((256, 256, 3))
    
#     if str(demo_idx2) in demo_parts:
#         image_2 = playbacks[demo_idx2][demo_parts[str(demo_idx2)][1]['start']].cam.get_image()[0]
#         image_3 = playbacks[demo_idx2].fg_masks[demo_parts[str(demo_idx2)][1]['start']]
#     else:
#         image_2 = np.zeros((256, 256, 3))
#         image_3 = np.zeros((256, 256, 3))
        
#     error, mean_flow, flow, warped, diff_img_hsv = similarity_from_reprojection(image_1, image_2, image_3, True)
# #     print(diff_img_hsv)
#     image_h1.set_data(image_1)
#     fig.canvas.draw_idle()
#     image_h2.set_data(image_2)
#     image_h3.set_data(image_3)
#     image_h4.set_data(warped)
#     image_h5.set_data(diff_img_hsv)
    
#     print(f"Score: {1 - error}")
    
# slider_d1 = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
#                              layout=Layout(width='70%'))
# slider_d2 = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
#                              layout=Layout(width='70%'))

# interact(update, demo_idx1=slider_d1, demo_idx2=slider_d2)

## Goal image scores

In [None]:
# goal_path = '../tmp/goal_pick_n_place_oval_rR_seed100'
# goal_pl = PlaybackEnvServo(goal_path)

# goal_image = goal_pl[-1].cam.get_image()[0]
# goal_mask = goal_pl.fg_masks[-1]

# from robot_io.recorder.simple_recorder import unprocess_seg

# object_name = "Oval_Sort_Box"
# info = goal_pl[-1].get_info()
# print(info.keys())
# name2uid = {}
# for i in range(10):
#     i = str(i)
#     if i not in info.keys():
#         continue
#     name2uid[info[i]["name"]] = info[i]["UID"]
# print(name2uid)
# obj_uid = int(name2uid[object_name])
# srf_uid = int(name2uid["surface_red"])
# seg_mask = info["seg_mask"]
# seg_obj, _ = unprocess_seg(seg_mask)
# goal_mask = np.logical_or(seg_obj == obj_uid, seg_obj == srf_uid)
# plt.imshow(goal_mask)
# plt.show()

In [None]:
# Plot the demonstrations
# %matplotlib notebook
# fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5,figsize=(8, 6))
# fig.suptitle("Demonstration Frames")
# ax1.set_axis_off()
# ax2.set_axis_off()
# ax3.set_axis_off()
# ax4.set_axis_off()
# ax5.set_axis_off()

# goal_scores = []

# image_h1 = ax1.imshow(goal_image)
# image_h2 = ax2.imshow(playbacks[0][demo_parts['0'][1]['end']].cam.get_image()[0])
# image_h3 = ax3.imshow(goal_mask)
# image_h4 = ax4.imshow(np.zeros((256, 256, 3)))
# image_h5 = ax5.imshow(np.zeros((256, 256, 3)))

# def update(demo_idx1):
#     global goal_scores
    
#     if str(demo_idx1) in demo_parts:
#         image_2 = playbacks[demo_idx1][demo_parts[str(demo_idx1)][1]['end']].cam.get_image()[0]
#     else:
#         image_2 = np.zeros((256, 256, 3))
        
#     error, mean_flow, flow, warped, diff_img_hsv = similarity_from_reprojection(image_2, goal_image, goal_mask, True)
# #     print(diff_img_hsv)
#     fig.canvas.draw_idle()
#     image_h2.set_data(image_2)
#     image_h4.set_data(warped)
#     image_h5.set_data(diff_img_hsv)
    
#     print(f"Score: {1 - error}")
#     goal_scores.append(1 - error)
    
# slider_d1 = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
#                              layout=Layout(width='70%'))
# slider_d2 = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
#                              layout=Layout(width='70%'))

# interact(update, demo_idx1=slider_d1, demo_idx2=slider_d2)