# Computing scores for recombination
1. Compute the errors at the front, for each live seed
2. Compute the error matrix between demonstrations
3. Compute errors with respect to the goal image.

In [None]:
import os
import copy
import time
import json
import shutil
import unittest
import subprocess
from pathlib import Path
import numpy as np

from scipy.spatial.transform import Rotation as R

from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.demo.demo_episode_recorder import record_sim
from flow_control.runner import evaluate_control
from flow_control.servoing.module import ServoingModule
from flow_control.servoing.playback_env_servo import PlaybackEnvServo
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout
import seaborn as sns
from tqdm import tqdm


%matplotlib inline

root_dir = "../tmp/recom"

In [None]:
recordings = sorted([os.path.join(root_dir, rec) for rec in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, rec))])

# Load the demonstration episodes
playbacks = [PlaybackEnvServo(rec) for rec in recordings[:]]

# Plot the demonstrations
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(demo_index, frame_index):
    image = playbacks[demo_index][frame_index].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    print("wp_name:", playbacks[demo_index][frame_index].get_info()["wp_name"])
    fg_mask = playbacks[demo_index].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, demo_index=slider_w, frame_index=slider_i)

In [None]:
def filter_demo(pb):
    return pb[-1].data['rew'] > 0 and np.mean(pb.get_fg_mask()) > 0.005

demo_good = [filter_demo(pb) for pb in playbacks]
good_demonstrations = np.where(demo_good)[0]
print(good_demonstrations)
good_demonstrations = [int(x) for x in good_demonstrations]
live_seeds = good_demonstrations

In [None]:
# Load demo segmentation file
demo_seg_file = f'{root_dir}/demo_parts_manual2.json'
fp = open(demo_seg_file)
demo_parts = json.load(fp)
live_seeds = [int(key) for key in demo_parts.keys()]
demo_parts

In [None]:
# Servo Module
# Load Servoing Module
from flow_control.servoing.module import ServoingModule
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
servo_module = ServoingModule(recordings[0], control_config=control_config,
                              start_paused=False)

In [None]:
import ipdb
import cv2

def similarity_from_reprojection(live_rgb, demo_rgb, demo_mask, return_images=False):
    # evaluate the similarity via flow reprojection error
    flow = servo_module.flow_module.step(demo_rgb, live_rgb)
    warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
    diff = (warped - (demo_rgb / 255.0))
    error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
    error = error.sum() / demo_mask.sum()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    if return_images:
        return error, mean_flow, flow, warped
    return error, mean_flow

In [None]:
from sklearn.preprocessing import minmax_scale

def normalize_errors(errors, flows):
    errors_l = errors[demo_good]
    mean_flows_l = flows[demo_good]
    errors_norm = np.ones(errors.shape)
    w = 0.5
    errors_norm[demo_good] = np.mean((1*minmax_scale(errors_l), w*minmax_scale(mean_flows_l)),axis=0)/(1+w)
    return errors_norm

## Errors front for all live seeds

In [None]:
errors_front = np.ones((len(playbacks), len(playbacks)))
flows_front = np.zeros((len(playbacks), len(playbacks)))

for live_i in tqdm(live_seeds):
    live_rgb = playbacks[live_i][0].cam.get_image()[0]
    for k1, v1 in demo_parts.items():
        demo_i = int(k1)
        if live_i == demo_i:
            continue
        im = playbacks[demo_i][v1[0]['start']].cam.get_image()[0]
        mask = playbacks[demo_i].fg_masks[v1[0]['start']]
        
        error, flow = similarity_from_reprojection(live_rgb, im, mask)
        errors_front[demo_i, live_i] = error
        flows_front[demo_i, live_i] = flow

# ax = plt.axes()
# sns.heatmap(errors_front)
# ax.set_xlabel('demo_index')
# ax.set_ylabel('live_index')

# plt.show()

errors_front_norm = normalize_errors(errors_front, flows_front)
np.savez('errors_front_norm_recom.npz', errors_front_norm)

## Error matrix between demonstration parts 

In [None]:
import seaborn as sns

error_matrix = np.ones((len(playbacks), len(playbacks)))
flow_matrix = np.zeros((len(playbacks), len(playbacks)))

