# Recombination

In [None]:
import os
import shutil
import unittest
import subprocess

import cv2
import ipdb
import numpy as np

from scipy.spatial.transform import Rotation as R
from robot_io.envs.playback_env import PlaybackEnv
from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.demo.demo_episode_recorder import record_sim
from flow_control.flow_control_main import evaluate_control
from flow_control.servoing.module import ServoingModule
import os
import numpy as np
import logging

def is_notebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True  # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interpreter

interactive = is_notebook()  # becomes overwritten
if interactive:
    get_ipython().run_line_magic('matplotlib', 'notebook')
    from ipywidgets import widgets, interact, Layout
    import matplotlib.pyplot as plt
    %matplotlib inline

In [None]:
# Instantiate environment for the seed
if interactive:    
    seed = 101
    rec_path = '../recombination/tmp_test_split/'
else:
    import sys
    if len(sys.argv) != 4:
        print("Usage: Recombination-new.py <Split recordings location> <seed>")
    rec_path = sys.argv[1] 
    seed = int(sys.argv[2])
    
env = RobotSimEnv(task='recombination', renderer='egl', act_type='continuous',
                  initial_pose='close', max_steps=500, control='absolute-full',
                  img_size=(256, 256),
                  sample_params=False,
                  seed=seed)
goal_rec = './recombination/tmp_test_split/pick_n_place_trapeze_rR_000080_seg1'
rec = PlaybackEnv(goal_rec).to_list()
video_recording = np.array([renv.cam.get_image()[0] for renv in rec])
goal_frame = video_recording[-1]

recordings_seg0 = sorted([os.path.join(rec_path, rec) for rec in os.listdir(rec_path) if rec.endswith('seg0')])
recordings_seg1 = sorted([os.path.join(rec_path, rec) for rec in os.listdir(rec_path) if rec.endswith('seg1')])
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.25)

In [None]:
# Create Directory to store outputs
store_path = './recombination/errors'
os.makedirs(store_path, exist_ok=True)

seed_dir = os.path.join(store_path, str(seed))
os.makedirs(seed_dir, exist_ok=True)

# Do you want to compute the error matrix and eror wrt goal frame?
compute = True
files = os.listdir(store_path)
if 'error_matrix.npz' in files and 'error_rear.npz' in files:
    # File already exists, skip
    compute = False

In [None]:
state, _, _, _ = env.step(None)
live_rgb = state['rgb_gripper']

if interactive:    
    import matplotlib.pyplot as plt
    plt.imshow(live_rgb)

In [None]:
# Select demonstration either based on first frame or last frame
def select_demo(servo_modules, live_rgb, goal_frame=None, use_goal=False):
    """
    Selects the demonstration with the minimum reprojection error

    Args:
        servo_modules:
        live_rgb: Array with the live view

    Returns:
        best_servo_module:
    """

    best_task = None
    best_servo_module = None
    best_error = np.inf
    errors_list = []

    # Multipliers for Front and Rear errors
    alpha, beta = 1.0, 1.0

    for t, s in servo_modules:
        error_front, error_rear = 0.0, 0.0
        if use_goal:
            last_rec_im = s.demo.steps[-5].cam.get_image()[0]
            flow_rear = s.flow_module.step(goal_frame, last_rec_im)
            warped_rear = s.flow_module.warp_image(goal_frame / 255.0, flow_rear)

            demo_mask_rear = s.demo.fg_masks[-5]
            mask_rear = np.zeros((256, 256))
            mask_rear[demo_mask_rear == True] = 255.0
            error_rear = ((warped_rear - (last_rec_im / 255.0))
                          ** 2.0).sum(axis=2) * mask_rear

            if mask_rear.sum() == 0.0:
                error_rear = 2.0
            else:
                error_rear = error_rear.sum() / mask_rear.sum()
        else:
            first_rec_im = s.demo.steps[0].cam.get_image()[0]
            flow_front = s.flow_module.step(live_rgb, first_rec_im)

            warped_front = s.flow_module.warp_image(live_rgb / 255.0, flow_front)

            # Logical demo mask
            demo_mask_front = s.demo.fg_masks[0]
            mask_front = np.zeros((256, 256))
            mask_front[demo_mask_front == True] = 255.0

            error_front = ((warped_front - (first_rec_im / 255.0))
                           ** 2.0).sum(axis=2) * mask_front

            if mask_front.sum() == 0.0:
                error_front = 2.0
            else:
                error_front = error_front.sum() / mask_front.sum()

        error = error_front * alpha + error_rear * beta
        errors_list.append(error)

        if error < best_error:
            best_error = error
            best_task = t
            best_servo_module = s

    return best_servo_module, best_task, errors_list

