In [None]:
import os
import json
from pathlib import Path
from scipy.spatial.transform import Rotation as R

runs_dir = Path("/misc/student/nayaka/paper/flowcontrol/flow_control/demo/tmp_new/cnn_run/")
demos_dir = Path("/misc/student/nayaka/paper/flowcontrol/flow_control/demo/tmp_new/cnn_new/")
demo_dir_jpg = Path("/misc/student/nayaka/paper/flowcontrol/flow_control/demo/train_sim/demo_imgs/")
live_dir_jpg = Path("/misc/student/nayaka/paper/flowcontrol/flow_control/demo/train_sim/live_imgs/")
rewards_fn = Path("/misc/student/nayaka/paper/flowcontrol/flow_control/demo/tmp_new/cnn_run/")

with open(demos_dir/"demo_parts_manual3.json") as f_obj:
    demo_parts = json.load(f_obj)

with open(runs_dir/"rewards.json") as f_obj:
    rewards = json.load(f_obj)    

# the number in runs is the index of this list
demos_list = sorted(os.listdir(demos_dir))
print(demos_list)
print(len(demos_list)-1)

demo_idx_to_dir = dict(enumerate(demos_list[:-1]))
print(demo_idx_to_dir)
runs = sorted(os.listdir(runs_dir))

# also filter these

In [None]:
import numpy as np
def get_image(demo_dir, frame_index, depth=False):
    arr = np.load(os.path.join(demo_dir, f"frame_{frame_index:06d}.npz"))
    rgb_gripper = arr["rgb_gripper"]
    return rgb_gripper

def get_info(demo_dir, frame_index):
    arr = np.load(os.path.join(demo_dir, f"frame_{frame_index:06d}.npz"), allow_pickle=True)
    return arr["robot_state"].item(), arr["info"].item()

def pos_orn_to_matrix(pos, orn):
    mat = np.eye(4)
    if len(orn) == 4:
        mat[:3, :3] = R.from_quat(orn).as_matrix()
    elif len(orn) == 3:
        mat[:3, :3] = R.from_euler('xyz', orn).as_matrix()
    mat[:3, 3] = pos
    return mat

def get_tcp_pose(demo_dir, frame_index):
    arr = np.load(os.path.join(demo_dir, f"frame_{frame_index:06d}.npz"),allow_pickle=True)
    state = arr["robot_state"].item()
    return pos_orn_to_matrix(state["tcp_pos"], state["tcp_orn"])

def get_extr_cal(demo_dir):
    camera_info = np.load(Path(demo_dir) / "camera_info.npz", allow_pickle=True)
    extr = camera_info["gripper_extrinsic_calibration"]
    return extr




# Define Distance Functions

$d(s_t, s_d) \in \mathbb{R}$

1. `GT_pose`: Ground truth pose distance
2. `VS`: Visual Similarity ~ hand crafted function ~ reprojection distance

In [None]:
def get_cam2obj(demo_dir, frame_num):
    state, info = get_info(demo_dir, frame_num)
    pos, orn = info["0"]["position"], info["0"]["orientation"]
    t_tcp_cam = get_extr_cal(demo_dir)
    t_tcp_cam = np.eye(4)
    t_tcp_robot = get_tcp_pose(demo_dir, frame_num)
    trf = np.linalg.inv(t_tcp_robot @ t_tcp_cam) @ pos_orn_to_matrix(pos, orn)
    return trf


cam2obj_demo = {}
for k, v in demo_idx_to_dir.items():
    #print(v, [x["start"] for x in demo_parts[str(k)]])
    for part_index, parts in enumerate(demo_parts[str(k)]):
        start = parts["start"]
        trf = get_cam2obj(demos_dir/v, start)
        cam2obj_demo[(k, part_index)] = trf

        
def get_scores_GP_pose(run_obs_fn, demo_index, part_num):
    # d(s_t, s_d) where 
    # run_dir -> s_t
    # demo_index, part_num -> s_d
    trf = get_cam2obj(run_obs_fn, 0)
    trf2 = cam2obj_demo[(demo_index, part_num)]
        
    diff = trf2 @ np.linalg.inv(trf)
    pos_diff = np.linalg.norm(diff[0:3,3])
    orn_diff = R.from_matrix(diff[:3,:3]).magnitude()
    return pos_diff, orn_diff