for k1, v1 in tqdm(demo_parts.items()):
    demo_i1 = int(k1)
    im1 = playbacks[demo_i1][v1[0]['end']].cam.get_image()[0]
    for k2, v2 in demo_parts.items():
        demo_i2 = int(k2)      
        im2 = playbacks[demo_i2][v2[1]['start']].cam.get_image()[0]
        mask2 = playbacks[demo_i2].fg_masks[v2[1]['start']]
        
        error, flow = similarity_from_reprojection(im1, im2, mask2)
        error_matrix[demo_i1, demo_i2] = error
        flow_matrix[demo_i1, demo_i2] = flow

error_matrix_norm = normalize_errors(error_matrix, flow_matrix)
np.savez("em_norm_old_fn.npz", error_matrix_norm)


sns.heatmap(1 - error_matrix_norm)

## Errors with respect to goal image
We use a hacky function to update the segmentation mask of the goal image

In [None]:
from robot_io.recorder.simple_recorder import unprocess_seg

errors_rear = np.ones((len(playbacks), 1))
flows_rear = np.zeros((len(playbacks), 1))

goal_path = '../tmp/goal_pick_n_place_trapeze_rR_seed1000'
goal_pl = PlaybackEnvServo(goal_path)

object_name = "Trapezium_Sort_Box"
info = goal_pl[-1].get_info()
print(info.keys())
name2uid = {}
for i in range(10):
    i = str(i)
    if i not in info.keys():
        continue
    name2uid[info[i]["name"]] = info[i]["UID"]
print(name2uid)
obj_uid = int(name2uid[object_name])
srf_uid = int(name2uid["surface_red"])
seg_mask = info["seg_mask"]
seg_obj, _ = unprocess_seg(seg_mask)
goal_mask = np.logical_or(seg_obj == obj_uid, seg_obj == srf_uid)
plt.imshow(goal_mask)
plt.show()

goal_rgb = goal_pl[-1].cam.get_image()[0]

for k1, v1 in demo_parts.items():
    demo_i1 = int(k1)
    im1 = playbacks[demo_i1][v1[1]['end']].cam.get_image()[0]
    error, flow = similarity_from_reprojection(im1, goal_rgb, goal_mask)
    errors_rear[demo_i1] = error
    flows_rear[demo_i1] = flow

errors_rear_norm = normalize_errors(errors_rear, flows_rear)
np.savez("errors_rear_trapeze_new_seg_mask.npz", errors_rear_norm)

In [None]:
# Plot the demonstrations
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0][-1].cam.get_image()[0])

def update(demo_index):
    image = playbacks[demo_index][-1].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    print(f"Error: {errors_rear[demo_index]}, Error Norm: {errors_rear_norm[demo_index]}")
    print(playbacks[demo_index][-1].data['rew'])
    print(f"Good Demo: {demo_good[demo_index]}")
#     print("wp_name:", playbacks[demo_index][-1].get_info()["wp_name"])
#     fg_mask = playbacks[demo_index][-1].get_fg_mask()
#     if fg_mask is not None:
#         print("percent fg:", np.mean(fg_mask)*100)
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, demo_index=slider_w, frame_index=slider_i)

In [None]:
errors_rear

In [None]:
# Load errors
# errors_front_norm = np.load('../../../res/errors_front_norm.npz')['arr_0']
# error_matrix_norm = np.load('../../../res/error_matrix_norm.npz')['arr_0']
# errors_rear_norm = np.load('../../../res/errors_rear_norm.npz')['arr_0']
errors_front_norm = np.load('errors_front_norm_rR.npz')['arr_0']
error_matrix_norm = np.load('error_matrix_norm_rR.npz')['arr_0']
errors_rear_norm = np.load('errors_rear_norm_rR.npz')['arr_0']
scores_front_norm = 1 - errors_front_norm
score_matrix_norm = 1 - error_matrix_norm
scores_rear_norm = 1 - errors_rear_norm

## Determine best Trajectory Segments 

