# Comparing different methods of Demonstration Selection

In this notebook, we compare different methods that can be used for demonstration selection.
We compare Hierarchical Localization, Visual Similarity, R3M

In [None]:
%load_ext autoreload
%autoreload 2
import json
from pathlib import Path
from tqdm.notebook import tqdm  # notebook-friendly progress bars

from hloc import extract_features, match_features, reconstruction, visualization, pairs_from_exhaustive
from hloc.visualization import plot_images, read_image, plot_keypoints
from hloc.utils import viz_3d

from scipy import stats

In [None]:
from flow_control.demo.playback_env_servo import PlaybackEnvServo
from flow_control.localize.hloc_utils import export_images_by_parts

# root_dir = Path("/home/argusm/CLUSTER/robot_recordings/flow/recombination/2023-01-24")
root_dir = Path("/home/argusm/Desktop/Demonstrations/2023-01-24")
# root_dir = Path("/home/nayakab/Desktop/Courses/WS2022/Project/cluster/")
parts_fn = root_dir / 'parts.json'
hloc_root = root_dir.parent / ( str(root_dir.name) + '_hloc')

mapping_dir = hloc_root / 'mapping'
outputs = hloc_root / 'outputs'
sfm_pairs = outputs / 'pairs-sfm.txt'
loc_pairs = outputs / 'pairs-loc.txt'
sfm_dir = outputs / 'sfm'
features_path = outputs / 'features.h5'
matches_path = outputs / 'matches.h5'
features_seg_path = outputs / 'features_seg.h5'

if parts_fn.is_file():
    with open(parts_fn) as f_obj:
        parts_references = json.load(f_obj)
        print("Succesfully loaded parts. --> Skip to ")

In [None]:
from flow_control.localize.hloc_utils import to_hloc_ref
with open(parts_fn) as f_obj:
    tmp = json.load(f_obj)

parts_references = {}
parts_references['locate'] = [to_hloc_ref(k,v['locate'][0]) for k,v in tmp.items()]

In [None]:
!rm -rf $outputs
!rm -rf $mapping_dir
parts_references = export_images_by_parts(root_dir, parts_fn, mapping_dir)

In [None]:
references_all = [ref for ref_part in parts_references.values() for ref in ref_part]
references_files = sorted([p.relative_to(hloc_root).as_posix() for p in (hloc_root / 'mapping/').iterdir()])
assert len(set(references_all)-set(references_files)) == 0
references = parts_references['locate']

In [None]:
print(len(references), "mapping images")
plot_images([read_image(hloc_root / r) for r in references[:4]], dpi=50)

In [None]:
from flow_control.localize.hloc_utils import save_features_seg


feature_conf = extract_features.confs['superpoint_aachen']
matcher_conf = match_features.confs['superglue']

extract_features.main(feature_conf, hloc_root, image_list=references_all, feature_path=features_path)
save_features_seg(root_dir, features_seg_path, features_path, references_all)

pairs_from_exhaustive.main(sfm_pairs, image_list=references)
match_features.main(matcher_conf, sfm_pairs, features=features_path, matches=matches_path)

In [None]:
from hloc.utils.io import get_keypoints

num_images = 4
plot_images([read_image(hloc_root / r) for r in references[:num_images]], dpi=75)
plot_keypoints([get_keypoints(features_path, r) for r in references[:num_images]], colors='lime', ps=4)

## Load Match Database

hloc saves all features and matches in a SQL database, so reading these is the easiest option.

In [None]:
from hloc.utils.io import get_keypoints
from flow_control.localize.hloc_utils import get_segmentation

name0 = references[1]
kps0, noise0 = get_keypoints(features_path, name0, return_uncertainty=True)
kps0_seg, noise0 = get_keypoints(features_seg_path, name0, return_uncertainty=True)
seg = get_segmentation(root_dir, name0)

plot_images([read_image(hloc_root / r) for r in [name0, ]]+[seg], dpi=75)
plot_keypoints([kps0, kps0_seg], colors='lime', ps=4)

