In [None]:
%matplotlib inline

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import time
import cv2


# torch
import torch

# key_dynam
from key_dynam.utils.utils import get_project_root, get_data_root


# pdc
from dense_correspondence.dataset.dynamic_drake_sim_dataset import DynamicDrakeSimDataset
from dense_correspondence.correspondence_tools.correspondence_finder import reproject_pixels
from dense_correspondence.correspondence_tools import correspondence_plotter
from dense_correspondence.correspondence_tools.correspondence_finder import compute_correspondence_data, pad_correspondence_data
from dense_correspondence_manipulation.utils.utils import getDenseCorrespondenceSourceDir, getDictFromYamlFilename
import dense_correspondence.loss_functions.utils as loss_utils
from dense_correspondence.dataset.spartan_episode_reader import SpartanEpisodeReader

In [None]:
dataset_processed_dir = os.path.join(get_data_root(), 'dev/pdc/2018-04-16-14-25-19/processed')

config_file = os.path.join(getDenseCorrespondenceSourceDir(), 'config/dense_correspondence/global/drake_sim_dynamic.yaml')
config = getDictFromYamlFilename(config_file)

episode = SpartanEpisodeReader(config, dataset_processed_dir)
indices = episode.indices


sz = 2
figsize = (6.4*sz, 4.8*sz)


episode_name = "episode_0"
multi_episode_dict = {episode_name: episode}
print("multi_episode_dict.keys()", multi_episode_dict.keys())




dataset = DynamicDrakeSimDataset(config, multi_episode_dict)



idx_a = indices[0]
idx_b = indices[20]
episode = multi_episode_dict[episode_name]

camera_names = list(episode.camera_names)
camera_name_a = camera_names[0]
camera_name_b = camera_names[0]

data_a = episode.get_image_data(camera_name_a, idx_a)
data_b = episode.get_image_data(camera_name_b, idx_b)

print("data_a.keys()", data_a.keys())

# plot RGB
plt.figure(figsize=figsize)
plt.imshow(data_a['rgb'])
plt.title('rgb_a')


# mask
mask = data_a['mask']
plt.figure()
plt.imshow(mask)
plt.title('mask')

In [None]:


cd = dataset._getitem(episode, None, camera_name_a, camera_name_b, idx_a=idx_a, idx_b=idx_b)

pad_correspondence_data(cd, N_matches=10, N_masked_non_matches=20, N_background_non_matches=30, verbose=False)
uv_a = cd['matches']['uv_a']
uv_b = cd['matches']['uv_b']

H = cd['data_a']['rgb_tensor'].shape[1]
W = cd['data_a']['rgb_tensor'].shape[2]
sigma_fraction = 0.003
diag = np.sqrt(W**2 + H**2)
sigma = sigma_fraction * diag

print("H", H)
print("W", W)
# visualize some of them

# create heatmap
heatmap_tensor = loss_utils.create_heatmap(uv_b.permute([1,0]), H, W, sigma, type='exp')
print("heatmap_tensor.shape", heatmap_tensor.shape)


idx_range = range(1)
n = 0
uv_a_short = uv_a[:, idx_range]
uv_b_short = uv_b[:, idx_range]



# plot matches
# print(uv_a_short)
images = [data_a['rgb'], data_b['rgb']]
correspondence_plotter.plot_correspondences(images, uv_a_short, uv_b_short)

# plot heatmap
n = 0
heatmap = heatmap_tensor[n, :, :].unsqueeze(-1).expand(*[-1,-1,3]).numpy()
heatmap = heatmap_tensor[n, :, :].numpy()

print("heatmap[v,u]", heatmap[uv_b[1, n], uv_b[0, n]])

heatmap_255 = np.uint8(255*heatmap)
print("heatmap_255.dtype", heatmap_255.dtype)

print("heatmap_255[v,u]", heatmap_255[uv_b[1, n], uv_b[0, n]])

# this is in 'bgr', we want 'rgb'
colormap = cv2.applyColorMap(heatmap_255, cv2.COLORMAP_JET)
colormap_rgb = np.zeros_like(colormap)
colormap_rgb[:,:,0] = colormap[:,:,2]
colormap_rgb[:,:,2] = colormap[:,:,0]
print("heatmap.max", np.max(heatmap))
print("heatmap min", np.min(heatmap))

plt.figure(figsize=figsize)
plt.imshow(colormap_rgb)
plt.show()


alpha = 0.3
blend = alpha * colormap_rgb + (1-alpha) * data_b['rgb']
blend = np.int16(blend)
plt.figure(figsize=figsize)
plt.imshow(blend)
plt.title('blend')
plt.show()



if False:

    # plot non-matches
    masked_non_matches_uv_a = cd['masked_non_matches']['uv_a'][:, idx_range]
    masked_non_matches_uv_b = cd['masked_non_matches']['uv_b'][:, idx_range]
    correspondence_plotter.plot_correspondences(images, masked_non_matches_uv_a, 
                                                masked_non_matches_uv_b, circ_color='r')
    # plot non-matches
    background_non_matches_uv_a = cd['background_non_matches']['uv_a'][:, idx_range]
    background_non_matches_uv_b = cd['background_non_matches']['uv_b'][:, idx_range]
    correspondence_plotter.plot_correspondences(images, background_non_matches_uv_a, 
                                                background_non_matches_uv_b, circ_color='r')


    bnm_uv_a = cd['masked_non_matches']['uv_a']
    print("bnm_uv_a.dtype", bnm_uv_a.dtype)

    bnm_uv_b = cd['masked_non_matches']['uv_b']
    print("bnm_uv_b.dtype", bnm_uv_b.dtype)