In [None]:
import os
import copy
import json
import logging
from glob import glob
from pathlib import Path

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout

from flow_control.servoing.playback_env_servo import PlaybackEnvServo
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")
root_dir = "../tmp/ss_traj_based"
task = "shape_sorting"
object_selected = "trapeze" 
task_variant = "rP"  # rotation plus (+-pi)

def get_configurations(root_dir=root_dir, num_episodes=70, prefix=""):
    os.makedirs(root_dir, exist_ok=True)
    save_dir_template = os.path.join(root_dir, f"{prefix}_{task}_{object_selected}")
    for seed in range(num_episodes):
        save_dir = save_dir_template + f"_{task_variant}"+f"_seed{seed:03d}"
        yield object_selected, seed, save_dir
            
demo_cfgs = get_configurations(prefix="demo")
recordings = []
for _, demo_seed, demo_dir in demo_cfgs:
    recordings.append(demo_dir)
    
print("Number of recordings:", len(recordings))
print("first", recordings[0])
print("last ", recordings[-1])

In [None]:
live_seed, demo_seed = 0, 0
demo_dir = recordings[demo_seed]
print(f"live: {live_seed} demo: {demo_seed} @ {demo_dir}")
print()

demo = PlaybackEnvServo(demo_dir, load="keep")
print("demo keep:", list(demo.keep_dict.keys()))
print()

demo_parts_fn = os.path.join(root_dir, "demo_parts_manual3.json")
with open(demo_parts_fn) as f_obj:
    demo_parts = json.load(f_obj)

demo_keep = sorted(list(demo.keep_dict.keys()))
keep_all = copy.copy(demo.keep_dict)
keep_parts = {}
for p in demo_parts[str(demo_seed)]:
    if p["start"] == 0:
        p_start = -1
    else:
        p_start = p["start"]
        
    parts = []
    for demo_index in demo_keep:
        if p_start < demo_index and p["end"] >= demo_index:
            parts.append(demo_index)
    print(p["name"], '\t', parts)
    
    keep_parts[p["name"]] = dict([(i, demo.keep_dict[i]) for i in parts])
    print(keep_parts[p["name"]])
# set keep_dict to first part
demo.keep_dict = keep_parts["locate"]
#servo_module


In [None]:
playbacks = [PlaybackEnvServo(rec, load='all') for rec in recordings[:]]
good_demos = [int(key) for key in demo_parts.keys()]
demo_good = good_demos

In [None]:
good_demos

In [None]:
from sklearn.preprocessing import minmax_scale
# Load Servoing Module
from flow_control.servoing.module import ServoingModule
control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
servo_module = ServoingModule(demo_dir, control_config=control_config,
                              start_paused=False)
def similarity_from_reprojection(live_rgb, demo_rgb, demo_mask, return_images=False):
    # evaluate the similarity via flow reprojection error
    flow = servo_module.flow_module.step(demo_rgb, live_rgb)
    warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
    error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
    error = error.sum() / demo_mask.sum()
    mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
    if return_images:
        return error, mean_flow, flow, warped
    return error, mean_flow

def normalize_errors(errors, flows, demo_good):
    errors_l = errors[demo_good]
    mean_flows_l = flows[demo_good]
    errors_norm = np.ones(errors.shape)
    w = .5
    errors_norm[demo_good] = np.mean((1*minmax_scale(errors_l), w*minmax_scale(mean_flows_l)),axis=0)/(1+w) 
    
    return errors_norm

In [None]:
def compute_current_scores(playbacks, current_rgb, demo_parts, demo_good, traj_idx=0, live_seed=0):    
    sim_errors = np.ones(len(playbacks)) # lower is better
    mean_flows = np.zeros(len(playbacks))

    for demo_seed in demo_good:
        if traj_idx == 0 and demo_seed == live_seed:
            continue
        start_idx = demo_parts[str(demo_seed)][traj_idx]['start']
        demo_rgb =  playbacks[demo_seed][start_idx].cam.get_image()[0]
        demo_mask =  playbacks[demo_seed].fg_masks[start_idx]
        error, mean_flow = similarity_from_reprojection(current_rgb, demo_rgb, demo_mask)
        sim_errors[demo_seed] = error
        mean_flows[demo_seed] = mean_flow
    errors_norm = normalize_errors(sim_errors, mean_flows, demo_good)
    scores = 1 - errors_norm
    
    return scores

In [None]:
import ipdb
def split_keypoints(pb, demo_part):
    demo_keep = sorted(list(pb.keep_dict.keys()))
    keep_all = copy.copy(pb.keep_dict)
    keep_parts = {}
    for p in demo_part:
#         ipdb.set_trace()
        if p["start"] == 0:
            p_start = -1
        else:
            p_start = p["start"]

        parts = []
        for demo_index in demo_keep:
            if p_start < demo_index and p["end"] >= demo_index:
                parts.append(demo_index)