In [None]:
from hloc.utils.io import get_matches
from flow_control.localize.hloc_utils import kp_seg_filter

name_q = references[1]
name_d = references[3]

matches, scores = get_matches(matches_path, name_q, name_d)
kps_q, noise_q = get_keypoints(features_path, name_q, return_uncertainty=True)
kps_d, noise_d = get_keypoints(features_path, name_d, return_uncertainty=True)
kps_q_match = kps_q[matches[:, 0]]
kps_d_match = kps_d[matches[:, 1]]

#%prun in_seg = kp_seg_filter_pb(kps_d_match, name_d)
in_seg = kp_seg_filter(kps_d_match, name_d, features_seg_path)

print("in_seg", in_seg)
print(kps_d_match[in_seg].shape)

kps_q_seg = kps_q_match[in_seg]
kps_d_seg = kps_d_match[in_seg]

## Get ground truth positions and orientations

Use TCP position and orientation at the point of grasp to get actual object positions and orientation

In [None]:
import numpy as np
from scipy.spatial.transform import Rotation as R

import ipdb

with open(parts_fn, 'r') as f_obj:
    part_info = json.load(f_obj)

def get_tcp_position_orn(demo_dir, frame_index):
    arr = np.load(os.path.join(demo_dir, f"frame_{frame_index:06d}.npz"),allow_pickle=True)
    state = arr["robot_state"].item()
    return state["tcp_pos"], state['tcp_orn']

def get_tcp_position_at_grasp(name_q):
    tmp = name_q.strip().split('/')[1]
    rec_name = tmp.split('_')[0]
    
    rec_path = os.path.join(root_dir, rec_name)
    gripper_close_idx = part_info[rec_name]['insert'][0]
    
    pos, orn = get_tcp_position_orn(rec_path, gripper_close_idx)
    return pos, orn

def compute_distance(name_q, name_d):
    pos_q, orn_q = get_tcp_position_at_grasp(name_q)
    pos_d, orn_d = get_tcp_position_at_grasp(name_d)
    
    orn_q = R.from_quat(orn_q)
    orn_d = R.from_quat(orn_d)
    
    orn_d = orn_d.inv().as_matrix()
    orn_q = orn_q.as_matrix()

    pos_err = np.linalg.norm(pos_q - pos_d)
    
    orn_dist = orn_q @ orn_d
    orn_dist = R.from_matrix(orn_dist)
    
    orn_err = np.linalg.norm(orn_dist.as_rotvec(degrees=True))
    
    return pos_err, orn_err

## HLOC Errors 

In [None]:
from hloc.visualization import plot_matches
from flow_control.localize.hloc_utils import get_playback, align_pointclouds
import matplotlib.pyplot as plt
import os
import json
    
def find_best_demo(name_q, query_cam, references):
    results = {}
    for name_d in tqdm(references):
        if name_q == name_d:
            continue
        
        res = align_pointclouds(root_dir, matches_path, features_path, features_seg_path,
                                           name_q, name_d, query_cam=query_cam)
        if res is None:
            continue
            
        res['trf_est'] = res['trf_est']
        res['num_inliers'] = int(res['num_inliers'])
        res['num_candidates'] = int(res['num_candidates'])
        res['in_score'] = float(res['num_candidates'])
        
        results[name_d] = res
#         plot_images([read_image(hloc_root / r) for r in [name_q, name_d]], dpi=75)
#         plot_matches(res["kps_q"], res["kps_d"], a=0.1)
#         plt.show()

    results = {k: v for k, v in results.items() if v is not None}
    results_sorted = sorted(results.items(), key=lambda t: -t[1]["num_inliers"])
    
    name_d_best = results_sorted[0][0]
    res_best = results_sorted[0][1]
    return name_d_best, res_best, results

hloc_pos_errors = []
hloc_orn_errors = []


