# Recombination & Retrival for Servoing

This notebook shows an example of how a suitable demonstration can be selected if several candidates are avaliable. The current selection is based on reprojection error.


## 1. Load Existing Demonstrations

In [None]:
import os
import copy
import json
import logging
from glob import glob
from pathlib import Path

from tqdm import tqdm
import numpy as np
#get_ipython().run_line_magic('matplotlib', 'notebook')
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout

from flow_control.recombination.record_multi import get_configurations
from flow_control.servoing.playback_env_servo import PlaybackEnvServo

root_dir = "/tmp/flow_experiments3"
num_episodes = 20

demo_cfgs = get_configurations(prefix="demo", num_episodes=num_episodes)
recordings = []
for _, _, _, demo_seed, demo_dir in demo_cfgs:
    recordings.append(demo_dir)
    
print("Number of recordings:", len(recordings))
print(recordings[0])
print(recordings[-1])

In [None]:
# Load the demonstration episodes
playbacks = [PlaybackEnvServo(rec) for rec in recordings[:]]

In [None]:
# Plot the demonstrations
%matplotlib notebook
fig, ax = plt.subplots(1,figsize=(8, 6))
fig.suptitle("Demonstration Frames")
ax.set_axis_off()
image_h = ax.imshow(playbacks[0].cam.get_image()[0])

def update(demo_index, frame_index):
    image = playbacks[demo_index][frame_index].cam.get_image()[0]
    image_h.set_data(image)
    fig.canvas.draw_idle()
    print("wp_name:", playbacks[demo_index][frame_index].get_info()["wp_name"])
    fg_mask = playbacks[demo_index].get_fg_mask()
    if fg_mask is not None:
        print("percent fg:", np.mean(fg_mask)*100)
    
slider_w = widgets.IntSlider(min=0, max=len(playbacks)-1, step=1, value=0,
                             layout=Layout(width='70%'))
slider_i = widgets.IntSlider(min=0, max=200-1, step=1, value=0,
                             layout=Layout(width='70%'))

interact(update, demo_index=slider_w, frame_index=slider_i)

In [None]:
# Summarize the demonstration info
fg_pixels = [np.mean(pb.get_fg_mask())*100 for pb in playbacks]
header = "ep\t len\t rew\t fg[%]"
print(header,"\n"+"-"*len(header.replace("\t"," "*5)))
for i, pb in enumerate(playbacks):
    print(i, "\t", len(pb), "\t", pb[-1].data["rew"],"\t", round(fg_pixels[i], 2))

## 2. Fiter Bad Demonstrations

In [None]:
# Identify bad demonstrations
demo_fails = np.zeros((len(playbacks), len(playbacks)), dtype=bool)
for i, pb in enumerate(playbacks):
    if pb[-1].data["rew"] == 0:
        demo_fails[:, i] = 1
        demo_fails[i, :] = 1
        
demo_lowseg = np.zeros((len(playbacks), len(playbacks)), dtype=bool)
for i, pfg in enumerate(fg_pixels):
    if fg_pixels[i] < 0.5:
        demo_lowseg[:, i] = 1
        demo_lowseg[i, :] = 1
        
demo_same = np.eye(len(playbacks), len(playbacks), dtype=bool)
demo_bad = demo_fails + demo_lowseg + demo_same
print(demo_bad.astype(int))
print("Viable perc:", 1-demo_bad.mean())
print("Viable runs:", sum(demo_bad.flatten()==0))

## Computer Errors for Viable Pairs

In [None]:
# Load Servoing Module
from flow_control.servoing.module import ServoingModule
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
servo_module = ServoingModule(demo_dir, control_config=control_config,
                              start_paused=False)

In [None]:
fig, (ax, ax2, ax3) = plt.subplots(1, 3)
ax.set_title("live")
ax2.set_title("demo")
ax3.set_title("warped")
ax.set_axis_off()
ax2.set_axis_off()
ax3.set_axis_off()
empty_image = np.zeros((256, 256), dtype=np.uint8)
image_h = ax.imshow(empty_image)
image_h2 = ax2.imshow(empty_image)
image_h3 = ax3.imshow(empty_image)