#         print(p["name"], '\t', parts)
        
    

        keep_parts[p["name"]] = parts
    keep_parts['grasp'].append(keep_parts['insert'][0])
    return keep_parts

keypoint_info = {}
# pb_keep = [PlaybackEnvServo(rec, load='keep') for rec in recordings[:]]
for demo_seed in good_demos:
    keypoint_info[demo_seed] = split_keypoints(playbacks[demo_seed], demo_parts[str(demo_seed)])
# keypoint_info = {0: {'locate': [0, 4], 'grasp': [7, 14, 26], 'insert': [26, 31, 37, 42, 44, 47]}}

In [None]:
demo_parts
keypoint_info

In [None]:
demo_good = np.array(demo_good)

In [None]:
from math import pi
from flow_control.servoing.module import ServoingModule
from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.runner import evaluate_control
import ipdb

def eval_cmb(playbacks, demo_good, live_seed, demo_parts, keypoint_info, exist_ok=False):
    renderer = 'debug'
    # Instantiate env
    env = RobotSimEnv(task='shape_sorting', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=500, control='absolute-full',
                      img_size=(256, 256), param_randomize=("geom",), seed=int(live_seed),
                      task_info=dict(object_rot_range={"rP":pi/2.,"rR":pi/6.}[task_variant]))
    
    traj_map = {0: 'locate', 1: 'grasp', 2: 'insert'}
    
    
    for idx in range(3):
        
        state, _, _, _ = env.step(None)
        current_rgb = state['rgb_gripper']

#         current_rgb = np.ones((256, 256, 3))
        
        scores = compute_current_scores(playbacks, current_rgb, demo_parts, demo_good, traj_idx=idx, live_seed=live_seed)
        best_demo_seed = np.argmax(scores)
        
        best_demo = recordings[best_demo_seed]
        kp_info = keypoint_info[best_demo_seed]
        kps = kp_info[traj_map[idx]]
        
#         ipdb.set_trace()
        servo_module = ServoingModule(best_demo, control_config=control_config,
                                      start_paused=False, plot=False, plot_save_dir=None,
                                      load='select', selected_kp=kps)
        _, reward, _, info = evaluate_control(env, servo_module, max_steps=130, save_dir=None,
                                             initial_align=True if idx == 0 else False)
#         ipdb.set_trace()
    del env
    del servo_module
    return reward

In [None]:
steps = [1, 5, 10, 20, 30, 40, 50, 60]
num_live_seeds = 20
rewards = np.zeros((num_live_seeds, len(steps)))

for live_idx, live_seed in enumerate(good_demos[0:num_live_seeds]):
    for step_idx, step_value in enumerate(steps):
        step_demo_good = demo_good[0:step_value]
        reward = eval_cmb(playbacks, step_demo_good, live_seed, demo_parts, keypoint_info)
        rewards[live_idx, step_idx] = reward

In [None]:
np.savez('rewards_ss_steps.npz', rewards)

In [None]:
import matplotlib.pyplot as plt
rew_mean = np.mean(rewards, axis=0)
plt.plot(steps, rew_mean, ".-")
plt.xlabel("#Recordings")
plt.ylabel("Mean Rewards")
plt.title("Shape-Sorting Task (Debug Mode)")
plt.savefig('rewards_ss.jpg', dpi=800)

In [None]:
renderer = "debug"
def eval_cmb_single(live_seed, demo_seed, exist_ok=False):
    save_dir = f"{root_dir}_{demo_seed:03d}"
    demo_dir = recordings[demo_seed]
#     if Path(save_dir).is_dir():
#         if exist_ok:
#             shutil.rmtree(save_dir)
#         else:
#             raise ValueError
    servo_module = ServoingModule(demo_dir, control_config=control_config,
                                  start_paused=False, plot=False, plot_save_dir=None)
    env = RobotSimEnv(task='shape_sorting', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=500, control='absolute-full',
                      img_size=(256, 256), param_randomize=("geom",), seed=int(live_seed),
                      task_info=dict(object_rot_range={"rP":pi/2.,"rR":pi/6.}[task_variant]))
    _, reward, _, info = evaluate_control(env, servo_module,
                                          max_steps=130,
                                          save_dir=None)
    return reward, save_dir

In [None]:
single_reward = np.zeros((len(good_demos), len(good_demos)))
for d_idx, demo_seed in enumerate(good_demos):
    for l_idx, live_seed in enumerate(good_demos):
        if live_seed == demo_seed:
            continue
        reward, _ = eval_cmb_single(live_seed, demo_seed)
        single_reward[d_idx, l_idx] = reward

In [None]:
single_reward

In [None]:
single_rew_mean = np.mean(single_reward, axis=1)
single_rew_mean

In [None]:
np.mean(single_rew_mean)

In [None]:
np.std(single_rew_mean)