for idx in range(len(references)):
    name_q = references[idx]
    pb, frame_index = get_playback(root_dir, name_q)
    query_cam = pb[frame_index].cam

    #%prun -D program.prof 
    name_d_best, res_best, results = find_best_demo(name_q, query_cam, references)
    
    pos_err, orn_err = compute_distance(name_q, name_d_best)
    
    hloc_pos_errors.append(pos_err)
    hloc_orn_errors.append(orn_err)

    # plot_images([read_image(hloc_root / r) for r in [name_q, name_d_best]], dpi=75)
    # plot_matches(res_best["kps_q"], res_best["kps_d"], a=0.1)

## Visual Similarity Errors

In [None]:
from sklearn.preprocessing import minmax_scale

recordings = sorted([root_dir / f for f in os.listdir(root_dir)])
recordings = recordings[:-1]

# Load Servoing Module
from flow_control.servoing.module import ServoingModule
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
servo_module = ServoingModule(recordings[0], control_config=control_config,
                              start_paused=False, flow_module='RAFT')

def similarity_from_reprojection(live_rgb, demo_rgb, demo_mask, return_images=False):
    # evaluate the similarity via flow reprojection error
    flow = servo_module.flow_module.step(demo_rgb, live_rgb)
    warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
    error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
    error = error.sum() / demo_mask.sum()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    if return_images:
        return error, mean_flow, flow, warped
    return error, mean_flow

sim_scores = np.ones((len(recordings), len(recordings)))  # lower is better
mean_flows = np.zeros((len(recordings), len(recordings)))

demo_good = [True for rec in recordings]

bad_pair_arr = np.zeros((len(recordings), len(recordings)), dtype=bool)
for idx in np.where(np.array(demo_good) == False)[0]:
    bad_pair_arr[:,idx] = True
    bad_pair_arr[idx,:] = True
bad_pair_arr += np.eye(len(recordings), len(recordings), dtype=bool)

good_pairs = list(zip(*np.where(bad_pair_arr==False)))

for live_i, demo_i in tqdm(good_pairs):
    live_rgb = read_image(hloc_root / references[live_i])

    demo_rgb =  read_image(hloc_root / references[demo_i])
    demo_mask = get_segmentation(root_dir, references[demo_i])
    
    error, mean_flow = similarity_from_reprojection(live_rgb.copy(), demo_rgb.copy(), demo_mask.copy())
    assert error <= 1.0
    sim_scores[demo_i, live_i] = error
    mean_flows[demo_i, live_i] = mean_flow


sim_l = sim_scores[demo_good]
mean_flows_l = mean_flows[demo_good]

sim_scores_norm = np.ones(sim_scores.shape)
w = .5
sim_scores_norm[demo_good] = np.mean((1*minmax_scale(sim_l), w*minmax_scale(mean_flows_l)),axis=0)/(1+w)

In [None]:
vs_pos_errors = []
vs_orn_errors = []

best_demo_idx = np.argmin(sim_scores_norm, axis=0)

for idx in range(len(recordings)):
    name_q = references[idx]
    name_d_best = references[best_demo_idx[idx]]
    
    pos_err, orn_err = compute_distance(name_q, name_d_best)
    vs_pos_errors.append(pos_err)
    vs_orn_errors.append(orn_err)

## R3M Errors, Masked and unmasked

In [None]:
from r3m import load_r3m
import torchvision.transforms as T
import torch
import ipdb

transform = T.Compose([T.ToTensor()])

def get_r3m_embeddings(playbacks, transform=None, device='cuda', masked=False):
    embeddings = []

    r3m = load_r3m("resnet50")
    r3m.eval()
    r3m.to(device)
    
    with torch.no_grad():
        
        for pb in playbacks:
            im = pb[18].cam.get_image()[0]
            
            if masked:
                mask = pb.get_fg_mask()
                mask = mask[..., np.newaxis].repeat(3, axis=2)
                im = im * mask
                
            im = transform(im)

            im = im.unsqueeze(0).cuda()
            embeddings.append(r3m(im * 255.0))
        
        embeddings = torch.cat(embeddings)
    
    embeddings = embeddings.detach().cpu().numpy()
    
    return embeddings

playbacks = [PlaybackEnvServo(rec, load='keep') for rec in recordings[:]]