def update(live_index, demo_index):
    image_l = playbacks[live_index][0].cam.get_image()[0]
    
    is_good_demo = demo_bad[demo_index, live_index] == 0
    if not is_good_demo:
        image_h2.set_data(empty_image)
        image_h3.set_data(empty_image)
        print("no valid demos.")
        return
    
    image_d = playbacks[demo_index][0].cam.get_image()[0]
    flow = servo_module.flow_module.step(image_d, image_l)
    warped = servo_module.flow_module.warp_image(image_l / 255.0, flow)
    
    image_h.set_data(image_l)
    image_h2.set_data(image_d)
    image_h3.set_data(warped)
    fig.canvas.draw_idle()
    
    demo_mask =  playbacks[demo_index].get_fg_mask()
    error = np.linalg.norm((warped - (image_d / 255.0)), axis=2) * demo_mask
    error = error.mean() / demo_mask.mean()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    print(f"error {error:.2f} mean_flow {mean_flow:.2f}")
    
slider_l = widgets.IntSlider(min=0, max=num_episodes-1, step=1, value=3,
                             layout=Layout(width='70%'))
slider_d = widgets.IntSlider(min=0, max=num_episodes-1, step=1, value=18,
                             layout=Layout(width='70%'))
interact(update, live_index=slider_l, demo_index=slider_d)
print("l: live frame   d: demo frame")

In [None]:
def get_error(live_rgb, demo_rgb, demo_mask):
    # compute the reprojection error
    flow = servo_module.flow_module.step(demo_rgb, live_rgb)
    warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
    error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
    error = error.sum() / demo_mask.sum()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    return error, mean_flow

errors = np.ones(demo_bad.shape)  # lower is better
mean_flows = np.zeros(demo_bad.shape)

for live_i in tqdm(range(num_episodes)):    
    live_rgb = playbacks[live_i][0].cam.get_image()[0]
    for demo_i in range(num_episodes):
        if demo_bad[demo_i, live_i] == 1:
            continue
            
        demo_rgb =  playbacks[demo_i][0].cam.get_image()[0]
        demo_mask =  playbacks[demo_i].get_fg_mask()
        error, mean_flow = get_error(live_rgb, demo_rgb, demo_mask)
        assert error <= 1.0
        errors[demo_i, live_i] = error
        mean_flows[demo_i, live_i] = mean_flow
        
        #print(live_i, demo_i, f"\t{error:02f}\t", round(mean_flow,2))
    #print(np.argsort(errors[:, live_i]))
    #print()

In [None]:
#errors_l = errors[np.logical_not(demo_bad)]
#mean_flows_l = mean_flows[np.logical_not(demo_bad)]
#fig, ax = plt.subplots(1, 1)
#ax.scatter( errors_l, mean_flows_l)
#plt.show()

## 2. Pick Best Demonstrations

In [None]:
good_episode = np.any(np.logical_not(demo_bad),axis=0)
print("live episode", np.arange(num_episodes)[good_episode])
print("demo episode", np.argmin(errors, axis=0)[good_episode])

In [None]:
fig, (ax, ax2) = plt.subplots(1, 2)
ax.set_title("live")
ax2.set_title("demo")
ax.set_axis_off()
ax2.set_axis_off()
empty_image = np.zeros((256, 256), dtype=np.uint8)
image_h = ax.imshow(empty_image)
image_h2 = ax2.imshow(empty_image)

def update(live_index, demo_rank):
    image_l = playbacks[live_index][0].cam.get_image()[0]
    # chose d'th best demo
    demo_i = np.argsort(errors[:, live_index])[demo_rank]  
    is_good_demo = demo_bad[demo_i, live_index] == 0
    if not is_good_demo:
        image_h2.set_data(empty_image)
        print("no good demo.")
        return
    print(f"{demo_rank} th best demo is {demo_i}   err {errors[demo_i, live_index]:.2f}   dist {mean_flows[demo_i, live_index]:.2f}")
    image_d = playbacks[demo_i][0].cam.get_image()[0]
    image_h.set_data(image_l)
    image_h2.set_data(image_d)
    fig.canvas.draw_idle()
    
slider_l = widgets.IntSlider(min=0, max=num_episodes-1, step=1, value=3,
                             layout=Layout(width='70%'))
slider_d = widgets.IntSlider(min=0, max=num_episodes-1, step=1, value=0,
                             layout=Layout(width='70%'))
interact(update, live_index=slider_l, demo_rank=slider_d)

## Run Servoing

In [None]:
import shutil
from importlib import reload

from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.servoing.module import ServoingModule
from flow_control.runner import evaluate_control