In [None]:
def get_best_segments(score_front, score_rear, score_matrix, score_fn='min'):
    x, y = score_matrix.shape
    best_score_fn = {'score': 0.0, 'idx1': -1, 'idx2': -1}

    if score_fn == 'min':
        idx1 = np.argmax(score_front)
        idx2 = np.argmax(score_rear)
        best_score_fn['idx1'] = idx1
        best_score_fn['idx2'] = idx2

        return best_score_fn

    for i in range(x):
        for j in range(y):
            total_score_fn = 0.0
            if score_fn == 'sum':
                total_score_fn = score_matrix[i][j] + score_front[i] + score_rear[j]
            elif score_fn == 'prod':
                total_score_fn = score_front[i] * score_matrix[i][j] * score_rear[j]

            if total_score_fn > best_score_fn['score']:
                best_score_fn['score'] = total_score_fn
                best_score_fn['idx1'] = i
                best_score_fn['idx2'] = j

    return best_score_fn

In [None]:
best_traj = {}
for live_i in range(len(playbacks)):
    sf = scores_front_norm[:, live_i]
    best_traj[live_i] = get_best_segments(sf, scores_rear_norm, score_matrix_norm, score_fn='prod')

In [None]:
best_traj

In [None]:
# Plot similarity scores
from scipy import ndimage
import cv2

%matplotlib notebook
fig, (ax, ax2) = plt.subplots(1, 2)
fig.suptitle("Flow")
ax.set_title("live")
ax2.set_title("demo")
ax.set_axis_off()
ax2.set_axis_off()
empty_image = np.zeros((256, 256), dtype=np.uint8)
image_h = ax.imshow(empty_image)
image_h2 = ax2.imshow(empty_image)
num_episodes = len(playbacks)

arrow_flow = ax2.annotate("", xytext=(64, 64), xy=(84, 84), arrowprops=dict(arrowstyle="->"))

def update(live_index):
    global arrow_flow
    
    image_l = playbacks[live_index][0].cam.get_image()[0]
    image_h.set_data(image_l)
    ax.set_title(f"live @ {live_index}, 0")
    
    demo_index = best_traj[live_index]['idx1'] 

    if live_index not in live_seeds:
        image_h2.set_data(empty_image)
        ax2.set_title(f"demo")
        print("no valid demos.")
        return
    arrow_flow.remove()
    image_d = playbacks[demo_index][0].cam.get_image()[0]
    demo_mask =  playbacks[demo_index].get_fg_mask()
    error, mean_flow, flow, warped = similarity_from_reprojection(image_l, image_d, demo_mask,
                                                                  return_images=True)
    mask_com = np.array(ndimage.center_of_mass(demo_mask))[::-1]
    size_scl = np.array([1.0, 1.0])

    mean_flow_origin = mask_com * size_scl
    mean_flow_xy = mean_flow_origin + mean_flow * size_scl
    
    demo_img = image_d
    arrw_f = ax2.annotate("", xytext=mean_flow_origin,
                                      xy=mean_flow_xy,
                                      arrowprops=dict(arrowstyle="->"))
    arrow_flow = arrw_f
    
    print(f"error {error:.2f}")
    image_h2.set_data(demo_img)
    ax2.set_title(f"demo @ {demo_index}, 0")
    fig.canvas.draw_idle()
    
slider_l = widgets.IntSlider(min=0, max=num_episodes-1, step=1, value=3,
                             layout=Layout(width='70%'))
interact(update, live_index=slider_l)
plt.tight_layout()
plt.show()

In [None]:
import ipdb
def split_keypoints(pb, demo_part):
    demo_keep = sorted(list(pb.keep_dict.keys()))
    keep_all = copy.copy(pb.keep_dict)
    keep_parts = {}
    for p in demo_part:
#         ipdb.set_trace()
        if p["start"] == 0:
            p_start = -1
        else:
            p_start = p["start"]

        parts = []
        for demo_index in demo_keep:
            if p_start < demo_index and p["end"] >= demo_index:
                parts.append(demo_index)
#         print(p["name"], '\t', parts)
        
    

        keep_parts[p["name"]] = parts
#     keep_parts['locate'].append(keep_parts['insert'][0])
#     keep_parts['insert'].pop(0)
    return keep_parts

keypoint_info = {}
# pb_keep = [PlaybackEnvServo(rec, load='keep') for rec in recordings[:]]
for demo_seed in good_demonstrations:
    keypoint_info[demo_seed] = split_keypoints(playbacks[demo_seed], demo_parts[str(demo_seed)])