embeddings_with_mask = get_r3m_embeddings(playbacks, transform=transform, device='cuda', masked=True)
embeddings_without_mask = get_r3m_embeddings(playbacks, transform=transform, device='cuda', masked=False)

In [None]:
from scipy.spatial import distance

r3m_sim_scores_masked = np.ones(bad_pair_arr.shape) * 10.0
r3m_sim_scores_no_mask = np.ones(bad_pair_arr.shape) * 10.0

for live_i, demo_i in tqdm(good_pairs):
    live_embedding = embeddings_without_mask[live_i, :] 
    
    demo_embedding = embeddings_without_mask[demo_i, :]
    demo_embedding_masked = embeddings_with_mask[demo_i, :]
    
    error_masked = np.linalg.norm(demo_embedding_masked - live_embedding)
    error_no_mask = np.linalg.norm(demo_embedding - live_embedding)

    r3m_sim_scores_masked[demo_i, live_i] = error_masked    
    r3m_sim_scores_no_mask[demo_i, live_i] = error_no_mask

In [None]:
r3m_masked_pos_errors = []
r3m_masked_orn_errors = []

r3m_unmasked_pos_errors = []
r3m_unmasked_orn_errors = []

best_demo_idx_masked = np.argmin(r3m_sim_scores_masked, axis=0)
best_demo_idx_no_mask = np.argmin(r3m_sim_scores_no_mask, axis=0)

for idx in range(len(recordings)):
    name_q = references[idx]
    
    name_d_best_masked = references[best_demo_idx_masked[idx]]
    name_d_best_no_mask = references[best_demo_idx_no_mask[idx]]
    
    pos_err_m, orn_err_m = compute_distance(name_q, name_d_best_masked)
    pos_err_nm, orn_err_nm = compute_distance(name_q, name_d_best_no_mask)
    
    r3m_masked_pos_errors.append(pos_dist_m)
    r3m_masked_orn_errors.append(orn_dist_m)
    
    r3m_unmasked_pos_errors.append(pos_dist_nm)
    r3m_unmasked_orn_errors.append(orn_dist_nm)

## R3M features 

In [None]:
import ipdb

def get_r3m_features(playbacks, transform=None, masked=False):
    embeddings = []
    features = []
    max_locations = []
    
    r3m = load_r3m("resnet50")
    r3m.cuda()
    
    # Update network
    r3m = r3m.module
    convnet = r3m.convnet

    arch = list(convnet.children())

    del arch[-2]  # AvgPool
    del arch[-1]  # FC

    convnet = nn.Sequential(*arch)
    r3m.convnet = convnet

    r3m = nn.DataParallel(r3m)
    r3m.eval()
    
    with torch.no_grad():
        
        for pb in playbacks:
            im = pb[0].cam.get_image()[0]
            
            if masked:
                mask = pb.get_fg_mask()
                mask = np.array(mask, dtype=float)
                
                tmp_mask = resize(mask, (8, 8))
                tmp_mask = np.asarray(Image.fromarray(mask).resize((8,8)))
                
                max_x, max_y = np.where(tmp_mask == np.amax(tmp_mask))
                x_max = max_x[0]
                y_max = max_y[0]
                
            im = transform(im)

            im = im.unsqueeze(0).cuda()
            emb = r3m(im * 255.0)
            feat = emb[:, :, x_max, y_max]
            max_locations.append((x_max, y_max))
            
            embeddings.append(emb)
            features.append(feat)
        
        features = torch.cat(features)
        embeddings = torch.cat(embeddings)
    
    features = features.detach().cpu().numpy()
    embeddings = embeddings.detach().cpu().numpy()
    max_locations = np.stack(max_locations)
    
    return embeddings, features, max_locations

r3m_embeddings, r3m_features, feature_loc = get_r3m_features(playbacks, transform=transform, masked=True)

In [None]:
from scipy.spatial import distance

r3m_sim_scores_features = np.ones(bad_pair_arr.shape) * 10.0