renderer = "debug"
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.35)

def eval_cmb(live_seed, demo_seed, exist_ok=False):
    save_dir = f"{root_dir}_{demo_seed:03d}"
    demo_dir = recordings[demo_seed]
    if Path(save_dir).is_dir():
        if exist_ok:
            shutil.rmtree(save_dir)
        else:
            raise ValueError
    servo_module = ServoingModule(demo_dir, control_config=control_config,
                                  start_paused=False, plot=False, plot_save_dir=None)
    env = RobotSimEnv(task='shape_sorting', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=500, control='absolute-full',
                      img_size=(256, 256), param_randomize=("geom",), seed=int(live_seed))
    _, reward, _, info = evaluate_control(env, servo_module,
                                          max_steps=130,
                                          save_dir=save_dir)
    return reward

In [None]:
live_seed, demo_seed = 3, 19
reward = eval_cmb(live_seed, demo_seed, exist_ok=True)
print(f"live_i {live_seed} demo_seed {demo_seed} -> reward {reward}")

In [None]:
rewards = np.zeros(demo_bad.shape, dtype=int) - 1

for live_i in np.arange(num_episodes):
    demo_i = np.argmin(errors[:, live_i])
    if demo_bad[demo_i, live_i]:
        continue
    print(live_i, demo_i)
    reward = eval_cmb(live_i, demo_i)
    rewards[demo_i, live_i] = reward

In [None]:
def hacky_get_reward(save_dir, demo_seed):
    save_dir2 = f"{save_dir}_{demo_seed:03d}"
    frame_names = sorted(glob(f"{save_dir2}/frame_*.npz"))
    try:
        return np.load(frame_names[-1])["rew"].item()
    except IndexError:
        #print("Warning: empty epiosode {}".format(save_dir2))
        return -1
    #rewards = [np.load(frame)["rew"].item() for frame in frame_names]
    #return len(frame_names)    
    
row_rewards = []
demo_cfgs = get_configurations(prefix="demo", num_episodes=num_episodes)
for _, _, _, demo_seed, demo_dir in demo_cfgs:
    run_cfgs =  get_configurations(prefix="run", num_episodes=num_episodes)
    rewards = []
    for _, _, _, seed, save_dir in run_cfgs:
        reward = hacky_get_reward(save_dir, demo_seed)
        rewards.append(reward)
    row_rewards.append(rewards)

print("done!")
tmp = np.array(row_rewards, dtype=int)
print(tmp)
print("mean performance", tmp.mean())
print("mean performance (masked): ", np.ma.masked_array(tmp, mask=demo_bad).mean().round(3))

fig, ax = plt.subplots()
fig.suptitle("Success for Demo/Live")
ax.set_xlabel("live ep")
ax.set_ylabel("demo ep")
ax.imshow(tmp)
plt.show()

In [None]:
X,Y,c = [], [], []
for x in range(tmp.shape[0]):
    for y in range(tmp.shape[1]):
        X.append(fg_pixels[x])
        Y.append(fg_pixels[y])
        c.append(tmp[x,y])

fig, ax = plt.subplots()
ax.scatter(X,Y,c=c)
fig.suptitle("Success vs %FG")
ax.set_xlabel("demo %FG pixels")
ax.set_ylabel("live %FG pixels")
plt.show()

In [None]:
from collections import defaultdict

def hacky_get_edge_frames(save_dir, demo_seed):
    save_dir2 = f"{save_dir}_{demo_seed:03d}"
    frame_names = sorted(glob(f"{save_dir2}/frame_*.npz"))
    return frame_names[0], frame_names[-1]

def hacky_get_ep_len(save_dir, demo_seed):
    save_dir2 = f"{save_dir}_{demo_seed:03d}"
    frame_names = sorted(glob(f"{save_dir2}/frame_*.npz"))
    return len(frame_names)    

def to_vec(frame_info):
    keys = ["fit_pc_size", "fit_inliers", "fit_q_pos", "fit_q_col"]
    x = []
    for key in keys:
        try:
            x.append(frame_info[key])
        except KeyError:
            x.append(-1)
    return x

run_data = []
X, Y = [], []

