# Recombination

In [None]:
import os
import shutil
import unittest
import subprocess

import cv2
import ipdb
import numpy as np

from scipy.spatial.transform import Rotation as R

from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.demo.demo_episode_recorder import record_sim
from flow_control.flow_control_main import evaluate_control
from flow_control.servoing.module import ServoingModule
%matplotlib inline

In [None]:
# Instantiate environment for the seed
seed = 100
env = RobotSimEnv(task='shape_sorting', renderer='egl', act_type='continuous',
                  initial_pose='close', max_steps=500, control='absolute-full',
                  img_size=(256, 256),
                  sample_params=False,
                  seed=seed)
goal_im_path = '../frame_000074.jpg'
rec_path = '../tmp_test_segmented/tmp_test'
recordings_seg0 = [os.path.join(rec_path, rec) for rec in os.listdir(rec_path) if rec.endswith('seg0')]
recordings_seg1 = [os.path.join(rec_path, rec) for rec in os.listdir(rec_path) if rec.endswith('seg1')]
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.25)

In [None]:
# Set up servo modules
recordings_seg0 = sorted(recordings_seg0)
servo_modules_seg0 = [(t, ServoingModule(t, control_config=control_config,
                                    plot=False, save_dir=None)) for t in recordings_seg0]

In [None]:
import matplotlib.pyplot as plt
state, _, _, _ = env.step(None)
live_rgb = state['rgb_gripper']
goal_frame = cv2.imread(goal_im_path)
plt.imshow(goal_frame)

In [None]:
# Select demonstration either based on first frame or last frame
def select_demo(servo_modules, live_rgb, goal_frame=None, use_goal=False):
    """
    Selects the demonstration with the minimum reprojection error

    Args:
        servo_modules:
        live_rgb: Array with the live view

    Returns:
        best_servo_module:
    """

    best_task = None
    best_servo_module = None
    best_error = np.inf
    errors_list = []

    # Multipliers for Front and Rear errors
    alpha, beta = 1.0, 1.0

    for t, s in servo_modules:
        error_front, error_rear = 0.0, 0.0
        if use_goal:
            last_rec_im = s.demo.steps[-5].cam.get_image()[0]
            flow_rear = s.flow_module.step(goal_frame, last_rec_im)
            warped_rear = s.flow_module.warp_image(goal_frame / 255.0, flow_rear)

            demo_mask_rear = s.demo.fg_masks[-5]
            mask_rear = np.zeros((256, 256))
            mask_rear[demo_mask_rear == True] = 255.0
            error_rear = ((warped_rear - (last_rec_im / 255.0))
                          ** 2.0).sum(axis=2) * mask_rear

            if mask_rear.sum() == 0.0:
                error_rear = 2.0
            else:
                error_rear = error_rear.sum() / mask_rear.sum()
        else:
            first_rec_im = s.demo.steps[0].cam.get_image()[0]
            flow_front = s.flow_module.step(live_rgb, first_rec_im)

            warped_front = s.flow_module.warp_image(live_rgb / 255.0, flow_front)

            # Logical demo mask
            demo_mask_front = s.demo.fg_masks[0]
            mask_front = np.zeros((256, 256))
            mask_front[demo_mask_front == True] = 255.0

            error_front = ((warped_front - (first_rec_im / 255.0))
                           ** 2.0).sum(axis=2) * mask_front

            if mask_front.sum() == 0.0:
                error_front = 2.0
            else:
                error_front = error_front.sum() / mask_front.sum()

        error = error_front * alpha + error_rear * beta
        errors_list.append(error)

        if error < best_error:
            best_error = error
            best_task = t
            best_servo_module = s

    return best_servo_module, best_task, errors_list

In [None]:
_, _, errors_list_front = select_demo(servo_modules_seg0, live_rgb, goal_frame, use_goal=False)

In [None]:
del servo_modules_seg0

In [None]:
recordings_seg1 = sorted(recordings_seg1)
servo_modules_seg1 = [(t, ServoingModule(t, control_config=control_config,
                                    plot=False, save_dir=None)) for t in recordings_seg1]

In [None]:
_, _, errors_list_rear = select_demo(servo_modules_seg1, live_rgb, goal_frame, use_goal=True)

In [None]:
del servo_modules_seg1

In [None]:
res = sorted(range(len(errors_list_front)), key = lambda sub: errors_list_front[sub])[:5]
best_recordings_front = [recordings_seg0[i] for i in res]

print(best_recordings_front)

In [None]:
res = sorted(range(len(errors_list_rear)), key = lambda sub: errors_list_rear[sub])[:5]
best_recordings_rear = [sorted(recordings_seg1)[i] for i in res]

print(best_recordings_rear)

In [None]:
plt.imshow(live_rgb)