for live_i, demo_i in tqdm(good_pairs):
    live_embedding = torch.mean(r3m_embeddings[live_i, :], dim=(0, 1)) 
    
    demo_embedding = r3m_features[demo_i, :]
    
    error = np.linalg.norm(demo_embedding - live_embedding)

    r3m_sim_scores_features[demo_i, live_i] = error

In [None]:
r3m_feat_pos_errors = []
r3m_feat_orn_errors = []

best_demo_idx_masked = np.argmin(r3m_sim_scores_features, axis=0)

for idx in range(len(recordings)):
    name_q = references[idx]
    
    name_d_best = references[best_demo_idx_masked[idx]]
    
    pos_err_m, orn_err_m = compute_distance(name_q, name_d_best)
    
    r3m_feat_pos_errors.append(pos_dist_m)
    r3m_feat_orn_errors.append(orn_dist_m)

## Plotting

In [None]:
# Orientation Errors

max_orn_error = max(np.max(vs_orn_errors), np.max(hloc_orn_err), 
                    np.max(r3m_masked_orn_errors), np.max(r3m_unmasked_orn_errors), np.max(r3m_feat_orn_errors))

res = stats.cumfreq(vs_orn_errors, numbins=30, defaultreallimits=(0.0, max_orn_error))
x = res.lowerlimit + np.linspace(0, res.binsize*res.cumcount.size, res.cumcount.size)
plt.plot(x, res.cumcount / len(vs_orn_errors), label='VS')

res = stats.cumfreq(hloc_orn_err, numbins=30, defaultreallimits=(0.0, max_orn_error))
plt.plot(x, res.cumcount / len(hloc_orn_err), label='HLOC')

res = stats.cumfreq(r3m_masked_orn_errors, numbins=30, defaultreallimits=(0.0, max_orn_error))
plt.plot(x, res.cumcount / len(r3m_masked_orn_errors), label='R3M_masked')

res = stats.cumfreq(r3m_unmasked_orn_errors, numbins=30, defaultreallimits=(0.0, max_orn_error))
plt.plot(x, res.cumcount / len(r3m_unmasked_orn_errors), label='R3M_no_mask')

res = stats.cumfreq(r3m_feat_orn_errors, numbins=30, defaultreallimits=(0.0, max_orn_error))
plt.plot(x, res.cumcount / len(r3m_feat_orn_errors), label='R3M Pixel Features')

plt.xlabel("Error (Mag)")
plt.ylabel("Samples (%)")
plt.legend()

plt.show()

In [None]:
# Position Errors

max_dist_error = max(np.max(vs_pos_errors), np.max(hloc_pos_errors), 
                    np.max(r3m_masked_pos_errors), np.max(r3m_unmasked_pos_errors), np.max(r3m_feat_pos_errors))

res = stats.cumfreq(vs_pos_errors, numbins=30, defaultreallimits=(0.0, max_dist_error))
x = res.lowerlimit + np.linspace(0, res.binsize*res.cumcount.size, res.cumcount.size)
plt.plot(x * 1000, res.cumcount / len(vs_pos_errors), label='VS')

res = stats.cumfreq(hloc_pos_errors, numbins=30, defaultreallimits=(0.0, max_dist_error))
plt.plot(x * 1000, res.cumcount / len(hloc_pos_errors), label='HLOC')

res = stats.cumfreq(r3m_masked_pos_errors, numbins=30, defaultreallimits=(0.0, max_dist_error))
plt.plot(x * 1000, res.cumcount / len(r3m_masked_pos_errors), label='R3M_masked')

res = stats.cumfreq(r3m_unmasked_pos_errors, numbins=30, defaultreallimits=(0.0, max_dist_error))
plt.plot(x * 1000, res.cumcount / len(r3m_unmasked_pos_errors), label='R3M_no_mask')

res = stats.cumfreq(r3m_feat_pos_errors, numbins=30, defaultreallimits=(0.0, max_dist_error))
plt.plot(x * 1000, res.cumcount / len(r3m_feat_pos_errors), label='R3M Pixel Features')

plt.xlabel("Error (mm)")
plt.ylabel("Samples (%)")
plt.legend()

plt.show()