In [None]:
%matplotlib inline

In [None]:
import os
import random
import numpy as np
import copy
import matplotlib.pyplot as plt


# pdc
from dense_correspondence_manipulation.utils.utils import set_cuda_visible_devices

GPU_LIST = [1]
set_cuda_visible_devices(GPU_LIST)

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter



# pdc
from dense_correspondence.dataset.dynamic_drake_sim_dataset import DynamicDrakeSimDataset
from dense_correspondence.correspondence_tools.correspondence_finder import reproject_pixels
from dense_correspondence.correspondence_tools import correspondence_plotter
from dense_correspondence.correspondence_tools.correspondence_finder import compute_correspondence_data, pad_correspondence_data
from dense_correspondence_manipulation.utils.utils import getDenseCorrespondenceSourceDir
import dense_correspondence.loss_functions.utils as loss_utils
import dense_correspondence_manipulation.utils.utils as pdc_utils
from dense_correspondence.network import predict

# pdc
from dense_correspondence.dataset.dynamic_drake_sim_dataset import DynamicDrakeSimDataset
from dense_correspondence.correspondence_tools.correspondence_finder import reproject_pixels
from dense_correspondence.correspondence_tools import correspondence_plotter
from dense_correspondence.correspondence_tools.correspondence_finder import compute_correspondence_data, pad_correspondence_data
from dense_correspondence_manipulation.utils.utils import getDenseCorrespondenceSourceDir, getDictFromYamlFilename
import dense_correspondence_manipulation.utils.utils as pdc_utils
import dense_correspondence.loss_functions.utils as loss_utils
from dense_correspondence.dataset.spartan_episode_reader import SpartanEpisodeReader
import dense_correspondence_manipulation.utils.visualization as vis_utils



print('\n\n')
print("device_count", torch.cuda.device_count())

## Load Dataset

In [None]:
DATA_ROOT = os.getenv("DATA_ROOT")
episodes_root = os.path.join(os.getenv("DATA_ROOT"), "pdc/logs_proto")
print("episodes_root", episodes_root)
episode_list_config = getDictFromYamlFilename(os.path.join(getDenseCorrespondenceSourceDir(), 
                                                          'config/dense_correspondence/dataset/single_object/caterpillar_9_episodes.yaml'))
multi_episode_dict = SpartanEpisodeReader.load_dataset(episode_list_config,
                                                      episodes_root)

config_file = os.path.join(getDenseCorrespondenceSourceDir(), 'config/dense_correspondence/global/drake_sim_dynamic.yaml')
config = getDictFromYamlFilename(config_file)
dataset = DynamicDrakeSimDataset(config, multi_episode_dict)

## Load Model

In [None]:
model_file = os.path.join(DATA_ROOT,
                          "pdc/dev/experiments/heatmap/trained_models/2020-02-06-20-42-49_resnet50__dataset_caterpillar_9/net_dy_epoch_0_iter_10000_model.pth")


model_file = os.path.join(DATA_ROOT, "pdc/dev/experiments/heatmap/trained_models/2020-02-13-21-38-56_resnet50__dataset_caterpillar_9_spatial_expectation_enabled/net_dy_epoch_0_iter_3000_model.pth")
model_config = getDictFromYamlFilename(os.path.join(os.path.dirname(model_file), 'config.yaml'))
model = torch.load(model_file)
model = model.cuda()
model = model.eval()

In [None]:
## Load a pair images, visualize ground truth an estimated correspondences

sz = 3
figsize = (6.4*sz, 4.8*sz)
num_rows = 1
num_cols = 2
K = 4 # num matches to display


episode_name = list(multi_episode_dict.keys())[0]
episode = dataset.episodes[episode_name]
camera_names = list(episode.camera_names)
camera_name_a = camera_names[0]
camera_name_b = camera_names[0]

idx_list = [10,30,60,100]