# keypoint_info = {0: {'locate': [0, 4], 'grasp': [7, 14, 26], 'insert': [26, 31, 37, 42, 44, 47]}}
keypoint_info

In [None]:
def select_key_points(rec, idx, traj_info):
    selected_kp = []
    start_idx, end_idx = traj_info[idx]['start'], traj_info[idx]['end']
    keep_dict = json.load(open(f'{rec}/servo_keep.json'))
    for key in keep_dict.keys():
        if int(key) < start_idx:
            continue
        elif int(key) > end_idx:
            return selected_kp
        else:
            selected_kp.append(int(key))
    
    return selected_kp        

In [None]:
from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.servoing.module import ServoingModule
from flow_control.runner import evaluate_control
from math import pi
import hydra

# @hydra.main()
def eval_cmb(save_dir, live_seed, rec_seeds, keypoint_info):
    renderer = "debug"
    control_config = dict(mode="pointcloud-abs-rotz", threshold=0.4)
    task = 'pick_n_place'
    object_selected = 'trapeze'
    task_variant = 'rR'

    env = RobotSimEnv(task='recombination', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=500, control='absolute-full',
                      img_size=(256, 256), param_randomize=("geom",),
                      param_info={'object_selected': object_selected, 'task_selected': task},
                      task_info=dict(object_rot_range={"rP":pi/2.,"rR":pi/6.}[task_variant]),
                      seed=int(live_seed))
    
    traj_map = {0: 'locate', 1: 'insert'}

    for idx, seed in enumerate(rec_seeds):
        rec = recordings[seed]            
        
        selected_kp = keypoint_info[seed][traj_map[idx]]
        servo_module = ServoingModule(rec, control_config=control_config,
                                      start_paused=False, plot=False, plot_save_dir=None,
                                      load='select', selected_kp=selected_kp)
        _, reward, _, info = evaluate_control(env, servo_module,
                                              max_steps=130,
                                              save_dir=f"{save_dir}_{idx}",
                                              initial_align=True if idx == 0 else False)
    return reward

In [None]:
# rewards = []
# for live_i in live_seeds:
#     best_traj_combination = best_traj[live_i]
#     idx1, idx2 = best_traj_combination['idx1'], best_traj_combination['idx2']
#     rec_list = [recordings[idx1], recordings[idx2]]
#     demo_traj_info = [demo_seg[str(idx1)], demo_seg[str(idx2)]]
#     reward = eval_cmb("", live_i, rec_list, demo_traj_info)
#     rewards.append(reward)

In [None]:
@hydra.main()
def main():
    eval_cmb("", 0, [0, 0], keypoint_info)

In [None]:
main()

In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

In [None]:
# best_traj_combination = best_traj[0]
# idx1, idx2 = best_traj_combination['idx1'], best_traj_combination['idx2']\
# rewards = []
# for live_i in range(20):    
#     rec_list = [recordings[0], recordings[0]]
#     demo_traj_info = [demo_seg[str(0)], demo_seg[str(0)]]
#     rewards.append(eval_cmb("", live_i, rec_list, demo_traj_info))

In [None]:
# steps = [1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75]
steps = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 60, 70, 75]
num_steps = len(steps)
save_root = "../tmp/recombination_run_rR_DEBUG"

for score_fn in ['min']:
    rewards = np.zeros((20, num_steps))

    for live_idx, live_seed in tqdm(enumerate(live_seeds[0:20])):
        last_rec1, last_rec2 = None, None

        # Now you have all required errors_old
        for step_idx, value in enumerate(steps):
            sf = scores_front_norm[0:value, live_seed]
            sr = scores_rear_norm[0:value]
            sm = score_matrix_norm[0:value, 0:value]
            best_segments = get_best_segments(sf, sr, sm, score_fn)
            
            idx1, idx2 = best_segments['idx1'], best_segments['idx2']
            rec_seeds = [idx1, idx2]
            rec1, rec2 = recordings[idx1], recordings[idx2]
