In [None]:
import os
import copy
import json
import logging
from glob import glob
from pathlib import Path

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets, interact, Layout

from flow_control.servoing.playback_env_servo import PlaybackEnvServo
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

root_dir = "/misc/student/nayaka/paper/flowcontrol/flow_control/demo/tmp_new/cnn_new"
task = "shape_sorting"
object_selected = "trapeze" 
task_variant = "rP"  # rotation plus (+-pi)


def get_recordings(directory):
    return sorted([os.path.join(directory, rec) for rec in os.listdir(directory) if os.path.isdir(os.path.join(directory, rec))])

recordings = get_recordings(root_dir)

print("Number of recordings:", len(recordings))
print("first", recordings[0])
print("last ", recordings[-1])

In [None]:
live_seed, demo_seed = 0, 0
demo_dir = recordings[demo_seed]
print(f"live: {live_seed} demo: {demo_seed} @ {demo_dir}")
print()

demo = PlaybackEnvServo(demo_dir, load="keep")
print("demo keep:", list(demo.keep_dict.keys()))
print()

demo_parts_fn = os.path.join(root_dir, "demo_parts_manual3.json")
with open(demo_parts_fn) as f_obj:
    demo_parts = json.load(f_obj)

demo_keep = sorted(list(demo.keep_dict.keys()))
keep_all = copy.copy(demo.keep_dict)
keep_parts = {}
for p in demo_parts[str(demo_seed)]:
    if p["start"] == 0:
        p_start = -1
    else:
        p_start = p["start"]
        
    parts = []
    for demo_index in demo_keep:
        if p_start < demo_index and p["end"] >= demo_index:
            parts.append(demo_index)
    print(p["name"], '\t', parts)
    
    keep_parts[p["name"]] = dict([(i, demo.keep_dict[i]) for i in parts])
    print(keep_parts[p["name"]])
# set keep_dict to first part
demo.keep_dict = keep_parts["locate"]
#servo_module


In [None]:
playbacks = [PlaybackEnvServo(rec, load='all') for rec in recordings[:]]
good_demos = [int(key) for key in demo_parts.keys()]
demo_good = good_demos
demo_good = np.array(demo_good)

In [None]:
# Visual Similarity scores

# from sklearn.preprocessing import minmax_scale
# # Load Servoing Module
# from flow_control.servoing.module import ServoingModule
# control_config = dict(mode="pointcloud-abs-rotz", threshold=0.40)
# servo_module = ServoingModule(demo_dir, control_config=control_config,
#                               start_paused=False)

# def similarity_from_reprojection(live_rgb, demo_rgb, demo_mask, return_images=False):
#     # evaluate the similarity via flow reprojection error
#     flow = servo_module.flow_module.step(demo_rgb, live_rgb)
#     warped = servo_module.flow_module.warp_image(live_rgb / 255.0, flow)
#     error = np.linalg.norm((warped - (demo_rgb / 255.0)), axis=2) * demo_mask
#     error = error.sum() / demo_mask.sum()
#     mean_flow = np.linalg.norm(flow[demo_mask],axis=1).mean()
#     if return_images:
#         return error, mean_flow, flow, warped
#     return error, mean_flow

# def normalize_errors(errors, flows):
#     errors_l = errors[demo_good]
#     mean_flows_l = flows[demo_good]
#     errors_norm = np.ones(errors.shape)
#     w = .5
#     errors_norm[demo_good] = np.mean((1*minmax_scale(errors_l), w*minmax_scale(mean_flows_l)),axis=0)/(1+w) 
    
#     return errors_norm


# def compute_current_scores(playbacks, current_rgb, demo_parts, demo_good, traj_idx=0, live_seed=0):    
#     sim_errors = np.ones(len(playbacks)) # lower is better
#     mean_flows = np.zeros(len(playbacks))

#     for demo_seed in demo_good:
#         if traj_idx == 0 and demo_seed == live_seed:
#             continue
#         start_idx = demo_parts[str(demo_seed)][traj_idx]['start']
#         demo_rgb =  playbacks[demo_seed][start_idx].cam.get_image()[0]
#         demo_mask =  playbacks[demo_seed].fg_masks[start_idx]
#         error, mean_flow = similarity_from_reprojection(current_rgb, demo_rgb, demo_mask)
#         sim_errors[demo_seed] = error
#         mean_flows[demo_seed] = mean_flow
#     errors_norm = normalize_errors(sim_errors, mean_flows)
#     scores = 1 - errors_norm
    
#     return scores

In [None]:
import ipdb
def split_keypoints(pb, demo_part):
    demo_keep = sorted(list(pb.keep_dict.keys()))
    keep_all = copy.copy(pb.keep_dict)
    keep_parts = {}
    for p in demo_part:
        if p["start"] == 0:
            p_start = -1
        else:
            p_start = p["start"]

        parts = []
        for demo_index in demo_keep:
            if p_start < demo_index and p["end"] >= demo_index:
                parts.append(demo_index)    

        keep_parts[p["name"]] = parts
    keep_parts['grasp'].append(keep_parts['insert'][0])
    return keep_parts