In [None]:
# Set up servo modules
recordings_seg0 = sorted(recordings_seg0)
servo_modules_seg0 = [(t, ServoingModule(t, control_config=control_config,
                                    plot=False, save_dir=None)) for t in recordings_seg0]

In [None]:
_, _, errors_list_front = select_demo(servo_modules_seg0, live_rgb, goal_frame, use_goal=False)
del servo_modules_seg0

In [None]:
if compute:    
    recordings_seg1 = sorted(recordings_seg1)
    servo_modules_seg1 = [(t, ServoingModule(t, control_config=control_config,
                                        plot=False, save_dir=None)) for t in recordings_seg1]
    _, _, errors_list_rear = select_demo(servo_modules_seg1, live_rgb, goal_frame, use_goal=True)
    del servo_modules_seg1

In [None]:
res = sorted(range(len(errors_list_front)), key = lambda sub: errors_list_front[sub])[:5]
best_recordings_front = [recordings_seg0[i] for i in res]

print(best_recordings_front)

In [None]:
if compute:    
    res = sorted(range(len(errors_list_rear)), key = lambda sub: errors_list_rear[sub])[:5]
    best_recordings_rear = [sorted(recordings_seg1)[i] for i in res]

    print(best_recordings_rear)

In [None]:
if interactive:
    plt.imshow(live_rgb)

In [None]:
if interactive:    
    fig, ax = plt.subplots(5, 2, figsize=(20, 20))
    for idx, t in enumerate(best_recordings_front):
        servo_module = ServoingModule(t, control_config=control_config, plot=False, save_dir=None)
        first_frame = servo_module.demo.steps[0].cam.get_image()[0]
        last_frame = servo_module.demo.steps[-1].cam.get_image()[0]
        ax[idx, 0].imshow(first_frame)
        ax[idx, 1].imshow(last_frame)
        ax[idx, 0].set_axis_off()
        ax[idx, 1].set_axis_off()

In [None]:
if interactive:
    fig, ax = plt.subplots(5, 2, figsize=(20, 20))
    for idx, t in enumerate(best_recordings_rear):
        servo_module = ServoingModule(t, control_config=control_config, plot=False, save_dir=None)
        first_frame = servo_module.demo.steps[0].cam.get_image()[0]
        last_frame = servo_module.demo.steps[-1].cam.get_image()[0]
        ax[idx, 0].imshow(first_frame)
        ax[idx, 1].imshow(last_frame)
        ax[idx, 0].set_axis_off()
        ax[idx, 1].set_axis_off()

# Create Trajectory from source to destination

In [None]:
# Compute errors between last frame of all recordings in seg0 and first frame of all recordings in seg1
if interactive:
    compute = True
rec_seg0 = sorted(recordings_seg0)
rec_seg1 = sorted(recordings_seg1)

if compute:
    seg1_sm = [(t, ServoingModule(t, control_config=control_config,
                                        plot=False, save_dir=None)) for t in rec_seg1]

    error_matrix = []

    for rec in rec_seg0:
        # Extract last frame from this segment
        rec_sm = ServoingModule(rec, control_config=control_config, plot=False, save_dir=None)
        last_frame = rec_sm.demo.steps[-1].cam.get_image()[0]

        # Compute all errors
        _, _, errors = select_demo(seg1_sm, last_frame, use_goal=False)

        # Append errors to error matrix
        error_matrix.append(errors)

        del rec_sm
    del seg1_sm