In [None]:
fig, ax = plt.subplots(5, 2, figsize=(20, 20))
for idx, t in enumerate(best_recordings_front):
    servo_module = ServoingModule(t, control_config=control_config, plot=False, save_dir=None)
    first_frame = servo_module.demo.steps[0].cam.get_image()[0]
    last_frame = servo_module.demo.steps[-1].cam.get_image()[0]
    ax[idx, 0].imshow(first_frame)
    ax[idx, 1].imshow(last_frame)
    ax[idx, 0].set_axis_off()
    ax[idx, 1].set_axis_off()

In [None]:
fig, ax = plt.subplots(5, 2, figsize=(20, 20))
for idx, t in enumerate(best_recordings_rear):
    servo_module = ServoingModule(t, control_config=control_config, plot=False, save_dir=None)
    first_frame = servo_module.demo.steps[0].cam.get_image()[0]
    last_frame = servo_module.demo.steps[-1].cam.get_image()[0]
    ax[idx, 0].imshow(first_frame)
    ax[idx, 1].imshow(last_frame)
    ax[idx, 0].set_axis_off()
    ax[idx, 1].set_axis_off()

# Create Trajectory from source to destination

In [None]:
# Compute errors between last frame of all recordings in seg0 and first frame of all recordings in seg1
compute = False
rec_seg0 = sorted(recordings_seg0)
rec_seg1 = sorted(recordings_seg1)

if compute:
    seg1_sm = [(t, ServoingModule(t, control_config=control_config,
                                        plot=False, save_dir=None)) for t in rec_seg1]

    error_matrix = []

    for rec in rec_seg0:
        # Extract last frame from this segment
        rec_sm = ServoingModule(rec, control_config=control_config, plot=False, save_dir=None)
        last_frame = rec_sm.demo.steps[-1].cam.get_image()[0]

        # Compute all errors
        _, _, errors = select_demo(seg1_sm, last_frame, use_goal=False)

        # Append errors to error matrix
        error_matrix.append(errors)

        del rec_sm
    del seg1_sm
    error_matrix_arr = np.array(error_matrix)
    np.savez('error_matrix.npz', error_matrix_arr)
else:
    error_matrix_arr = np.load('error_matrix.npz')['arr_0']

In [None]:
# Prepare data
error_front = errors_list_front
error_rear = errors_list_rear

In [None]:
# Now compute trajectory using error
x, y = error_matrix_arr.shape
best_error_fn1 = {'error': np.inf, 'idx1': -1, 'idx2': -1}
best_error_fn2 = {'error': np.inf, 'idx1': -1, 'idx2': -1}

for i in range(x):
    for j in range(y):
        total_error_fn1 = error_matrix_arr[i][j] + error_front[i] + error_rear[j]
        total_error_fn2 = error_matrix_arr[i][j] * error_front[i] * error_rear[j]
        
        if total_error_fn1 < best_error_fn1['error']:
            best_error_fn1['error'] = total_error_fn1
            best_error_fn1['idx1'] = i
            best_error_fn1['idx2'] = j
            
        if total_error_fn2 < best_error_fn2['error']:
            best_error_fn2['error'] = total_error_fn2
            best_error_fn2['idx1'] = i
            best_error_fn2['idx2'] = j

            
print(best_error_fn1)
print(best_error_fn2)

print(rec_seg0[idx1], rec_seg1[idx2])

In [None]:
fig, ax = plt.subplots(1, 6, figsize=(20, 20))
idx1 = best_error_fn1['idx1']
idx2 = best_error_fn1['idx2']

servo_module = ServoingModule(rec_seg0[idx1], control_config=control_config, plot=False, save_dir=None)
seg0_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
seg0_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]

del servo_module

servo_module = ServoingModule(rec_seg1[idx2], control_config=control_config, plot=False, save_dir=None)
seg1_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
seg1_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]
    
del servo_module

for i in range(6):
    ax[i].set_axis_off()

ax[0].imshow(live_rgb)
ax[1].imshow(seg0_firstframe)
ax[2].imshow(seg0_lastframe)
ax[3].imshow(seg1_firstframe)
ax[4].imshow(seg1_lastframe)
ax[5].imshow(goal_frame)



In [None]:
fig, ax = plt.subplots(1, 6, figsize=(20, 20))
idx1 = best_error_fn2['idx1']
idx2 = best_error_fn2['idx2']

servo_module = ServoingModule(rec_seg0[idx1], control_config=control_config, plot=False, save_dir=None)
seg0_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
seg0_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]

del servo_module

servo_module = ServoingModule(rec_seg1[idx2], control_config=control_config, plot=False, save_dir=None)
seg1_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
seg1_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]
    
del servo_module

for i in range(6):
    ax[i].set_axis_off()

ax[0].imshow(live_rgb)
ax[1].imshow(seg0_firstframe)
ax[2].imshow(seg0_lastframe)
ax[3].imshow(seg1_firstframe)
ax[4].imshow(seg1_lastframe)
ax[5].imshow(goal_frame)

