In [None]:
%matplotlib inline

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import time
import cv2
import PIL
from PIL import Image


# torch
import torch

# key_dynam
from key_dynam.utils.utils import get_project_root, load_yaml, get_data_root
from key_dynam.dataset.drake_sim_episode_reader import DrakeSimEpisodeReader
from key_dynam.utils import transform_utils
from key_dynam.utils import drake_image_utils
from key_dynam.dense_correspondence.descriptor_net import sample_descriptors, PrecomputedDescriptorNet

# pdc
from dense_correspondence.dataset.dynamic_drake_sim_dataset import DynamicDrakeSimDataset
from dense_correspondence.correspondence_tools.correspondence_finder import reproject_pixels
from dense_correspondence.correspondence_tools import correspondence_plotter
from dense_correspondence.correspondence_tools.correspondence_finder import compute_correspondence_data, pad_correspondence_data
from dense_correspondence_manipulation.utils.utils import getDenseCorrespondenceSourceDir, set_cuda_visible_devices
import dense_correspondence.loss_functions.utils as loss_utils
import dense_correspondence_manipulation.utils.utils as pdc_utils
import dense_correspondence_manipulation.utils.visualization as vis_utils
from dense_correspondence.network.predict import get_argmax_l2


GPU_LIST = [1]
set_cuda_visible_devices(GPU_LIST)

# toggle to use the previous deprecated pdc image transform
USE_DEPRECATED_IMAGE_TRANSFORM = False

## Load Network

In [None]:
# load network

# model_name = "2019-12-03-21-25-46-128603_top_down_same_view"
# model_name = "2019-12-03-22-51-17-805212_top_down_rotated"
# model_name = "2019-12-03-23-40-30-486661_top_down_rotated_sigma_5"
model_name = "2019-12-04-01-32-12-010393_top_down_rotated_sigma_5"
network_folder = os.path.join(get_project_root(), "data/dev/experiments/05/trained_models", model_name)

epoch = 155
model_file = os.path.join(network_folder, "net_dy_epoch_%d_iter_0_model.pth" %(epoch))
print("model_file", model_file)
model = torch.load(model_file)
model.cuda()
model = model.eval()

## Load Dataset

In [None]:
LOAD_PRECOMPUTED_DESCRIPTORS = True
DATA_ROOT = get_data_root()
# DATASET_NAME = "2019-11-26-00-03-48-223155"
# DATASET_NAME = "top_down_same_view"
DATASET_NAME = "top_down_rotated"
# DATASET_NAME = "2019-12-05-15-58-48-462834_top_down_rotated_250"

dataset_root = os.path.join(get_project_root(), "data/dev/experiments/05/data", DATASET_NAME)

des_images_root = None
if LOAD_PRECOMPUTED_DESCRIPTORS:
    des_images_root = os.path.join(DATA_ROOT, 
                                   "dev/experiments/05/precomputed_descriptors", 
                                   DATASET_NAME, 
                                   model_name)
    
    metadata = load_yaml(os.path.join(des_images_root, 'metadata.yaml'))
    
    assert model_file == metadata['model_file']
        
    
multi_episode_dict = DrakeSimEpisodeReader.load_dataset(dataset_root, 
                                                        descriptor_images_root=des_images_root)
# make pdc dataset now


# placeholder for now
config_file = os.path.join(getDenseCorrespondenceSourceDir(), 
                           'config/dense_correspondence/global/drake_sim_dynamic.yaml')

config = load_yaml(config_file)


datasets = dict()
for phase in ["train", "valid"]:
    dataset = DynamicDrakeSimDataset(config, multi_episode_dict, phase="train")
    
    if USE_DEPRECATED_IMAGE_TRANSFORM:
        from dense_correspondence_manipulation.utils.torch_utils import get_deprecated_image_to_tensor_transform
        dataset._rgb_image_to_tensor = get_deprecated_image_to_tensor_transform()
        
    datasets[phase] = dataset


episode_name_list = list(multi_episode_dict.keys())
episode_name_list.sort()
episode_0 = multi_episode_dict[episode_name_list[0]]


print("episode_0.name", episode_0.name)

camera_names = list(episode_0.camera_names)
camera_name_a = camera_names[0]
camera_name_b = camera_names[1]

global_config = load_yaml(os.path.join(get_project_root(), 'experiments/05/config.yaml'))

## Select reference descriptors

In [None]:
DEBUG = False
pdc_utils.reset_random_seed()
device = next(model.parameters()).device

print("device:")
idx = 0
idx_step = 5

MODEL_ENABLED = True
DATA_TYPE = "valid"
NUM_REF_DESCRIPTORS = 5

sz = 2
num_rows = 1
num_cols = 1
figsize = (6.4*sz*num_cols, 4.8*sz*num_rows)

with torch.no_grad():

    # indices
    b = 0
    n = 10

    data = None
    dataset_tmp = None
    dataset_tmp = datasets[DATA_TYPE]