keypoint_info = {}

for demo_seed in good_demos:
    keypoint_info[demo_seed] = split_keypoints(playbacks[demo_seed], demo_parts[str(demo_seed)])

## Load Model

1. Update this part depending on the model
2. Use the created model as an argument in the compute_current_scores function

In [None]:
from train_sim.network import SimilarityNet
import ipdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torchvision
from torchvision import transforms as T

model0 = nn.DataParallel(SimilarityNet(), device_ids=[0])
model0.cuda()

load_model0 = './train_sim/models/models_part0/simNet_500.pth'

if load_model0 is not None:
    print(f"Loading Model: {load_model0}")
    model0.load_state_dict(torch.load(load_model0))
    
    model0 = nn.DataParallel(SimilarityNet(), device_ids=[0])

    model0.cuda()
    
model1 = nn.DataParallel(SimilarityNet(), device_ids=[0])
model1.cuda()
load_model1 = './train_sim/models/simNet_280.pth'

if load_model1 is not None:
    print(f"Loading Model: {load_model1}")
    model1.load_state_dict(torch.load(load_model1))
    
transforms = T.Compose([T.ToTensor()])

model0.eval()
model1.eval()

In [None]:
# Compute scores using the trained model
# Stack input images (live and demo) and pass them through the model

def compute_current_scores(model, playbacks, current_rgb, demo_parts, demo_good, traj_idx=0, live_seed=0):    
    sim_errors = np.ones(len(playbacks)) # lower is better
    mean_flows = np.zeros(len(playbacks))
    
    transforms = T.Compose([T.ToTensor()])
        
    inputs = []
    
    current_rgb = transforms(np.float32(current_rgb / 255.0))

    for demo_idx in demo_good:
        start_idx = demo_parts[str(demo_idx)][traj_idx]['start']
        demo_rgb =  playbacks[demo_idx][start_idx].cam.get_image()[0]
        demo_mask =  playbacks[demo_idx].fg_masks[start_idx]
        
        demo_rgb =  transforms(np.float32(demo_rgb / 255.0))
                
        inp = torch.cat((current_rgb, demo_rgb), dim=0).cuda()
        inputs.append(inp[None, :])
    
    # After concatenation, shape: [Num_demos, 6, 256, 256]
    inputs = torch.cat(inputs, 0)
    out = model(inputs)
    out = out[:, 0]
        
    return out.detach().cpu().numpy()

# cr = np.zeros((256, 256, 3))

# with torch.no_grad():
#     out = compute_current_scores(playbacks, cr, demo_parts, demo_good, 0, 60)

In [None]:
from math import pi
from flow_control.servoing.module import ServoingModule
from gym_grasping.envs.robot_sim_env import RobotSimEnv
from flow_control.runner import evaluate_control
import ipdb

def eval_cmb(playbacks, live_seed, demo_parts, keypoint_info):
    renderer = 'egl'
    # Instantiate env
    env = RobotSimEnv(task='shape_sorting', renderer=renderer, act_type='continuous',
                      initial_pose='close', max_steps=500, control='absolute-full',
                      img_size=(256, 256), param_randomize=("geom",), seed=int(live_seed),
                      task_info=dict(object_rot_range={"rP":pi/2.,"rR":pi/6.}[task_variant]))
    
    traj_map = {0: 'locate', 1: 'grasp', 2: 'insert'}
    
    for idx in range(3):
        
        state, _, _, _ = env.step(None)
        current_rgb = state['rgb_gripper']
        
        scores = compute_current_scores(model, playbacks, current_rgb, demo_parts, demo_good, traj_idx=idx, live_seed=live_seed)

        best_demo_idx = np.argmax(scores)
        
        best_demo = recordings[best_demo_idx]
        kp_info = keypoint_info[best_demo_idx]
        kps = kp_info[traj_map[idx]]
                
        servo_module = ServoingModule(best_demo, control_config=control_config,
                                      start_paused=False, plot=False, plot_save_dir=None,
                                      load=kps)
        _, reward, _, info = evaluate_control(env, servo_module, max_steps=130, save_dir=None,
                                             initial_align=True if idx == 0 else False)
    del env
    del servo_module
    return reward

# eval_cmb(playbacks, 1, demo_parts, keypoint_info)

In [None]:
num_live_seeds = 60
rewards = np.zeros((num_live_seeds))

seeds = range(60, 60 + num_live_seeds)
for live_idx, live_seed in enumerate(seeds):
    reward = eval_cmb(playbacks, live_seed, demo_parts, keypoint_info)
    rewards[live_idx] = reward
    
np.savez('cnn_rewards_3parts.npz', rewards)

In [None]:
print(rewards, np.mean(rewards))