In [1]:
%load_ext autoreload
%autoreload 2

import os
import copy
import json
import logging
from glob import glob
from pathlib import Path

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout

from flow_control.demo.playback_env_servo import PlaybackEnvServo

root_dir = "/home/argusm/Desktop/Demonstrations/2023-01-24/"
recordings = [os.path.join(root_dir, d) for d in sorted(os.listdir(root_dir)) if os.path.isdir(os.path.join(root_dir, d))]

print("Number of recordings:", len(recordings))
print("first", recordings[0])
print("last ", recordings[-1])

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Number of recordings: 30
first /home/argusm/Desktop/Demonstrations/2023-01-24/14-18-38
last  /home/argusm/Desktop/Demonstrations/2023-01-24/18-08-23


In [2]:
# Load the demonstration episodes
playbacks = [PlaybackEnvServo(rec, load='keep') for rec in recordings[:]]

In [3]:
# Load Servoing Module
from flow_control.servoing.module import ServoingModule
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
servo_module = ServoingModule(recordings[0], control_config=control_config,
                              start_paused=False, flow_module='RAFT')

def similarity_from_reprojection(live_rgb, demo_rgb, demo_mask, return_images=False):
    # evaluate the similarity via flow reprojection error
    flow = servo_module.flow_module.step(demo_rgb, live_rgb)
    warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
    error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
    error = error.sum() / demo_mask.sum()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    if return_images:
        return error, mean_flow, flow, warped
    return error, mean_flow

pybullet build time: May 20 2022 19:44:17
INFO - 2023-06-13 16:10:39,929 - module - Loading ServoingModule...
INFO - 2023-06-13 16:10:39,929 - module - Using RAFT
INFO - 2023-06-13 16:10:40,546 - module - Loading recording (make take a bit): /home/argusm/Desktop/Demonstrations/2023-01-24/14-18-38
INFO - 2023-06-13 16:10:40,878 - module - Loading time was 0.331 s
INFO - 2023-06-13 16:10:42,230 - module_raft - Loading RAFT model, may take a bit...
INFO - 2023-06-13 16:10:42,321 - module - Threshold: {18: 0.6789367811625633, 185: 0.5880885725701316, 235: 0.5880789829366987, 269: 0.5834225891593978, 327: 0.5014027643760912, 386: 0.4, 400: 1.0, 471: 0.8962306117608927, 539: 0.9175077815972037, 681: 0.46160886946574853, 796: 0.42284438410685893, 849: 0.4, 881: 0.4000662074167735}


0.4 1.0


In [4]:
demo_good = [True for pb in playbacks]

bad_pair_arr = np.zeros((len(playbacks), len(playbacks)), dtype=bool)
for idx in np.where(np.array(demo_good) == False)[0]:
    bad_pair_arr[:,idx] = True
    bad_pair_arr[idx,:] = True
bad_pair_arr += np.eye(len(playbacks), len(playbacks), dtype=bool)
#print(bad_pairs.astype(int))
good_pairs = list(zip(*np.where(bad_pair_arr==False)))

print(f"Good demos:   {np.mean(demo_good)*100:.1f} %\t", sum(demo_good),"/",len(demo_good) )
print(f"Viable pairs: {(1-bad_pair_arr.mean())*100:.1f} %\t",sum(bad_pair_arr.flatten()==0),"/", bad_pair_arr.size)



Good demos:   100.0 %	 30 / 30
Viable pairs: 96.7 %	 870 / 900


In [7]:
sim_scores = np.ones(bad_pair_arr.shape)  # lower is better
mean_flows = np.zeros(bad_pair_arr.shape)

for live_i, demo_i in tqdm(good_pairs):
    live_rgb = playbacks[live_i][18].cam.get_image()[0]

    demo_rgb =  playbacks[demo_i][18].cam.get_image()[0]
    demo_mask =  playbacks[demo_i].get_fg_mask()
    error, mean_flow = similarity_from_reprojection(live_rgb, demo_rgb, demo_mask)
#     assert error <= 1.0
    sim_scores[demo_i, live_i] = error
    mean_flows[demo_i, live_i] = mean_flow

100%|█████████████████████████████████████████████████████████████████████████████████████████| 870/870 [03:33<00:00,  4.08it/s]


In [8]:
from sklearn.preprocessing import minmax_scale