#     if DATA_TYPE == "train":
#         dataset_tmp = dataset
#     elif DATA_TYPE == "valid":
#         dataset_tmp = dataset_valid
        
    
    data = dataset_tmp._getitem(episode_0, 
                                idx,
                                camera_name_a=camera_name_a, 
                                camera_name_b=camera_name_b)
    
    data_a = data['data_a']
    rgb_tensor_a = data_a['rgb_tensor'].to(device).unsqueeze(0)
#     rgb_tensor_b = data['data_b']['rgb_tensor'].to(device).unsqueeze(0)
    rgb_a = data_a['rgb']
#     uv_a = data['matches']['uv_a'].unsqueeze(0).to(device) # [B, 2, N]
#     uv_b = data['matches']['uv_b'].unsqueeze(0).to(device) # [B, 2, N]
#     valid = data['matches']['valid'].unsqueeze(0).to(device) # [B, N]

    # [1,D,H,W]
    out_a = model.forward(rgb_tensor_a)
    des_img_a = out_a['descriptor_image']
    
    # sample reference descriptors
    img_mask = torch.tensor(data['data_a']['mask']).to(device)
    print("img_mask.shape", img_mask.shape)
    ref_descriptors_dict = sample_descriptors(des_img_a.squeeze(), img_mask, NUM_REF_DESCRIPTORS)
    
    ref_descriptors = ref_descriptors_dict['descriptors']
    ref_descriptors_indices = ref_descriptors_dict['indices']
    
    if DEBUG:
        print("ref_descriptors.shape", ref_descriptors.shape)
        print("ref_descriptors_indices.shape", ref_descriptors_indices.shape)

    # make PrecomputeDescriptorNet
    pdn = PrecomputedDescriptorNet(global_config)
    pdn._ref_descriptors.data = ref_descriptors



# draw reference image
label_color = [0, 255, 0]
rgb_a_wr = np.copy(rgb_a)
vis_utils.draw_reticles(rgb_a_wr,
                        ref_descriptors_indices[:, 0],
                        ref_descriptors_indices[:, 1],
                        label_color)

figname = "reference descriptors"
fig = plt.figure(figname, figsize=figsize)
axes = fig.subplots(num_rows, num_cols, squeeze=False)
axes[0,0].imshow(rgb_a_wr)





## Localize reference descriptors in new episode

In [None]:
DEBUG = False
pdc_utils.reset_random_seed()

sz = 2
num_rows = 1
num_cols = 2
figsize = (6.4*sz*num_cols, 4.8*sz*num_rows)

# this is a randomly selected episode change the idx to
# get a different one
episode_1 = multi_episode_dict[episode_name_list[3]]

with torch.no_grad():
    
    for i in range(100):
        idx_cur = i*idx_step
        
        if idx_cur >= episode_1.length:
            break
        
        data = dataset_tmp._getitem(episode_1,
                                idx_cur,
                                camera_name_a=camera_name_a, 
                                camera_name_b=camera_name_b)
        
        rgb = data['data_a']['rgb']
        rgb_tensor = data['data_a']['rgb_tensor'].unsqueeze(0).to(device)
        descriptor_net_out = model.forward(rgb_tensor)
        des_img = descriptor_net_out['descriptor_image']
        
        # try doing this with precomputed descriptors
        # instead of passing it forwards through the model to compute
        # des_img, this should be a good sanity check
        
        
        
        des_img_precomputed = torch.Tensor(data['data_a']['descriptor']).to(device).unsqueeze(0)
        print("des_img_precomputed.shape", des_img_precomputed.shape)
        print("des_img.shape", des_img.shape)
        
        des_img_delta_norm = torch.norm(des_img - des_img_precomputed)
        print("des_img_delta norm", des_img_delta_norm)
        
#         print("des_img", des_img[0,0,:,:])
#         print("des_img_precomputed", des_img_precomputed[0,0,:,:])
        
        
        out = pdn.forward_descriptor_image(des_img)
#         out = pdn.forward_descriptor_image(des_img_precomputed)
        
        
        # [N, 2]
        best_match_indices = out['best_match_dict']['indices'].squeeze()


        # draw reference image
        label_color = [0, 255, 0]
        rgb_a_wr = np.copy(rgb_a)
        vis_utils.draw_reticles(rgb_a_wr,
                                ref_descriptors_indices[:, 0],
                                ref_descriptors_indices[:, 1],
                                label_color)
        
        # draw target image
        label_color = [255,0,0]
        rgb_wr = np.copy(rgb)
        vis_utils.draw_reticles(rgb_wr,
                               best_match_indices[:, 0],
                               best_match_indices[:, 1],
                               label_color)
        
        
        figname = "target idx: %d" %(idx_cur)
        fig = plt.figure(figname, figsize=figsize)
        axes = fig.subplots(num_rows, num_cols, squeeze=False)
        axes[0,0].imshow(rgb_a_wr)
        axes[0,1].imshow(rgb_wr)
        
        
        
        
        

In [None]:
W = 640
H = 480
image_diagonal_pixels = np.sqrt(H**2 + W**2)
image_diagonal_pixels * 0.00625

In [None]:
8*250/60