demo_cfgs = get_configurations(prefix="demo")
for _, _, _, demo_seed, demo_dir in demo_cfgs:
    run_cfgs =  get_configurations(prefix="run")
    rewards = []
    for _, _, _, seed, save_dir in run_cfgs:
        if demo_bad[demo_seed, seed]:
            continue
        tmp = {}
        tmp["run_name"] = f"{save_dir}_{demo_seed:03d}"
        tmp["first_frame"], tmp["last_frame"] = hacky_get_edge_frames(save_dir, demo_seed)
        tmp["reward"] = hacky_get_reward(save_dir, demo_seed)
        tmp["len"] = hacky_get_ep_len(save_dir, demo_seed)        
        first_frame = np.load(tmp["first_frame"], allow_pickle=True)
        
        ff_info = first_frame["info"].item()
        
        Y.append(tmp["reward"])
        X.append(to_vec(ff_info))
        
        try:
            print(tmp["reward"], ff_info.keys())
            tmp["initial_q"] = ff_info["fit_q_col"]
        except KeyError:
            continue
            tmp["initial_q"] = -1.
        #print(first_frame["info"].item().keys())
        run_data.append(tmp)

In [None]:
runs_success = [tmp for tmp in run_data if tmp["reward"] == 1.0]
runs_failure = [tmp for tmp in run_data if tmp["reward"] == 0]

key = "initial_q"
fig, ax = plt.subplots(1)
fig.suptitle(f"Outcome vs {key}")
ax.hist([x[key] for x in runs_success], bins=20, alpha=0.8, label="success")
ax.hist([x[key] for x in runs_failure], bins=20, alpha=0.8, label="failure")
ax.set_xlabel(f"run {key}")
ax.set_ylabel("ep count")
plt.legend()
plt.show()

In [None]:
#cur_data = runs_success
cur_data = runs_failure
#key = "last_frame"
key = "first_frame"

fig, (ax,ax2) = plt.subplots(1, 2)
fig.suptitle(f"Servo Run: {key.replace('_',' ').title()}")
ax.set_axis_off()
ax2.set_axis_off()
image_h = ax.imshow(np.zeros((256,256)))
image_h2 = ax2.imshow(np.zeros((256,256)))


def update(w):
    tmp = np.load(cur_data[w][key],allow_pickle=True)
    image = tmp["rgb_gripper"]
    last_frame = np.load(cur_data[w]["last_frame"],allow_pickle=True)["rgb_gripper"]
    
    try:
        print(tmp["info"].item()["fit_q_col"])
    except KeyError:
        pass
    print("reward", cur_data[w]["reward"], "seed", cur_data[w])
    print(cur_data[w][key])
    
    image_h.set_data(image)
    image_h2.set_data(last_frame)
    fig.canvas.draw_idle()
    
slider_w = widgets.IntSlider(min=0, max=len(cur_data)-1, step=1, value=0,
                             layout=Layout(width='70%'))
interact(update, w=slider_w)

In [None]:
X = np.array(X)
from sklearn.neural_network import MLPClassifier
clf = MLPClassifier(solver='lbfgs', alpha=1e-5,
                    hidden_layer_sizes=(), random_state=1)
clf.fit(X, Y)
Y_pred = clf.predict(X)
print(np.mean(Y==Y_pred))

In [None]:
from collections import defaultdict
import networkx as nx

graph = defaultdict(int)
for pb in playbacks:
    prev_node = None
    for i in range(len(pb)):
        wp_name = pb[i].get_info()["wp_name"]
        if prev_node != wp_name and prev_node is not None:
            #print("wp_name:", wp_name)
            graph[(prev_node,wp_name)]  += 1
        prev_node = wp_name
    #print()
print(graph)


G = nx.Graph()
for (prev_node, node), weight in graph.items():
    G.add_edge(prev_node,node,weight=weight/num_episodes)
    
elarge = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] > 0.5]

pos = nx.spring_layout(G, seed=0)  # positions for all nodes - seed for reproducibility

fig, ax = plt.subplots(1)

# nodes
#nx.draw_networkx_nodes(G, pos, node_size=700)
#nx.draw_networkx_edges(G, pos)

# node labels
#nx.draw_networkx_labels(G, pos, font_family="sans-serif")
#edge_labels = nx.get_edge_attributes(G, "weight")
#nx.draw_networkx_edge_labels(G, pos, edge_labels)
nx.draw_networkx(G, pos, node_size=700, node_color='#c4daef')

ax.set_ylim(-1,1)
ax.set_xlim(-1,1)
ax.axis("off")
#plt.tight_layout()
#plt.show()