In [None]:
from flow_control.servoing.module import ServoingModule
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
random_demo_dir = demos_dir / demo_idx_to_dir[0]
servo_module = ServoingModule(random_demo_dir, control_config=control_config,
                              start_paused=False)

def get_mask(demo_dir, frame_index):
    arr = np.load(demo_dir / "servo_mask.npz")
    mask = arr["mask"][frame_index] == arr["fg"][frame_index]
    return mask


def similarity_from_reprojection(live_rgb, demo_rgb, demo_mask, return_images=False):
    # evaluate the similarity via flow reprojection error
    flow = servo_module.flow_module.step(demo_rgb, live_rgb)
    warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
    error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
    error = error.sum() / demo_mask.sum()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    if return_images:
        return error, mean_flow, flow, warped
    return error, mean_flow


def get_scores_VS(run_obs_fn, demo_index, part_num):
    s_t_rgb = get_image(run_obs_fn, 0)
    
    demo_part_frame = demo_parts[str(demo_index)][part_num]["start"]
    demo_dir = demos_dir / demo_idx_to_dir[demo_index]    
    s_g_rgb = get_image(demo_dir, demo_part_frame)
    s_g_mask = get_mask(demo_dir, demo_part_frame)
    return similarity_from_reprojection(s_t_rgb, s_g_rgb, s_g_mask)
    

In [None]:
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm

distance = "VS"

scores_list = []
rew_list = []
for run in tqdm(runs[:200]):
    parts = sorted(os.listdir(runs_dir / run))
    parts = ["p2",]
    for part in parts:
        part_dir = runs_dir / run / part
        entrys = os.listdir(part_dir)
        assert len(entrys) == 1
        part_str = entrys[0]
        live_seed, demo_index = [int(x) for x in part_str.split("_")]
        part_num = int(part.replace("p",""))
        #print(part_dir, demo_index,"x",part_num, run)
        
        # load jpg live and demo image
        #jpg_demo_image = demo_dir_jpg / "{}_{}.jpg".format(run, part)
        #demo_img_jpg = Image.open(jpg_demo_image)
        #live_img_jpg = Image.open(live_dir_jpg / "{}_{}.jpg".format(run, part))
        
        # load demo frame
        #demo_part_frame = demo_parts[str(demo_index)][part_num]["start"]
        #demo_dir = demos_dir / demo_idx_to_dir[demo_index]
        #demo_frame_fn = demos_dir / demo_idx_to_dir[demo_index] / "frame_{0:06d}.npz".format(demo_part_frame)
        #demo_img_run = get_image(demo_dir, demo_part_frame)
        
        if distance == "GP_pose":
            scores = get_scores_GP_pose(part_dir/part_str, demo_index, part_num)
        elif distance == "VS":
            scores = get_scores_VS(part_dir/part_str, demo_index, part_num)
        else:
            raise ValueError
        scores_list.append(scores)
        rew_list.append(rewards[run])
        
        # plot
        #fig, ax = plt.subplots(1, 2, figsize=(8, 6))
        #ax[0].imshow(live_img_jpg)
        #ax[1].imshow(demo_img_jpg)
        #plt.show()
        # recording demo image
        
    #print(type(run))
    #print(rewards)
    #reward = rewards[run]
    #print("-"*50, reward, pos_diff, orn_diff) 

In [None]:
import sklearn.metrics
    
if distance == "GP_pose":
    pos_diffs = []
    orn_diffs = []
    for pos_diff, orn_diff in scores_list:
        pos_diffs.append(pos_diff)
        orn_diffs.append(orn_diff)
        
    cmb_diffs = np.array(pos_diffs) + np.array(orn_diffs)
    y = rew_list
    pred = -1*cmb_diffs
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y, pred, pos_label=1)
    print(sklearn.metrics.auc(fpr, tpr))

if distance == "VS":
    rp_errs  = []
    mean_flows = []
    for rp_err, mean_flow in scores_list:
        rp_errs.append(rp_err)
        mean_flows.append(mean_flow)
        
    y = rew_list
    pred = -1*np.array(rp_errs)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y, pred, pos_label=1)
    print(sklearn.metrics.auc(fpr, tpr))
    