In [None]:
# Prepare data
error_front = np.array(errors_list_front)
np.savez(os.path.join(seed_dir, 'error_front.npz'), error_front)

if compute:    
    error_matrix_arr = np.array(error_matrix)
    np.savez(os.path.join(store_path, 'error_matrix.npz'), error_matrix_arr)
    
    error_rear = np.array(errors_list_rear)
    np.savez(os.path.join(store_path, 'error_rear.npz'), error_rear)

In [None]:
if interactive:    
    # Now compute trajectory using error
    x, y = error_matrix_arr.shape
    best_error_fn1 = {'error': np.inf, 'idx1': -1, 'idx2': -1}
    best_error_fn2 = {'error': np.inf, 'idx1': -1, 'idx2': -1}

    for i in range(x):
        for j in range(y):
            total_error_fn1 = error_matrix_arr[i][j] + error_front[i] + error_rear[j]
            total_error_fn2 = error_matrix_arr[i][j] * error_front[i] * error_rear[j]

            if total_error_fn1 < best_error_fn1['error']:
                best_error_fn1['error'] = total_error_fn1
                best_error_fn1['idx1'] = i
                best_error_fn1['idx2'] = j

            if total_error_fn2 < best_error_fn2['error']:
                best_error_fn2['error'] = total_error_fn2
                best_error_fn2['idx1'] = i
                best_error_fn2['idx2'] = j


    print(best_error_fn1)
    print(best_error_fn2)

In [None]:
if interactive:    
    fig, ax = plt.subplots(1, 6, figsize=(20, 20))
    idx1 = best_error_fn1['idx1']
    idx2 = best_error_fn1['idx2']

    servo_module = ServoingModule(rec_seg0[idx1], control_config=control_config, plot=False, save_dir=None)
    seg0_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
    seg0_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]

    del servo_module

    servo_module = ServoingModule(rec_seg1[idx2], control_config=control_config, plot=False, save_dir=None)
    seg1_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
    seg1_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]

    del servo_module

    for i in range(6):
        ax[i].set_axis_off()

    ax[0].imshow(live_rgb)
    ax[1].imshow(seg0_firstframe)
    ax[2].imshow(seg0_lastframe)
    ax[3].imshow(seg1_firstframe)
    ax[4].imshow(seg1_lastframe)
    ax[5].imshow(goal_frame)


    print(rec_seg0[idx1], rec_seg1[idx2])

In [None]:
if interactive:    
    fig, ax = plt.subplots(1, 6, figsize=(20, 20))
    idx1 = best_error_fn2['idx1']
    idx2 = best_error_fn2['idx2']

    servo_module = ServoingModule(rec_seg0[idx1], control_config=control_config, plot=False, save_dir=None)
    seg0_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
    seg0_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]

    del servo_module

    servo_module = ServoingModule(rec_seg1[idx2], control_config=control_config, plot=False, save_dir=None)
    seg1_firstframe = servo_module.demo.steps[0].cam.get_image()[0]
    seg1_lastframe = servo_module.demo.steps[-1].cam.get_image()[0]

    del servo_module

    for i in range(6):
        ax[i].set_axis_off()

    ax[0].imshow(live_rgb)
    ax[1].imshow(seg0_firstframe)
    ax[2].imshow(seg0_lastframe)
    ax[3].imshow(seg1_firstframe)
    ax[4].imshow(seg1_lastframe)
    ax[5].imshow(goal_frame)

    print(rec_seg0[idx1], rec_seg1[idx2])

In [None]:
# idx1 = np.argmin(error_front)
# idx2 = np.argmin(error_rear)
# print(rec_seg0[idx1], rec_seg1[idx2])