sim_l = sim_scores[demo_good]
mean_flows_l = mean_flows[demo_good]
fig, ax = plt.subplots(1, 1)
ax.set_xlabel("reprojection")
ax.set_ylabel("mean flow")
ax.scatter(minmax_scale(sim_l), minmax_scale(mean_flows_l))
plt.show()

sim_scores_norm = np.ones(sim_scores.shape)
w = .5
sim_scores_norm[demo_good] = np.mean((1*minmax_scale(sim_l), w*minmax_scale(mean_flows_l)),axis=0)/(1+w)

In [9]:
good_episode = np.any(demo_good, axis=0)
print("live episode", np.arange(len(recordings))[good_episode])
print("demo episode", np.argmin(sim_scores_norm, axis=0)[good_episode])

live episode [[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
  24 25 26 27 28 29]]
demo episode [[ 9 13 19  1  2  9 22  4  3  5  1 19  5  1  3 11 12 22  5  7  3  3  6  4
   3  2  3  3  3  3]]


In [30]:
from scipy.spatial.transform import Rotation as R

N_DIGITS = 6

def pos_orn_to_matrix(pos, orn):
    """
    Arguments:
        pos (x,y,z)
        orn (q_x, q_y, q_z, w)
    Returns:
        mat: 4x4 homogeneous transformation
    """
    assert len(pos) == 3
    assert len(orn) == 4

    mat = np.eye(4)

    mat[:3, 3] = pos
    mat[:3, :3] = R.from_quat(orn).as_matrix()
    return mat


def matrix_to_pos_orn(mat):
    """
    Arguments:
        mat: 4x4 homogeneous transformation
    Returns:
        tuple:
            position: (x, y, z)
            orientation: quaternion (q_x, q_y, q_z, w)
    """
    pos = mat[:3, 3]
    orn = R.from_matrix(mat[:3, :3]).as_quat()
    return pos, orn


def get_actions(meas_path, frame_idx):
    servo_file = f"{meas_path}/servo_keep.json"
    
    with open(servo_file, 'r') as f_obj:
        servo_keep = json.load(f_obj)
    servo_keys = list(servo_keep.keys())
    
    for idx, key in enumerate(servo_keys):
        key = int(key)
        if key == frame_idx:
            if idx + 1 < len(servo_keys):
                next_frame_idx = int(servo_keys[idx + 1])
                break
            else:
                raise ValueError("")
                
    frame_path = f"{meas_path}/frame_{frame_idx:0{N_DIGITS}d}.npz"
    next_frame_path = f"{meas_path}/frame_{next_frame_idx:0{N_DIGITS}d}.npz"
                
    frame_data = np.load(frame_path, allow_pickle=True)
    next_frame_data = np.load(next_frame_path, allow_pickle=True)

    current_pos = np.array(np.atleast_1d(frame_data['robot_state'])[0]['tcp_pos'])
    current_orn = np.array(np.atleast_1d(frame_data['robot_state'])[0]['tcp_orn'])

    next_pos = np.array(np.atleast_1d(next_frame_data['robot_state'])[0]['tcp_pos'])
    next_orn = np.array(np.atleast_1d(next_frame_data['robot_state'])[0]['tcp_orn'])

#     gripper_action = np.atleast_1d(next_frame_data['action'])[0]['motion'][2]

    start_m = pos_orn_to_matrix(current_pos, current_orn)
    finish_m = pos_orn_to_matrix(next_pos, next_orn)

    rel_trf = np.linalg.inv(start_m) @ finish_m
    rel_pos_orn = [list(x) for x in matrix_to_pos_orn(rel_trf)]

    # Relative Actions
    rel_pos = np.array(rel_pos_orn[0], dtype='float32')
    rel_orn = np.array(rel_pos_orn[1], dtype='float32')
#     gripper_action = np.array(gripper_action, dtype='float32')

                
    return rel_pos, rel_orn

In [33]:
pos_mse = []
orn_mse = []

for idx in range(len(recordings)):
    query_demo = recordings[idx]
    best_idx = np.argmin(sim_scores_norm[:, idx])
    
    best_demo = recordings [best_idx]
    
    q_pos, q_orn = get_actions(query_demo, 18)
    d_pos, d_orn = get_actions(best_demo, 18)
    
    pmse = np.linalg.norm(q_pos - d_pos)
    omse = np.linalg.norm(q_orn - d_orn)
    
    pos_mse.append(pmse)
    orn_mse.append(omse)

In [34]:
np.mean(pos_mse), np.mean(orn_mse)

(0.036350686, 0.0075728735)