for j in idx_list:

    idx=0
    idx_a = 0
    idx_b = episode.indices[j]
    data = dataset._getitem(episode, idx, camera_name_a, camera_name_b, idx_a=idx_a, idx_b=idx_b)


    ## Visualize both rgb images with a few ground truth correspondences in green
    ## Learned correspondences will be in red . . . 

    # figname = "target idx: %d" %(idx)
    # fig = plt.figure(figname, figsize=figsize)
    # axes = fig.subplots(num_rows, num_cols, squeeze=False)
    uv_a = data['matches']['uv_a'][:,:K] # [2, num_matches]
    uv_b = data['matches']['uv_b'][:,:K] # [2, num_matches]
    rgb_a = np.copy(data['data_a']['rgb'])
    rgb_b = np.copy(data['data_b']['rgb'])

    # compute sigma
    H = data['data_a']['rgb_tensor'].shape[1]
    W = data['data_a']['rgb_tensor'].shape[2]
    sigma_fraction = config['loss_function']['heatmap']['sigma_fraction']
    diag = np.sqrt(W**2 + H**2)
    sigma = sigma_fraction * diag


    # draw reticles on image a
    label_color = [0, 255, 0]
    vis_utils.draw_reticles(rgb_a,
                            uv_a[0, :],
                            uv_a[1, :],
                            label_color)

    # ax = axes[0,0]
    # ax.imshow(rgb_a)

    label_color = [0, 255, 0]
    vis_utils.draw_reticles(rgb_b,
                            uv_b[0, :],
                            uv_b[1, :],
                            label_color)
    # ax = axes[0,1]
    # ax.imshow(rgb_b)


    with torch.no_grad():
        # now localize the correspondences
        # push rgb_a, rgb_b through the network
        rgb_tensor_a = data['data_a']['rgb_tensor'].unsqueeze(0).cuda()
        out_a = model.forward(rgb_tensor_a)
        des_img_a = out_a['descriptor_image']

        rgb_tensor_b = data['data_b']['rgb_tensor'].unsqueeze(0).cuda()
        out_b = model.forward(rgb_tensor_b)
        des_img_b = out_b['descriptor_image']

        # extract descriptors corresponding to uv_a in des_a
        # [B, K, D]
        des_a = pdc_utils.index_into_batch_image_tensor(des_img_a, uv_a.unsqueeze(0).cuda()).permute([0,2,1])

        # localize these in the other image
        # find best match in des_img_b
        
        # argmax
        best_match_dict = predict.get_argmax_l2(des_a, des_img_b)

        # [2, K]
        best_match_uv_b = best_match_dict['indices'].permute([0, 2, 1]).squeeze()
        
        # spatial expectation
        spatial_pred = predict.get_spatial_expectation(des_a, 
                                        des_img_b, 
                                        sigma=config['network']['sigma_descriptor_heatmap'],
                                        type='exp', return_heatmap=True)
        
        # [K, 2]
        uv_spatial_pred = spatial_pred['uv'].squeeze()
        print("uv_spatial_pred.shape", uv_spatial_pred.shape)
        

        print("des_a.shape", des_a.shape)
        print("best_match_uv_b.shape", best_match_uv_b.shape)


#     label_color = [255, 0, 0] # red
#     vis_utils.draw_reticles(rgb_b,
#                             best_match_uv_b[0, :],
#                             best_match_uv_b[1, :],
#                             label_color)
    
    label_color = [0, 0, 255] # blue
    label_color = [255, 0, 0]
    vis_utils.draw_reticles(rgb_b,
                            uv_spatial_pred[:, 0],
                            uv_spatial_pred[:, 1],
                            label_color)


    figname = "learned correspondences %d" %(j)
    fig = plt.figure(figname, figsize=figsize)
    axes = fig.subplots(num_rows, num_cols, squeeze=False)
    ax = axes[0,0]
    ax.imshow(rgb_a)

    ax = axes[0,1]
    ax.imshow(rgb_b)

    # compute heatmaps
    # this is the heatmap that an individual descriptor from rgb_a
    # will induce given descriptor image b (des_img_b)
    # [B,N,H,W]
#     heatmap_pred = loss_utils.compute_heatmap_from_descriptors(des_a,
#                                                                des_img_b,
#                                                                sigma=sigma,
#                                                                type=config['loss_function']['heatmap']['heatmap_type'])
    
    heatmap_pred = spatial_pred['heatmap']


    # visualize heatmaps
    for k in range(K):
        heatmap = heatmap_pred[0, k].detach().cpu().numpy() # [H, W]
        heatmap_rgb = vis_utils.colormap_from_heatmap(heatmap)
        figname = "heatmap %d" %(k)
        fig = plt.figure(figname, figsize=(6.4, 4.8))
        axes = fig.subplots(1, 1, squeeze=False)
        ax = axes[0,0]
        ax.imshow(heatmap_rgb)