#             rec_list = [rec1, rec2]
#             demo_traj_info = [demo_seg[str(idx1)], demo_seg[str(idx2)]]

            print(f"Recordings selected are: {rec1}, {rec2}")
            # continue
            if last_rec1 == rec1 and last_rec2 == rec2:
                # This was already tested, use the result
                rewards[live_idx, step_idx] = rewards[live_idx, step_idx - 1]
                print(f"skipped {live_seed} and {step_idx}")
            else:
                # This needs to be tested
                save_dir = f"{save_root}/{score_fn}/run_recombination_{live_seed}_{step_idx}"
                rewards[live_idx, step_idx] = eval_cmb(save_dir, live_seed, rec_seeds, keypoint_info)

                last_rec1 = rec1
                last_rec2 = rec2

            np.savez(f'rewards_{score_fn}.npz', rewards)

In [None]:
# steps = [1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75]
def filter_demo(pb):
    return 1 if pb[-1].data['rew'] > 0 and np.atleast_1d(pb[-1].data['info'])[0]['object_selected'] == 2 else 0

steps = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 60, 70, 75]
num_steps = len(steps)
save_root = "../tmp/recombination_run_rR"

# rewards_min = np.load('rewards_min.npz')['arr_0']
# rewards_sum = np.load('rewards_sum.npz')['arr_0']
# rewards_prod = np.load('rewards_prod.npz')['arr_0']

mask_min = np.zeros((20, num_steps))
mask_sum = np.zeros((20, num_steps))
mask_prod = np.zeros((20, num_steps))
rewards = []

for score_fn in ['min', 'sum', 'prod']:
    mask = np.zeros((20, num_steps))

    for live_idx, live_seed in tqdm(enumerate(live_seeds[0:20])):
        last_rec1, last_rec2 = None, None
        rew = []

        # Now you have all required errors_old
        for step_idx, value in enumerate(steps):
            sf = scores_front_norm[0:value, live_seed]
            sr = scores_rear_norm[0:value]
            sm = score_matrix_norm[0:value, 0:value]
            best_segments = get_best_segments(sf, sr, sm, score_fn)
            
            idx1, idx2 = best_segments['idx1'], best_segments['idx2']
            rec1, rec2 = recordings[idx1], recordings[idx2]

            print(f"Recordings selected are: {rec1}, {rec2}")
            # continue
            if last_rec1 == rec1 and last_rec2 == rec2:
                # This was already tested, use the result
                mask[live_idx, step_idx] = mask[live_idx, step_idx - 1]
                rew.append(rew[-1])
                print(f"skipped {live_seed} and {step_idx}")
            else:
                # This needs to be tested
                save_dir = f"{save_root}/{score_fn}/run_recombination_{live_seed}_{step_idx}_1"
                filtered_reward = filter_demo(pb)
                rew.append(filtered_reward)
                pb = PlaybackEnvServo(save_dir)
                mask[live_idx, step_idx] = filter_demo(pb)

                last_rec1 = rec1
                last_rec2 = rec2
    rewards.append(rew)
    
    np.savez(f'rew_{score_fn}.npz', rewards)
                
    exec(f"mask_{score_fn} = mask")

            

In [None]:
rewards_min = np.load('rewards_min.npz')['arr_0'] * mask_min
rewards_sum = np.load('rewards_sum.npz')['arr_0'] * mask_sum
rewards_prod = np.load('rewards_prod.npz')['arr_0'] * mask_prod

In [None]:
plt.cla()

In [None]:
rew = np.mean(rewards_min, axis=0)
rew_sum = np.mean(rewards_sum, axis=0)
rew_prod = np.mean(rewards_prod, axis=0)
# x = [1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75]
x =  [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 60, 70, 75]
plt.plot(x, rew_prod, ".-", label='score_fn=prod')
plt.plot(x, rew_sum, ".-", label='score_fn=sum')
plt.plot(x, rew, ".-", label='score_fn=dis')
# plt.plot(x, rewards_ml, ".-", label='error_fn=ML')

mean, std = 0.38, 0.19
mean, std = 0.52, 0.20

plt.plot([0, 75], [mean, mean], "k--")
plt.axhspan(mean - std, mean + std, facecolor ='gray', alpha = 0.2)

plt.xlabel("#Recordings")
plt.ylabel("Mean Rewards")
plt.title("Recombination (Debug Mode)")
plt.legend()
plt.savefig('rewards_abs.jpg', dpi=800)